生成对抗网络基础
生成对抗网络(Generative Adversarial Network, GAN)由Goodfellow等人于2014年提出,是一种基于博弈论的生成模型1。GAN的核心思想是通过两个神经网络的对抗训练,让生成器学会从随机噪声中生成与真实数据分布相似的样本。
1. 基本原理
1.1 Minimax零和博弈框架
GAN采用生成器(Generator, ) 和判别器(Discriminator, ) 的对抗架构:
其中:
- :从先验分布(通常为均匀或正态分布)采样的随机噪声
- :生成器将噪声映射到数据空间
- :判别器输出样本为真实数据的概率
- :真实数据分布
1.2 纳什均衡
GAN训练的目标是找到纳什均衡点,此时:
- 生成器 产生的分布 接近真实分布
- 判别器 无法区分真假样本,即 对所有样本
当 时,损失函数达到全局最优:
其中 为Jensen-Shannon散度。
1.3 训练动态
GAN的训练过程可以理解为交替优化:
| 阶段 | 优化目标 | 训练步数比例 |
|---|---|---|
| 判别器训练 | 最大化区分真假样本的能力 | 1-5步/轮 |
| 生成器训练 | 最小化判别器的区分能力 | 1步/轮 |
# GAN训练伪代码
for epoch in range(training_epochs):
# 训练判别器 k 步
for _ in range(k):
z = sample_noise(batch_size)
fake_samples = generator(z)
# 真实样本标签为1,假样本标签为0
d_loss = -torch.mean(torch.log(discriminator(real)) +
torch.log(1 - discriminator(fake_samples.detach())))
d_loss.backward()
optimizer_D.step()
# 训练生成器
z = sample_noise(batch_size)
fake_samples = generator(z)
g_loss = -torch.mean(torch.log(discriminator(fake_samples)))
g_loss.backward()
optimizer_G.step()2. 训练挑战
2.1 模式崩溃(Mode Collapse)
问题描述:生成器只学习到真实分布的少数模式,产生的样本缺乏多样性。
原因分析:
- Minimax损失导致生成器倾向于生成”安全”的样本
- 判别器快速收敛后,生成器的梯度变得不可靠
解决方案:
| 方法 | 原理 |
|---|---|
| 小批量判别 | 在判别器中加入批次内样本间的关系 |
| Wasserstein距离 | 使用更平滑的损失函数 |
| Unrolled GAN | 预测判别器的更新方向 |
| MSGGAN | 使用多个独立生成器 |
2.2 训练不稳定性
问题描述:判别器和生成器之间存在竞争失衡,导致训练崩溃。
梯度消失问题:
- 当判别器过于强大时, 趋近于0
- 生成器收到的梯度接近零,无法有效学习
梯度爆炸问题:
- 初始化不当或学习率过高时
- 判别器输出的极端值导致梯度爆炸
2.3 平衡问题
判别器和生成器的训练需要精心平衡:
- 过强判别器:生成器梯度消失
- 过弱判别器:生成器失去学习目标
TTUR(Two Time-Scale Update Rule):为判别器和生成器设置不同的学习率,通常判别器学习率更高。
3. 核心概念
3.1 潜在空间(Latent Space)
生成器的输入 构成潜在空间 。理想情况下:
- 潜在空间应具有良好的语义结构
- 相似的 应生成语义相似的样本
- 插值操作在潜在空间中应有平滑的视觉效果
# 潜在空间插值演示
def interpolate_z(z1, z2, steps=10):
"""在两个潜在向量之间进行线性插值"""
alphas = torch.linspace(0, 1, steps)
interpolations = []
for alpha in alphas:
z_interp = alpha * z1 + (1 - alpha) * z2
interpolations.append(z_interp)
return torch.stack(interpolations)3.2 条件生成
条件GAN(cGAN) 通过额外输入条件信息 来控制生成内容:
应用场景:
- 类别条件生成(class-conditional)
- 文本到图像生成
- 图像到图像转换
- 风格控制
3.3 对抗训练中的博弈动态
GAN训练过程中的博弈动态可以用以下图示理解:
生成器 G
↑
| 生成的假样本 G(z)
|
梯度更新 | 梯度更新
试图欺骗D ←——————→ D ←———————→ 更好地区分真假
|
| 真实样本 x ~ p_data
|
判别器 D
理想状态:经过充分训练后,判别器输出 ,即无法区分真假。
4. GAN损失函数变体
4.1 原始GAN损失
生成器损失(Minimax的直接形式):
判别器损失:
4.2 非饱和损失(Non-Saturating Loss)
为解决饱和问题,生成器使用:
当 时,梯度仍然存在,避免了梯度消失。
4.3 最小二乘损失(Least Squares GAN)
使用MSE损失提高训练稳定性:
4.4 Wasserstein损失
使用Wasserstein-1距离替代JS散度,具体见wasserstein-gan。
5. GAN的表示学习特性
5.1 隐式密度
GAN属于隐式生成模型,通过采样过程定义分布,不直接给出概率密度函数 。
优点:
- 可以表示复杂、高维分布
- 避免了马尔可夫链蒙特卡洛(MCMC)采样的问题
- 适合大规模数据集
缺点:
- 无法直接计算似然
- 难以进行精确的统计推断
5.2 表示分解
GAN的判别器学习到的表示可以用于:
- 特征可视化
- 半监督学习
- 迁移学习
6. 实践指南
6.1 训练技巧
- 使用BatchNorm:稳定训练,提供清晰的损失景观
- 使用Adam优化器:学习率通常设置为 ,
- 标签平滑:真实标签使用 而非
- 避免ReLU/LeakyReLU的极端斜率:通常使用
- 监控损失曲线:判别器损失趋近于 是理想状态
6.2 架构选择
| 任务 | 推荐架构 |
|---|---|
| 图像生成 | DCGAN、StyleGAN |
| 文本生成 | 脆性较强,通常用其他方法 |
| 语音合成 | WaveGAN |
| 图生成 | GRAN |
6.3 常见问题排查
| 问题 | 可能原因 | 解决方案 |
|---|---|---|
| 判别器损失迅速降至0 | 判别器过强 | 提高判别器学习率或减少判别器训练步数 |
| 生成器损失爆炸 | 梯度爆炸 | 降低学习率,使用梯度裁剪 |
| 模式崩溃 | 生成器缺乏多样性 | 添加小批量判别,使用WGAN |
| 训练不收敛 | 学习率过高 | 使用WGAN-GP或谱归一化 |
7. GAN与其他生成模型对比
| 特性 | GAN | VAE | Flow | Diffusion |
|---|---|---|---|---|
| 似然计算 | ✗ | ✓ | ✓ | ✓ |
| 隐式密度 | ✓ | ✗ | ✗ | ✗ |
| 训练稳定性 | 中等 | 稳定 | 稳定 | 非常稳定 |
| 样本质量 | 高 | 中等 | 中等 | 高 |
| 采样速度 | 快 | 快 | 快 | 慢 |
| 模式覆盖 | 差 | 好 | 好 | 好 |
8. 参考资料
扩展阅读:
Footnotes
-
Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial networks. NeurIPS, 2014. arXiv:1406.2661 ↩