生成对抗网络基础

生成对抗网络(Generative Adversarial Network, GAN)由Goodfellow等人于2014年提出,是一种基于博弈论的生成模型1。GAN的核心思想是通过两个神经网络的对抗训练,让生成器学会从随机噪声中生成与真实数据分布相似的样本。

1. 基本原理

1.1 Minimax零和博弈框架

GAN采用生成器(Generator, 判别器(Discriminator, 的对抗架构:

其中:

  • :从先验分布(通常为均匀或正态分布)采样的随机噪声
  • :生成器将噪声映射到数据空间
  • :判别器输出样本为真实数据的概率
  • :真实数据分布

1.2 纳什均衡

GAN训练的目标是找到纳什均衡点,此时:

  • 生成器 产生的分布 接近真实分布
  • 判别器 无法区分真假样本,即 对所有样本

时,损失函数达到全局最优:

其中 为Jensen-Shannon散度。

1.3 训练动态

GAN的训练过程可以理解为交替优化:

阶段优化目标训练步数比例
判别器训练最大化区分真假样本的能力1-5步/轮
生成器训练最小化判别器的区分能力1步/轮
# GAN训练伪代码
for epoch in range(training_epochs):
    # 训练判别器 k 步
    for _ in range(k):
        z = sample_noise(batch_size)
        fake_samples = generator(z)
        
        # 真实样本标签为1,假样本标签为0
        d_loss = -torch.mean(torch.log(discriminator(real)) + 
                             torch.log(1 - discriminator(fake_samples.detach())))
        d_loss.backward()
        optimizer_D.step()
    
    # 训练生成器
    z = sample_noise(batch_size)
    fake_samples = generator(z)
    g_loss = -torch.mean(torch.log(discriminator(fake_samples)))
    g_loss.backward()
    optimizer_G.step()

2. 训练挑战

2.1 模式崩溃(Mode Collapse)

问题描述:生成器只学习到真实分布的少数模式,产生的样本缺乏多样性。

原因分析

  • Minimax损失导致生成器倾向于生成”安全”的样本
  • 判别器快速收敛后,生成器的梯度变得不可靠

解决方案

方法原理
小批量判别在判别器中加入批次内样本间的关系
Wasserstein距离使用更平滑的损失函数
Unrolled GAN预测判别器的更新方向
MSGGAN使用多个独立生成器

2.2 训练不稳定性

问题描述:判别器和生成器之间存在竞争失衡,导致训练崩溃。

梯度消失问题

  • 当判别器过于强大时, 趋近于0
  • 生成器收到的梯度接近零,无法有效学习

梯度爆炸问题

  • 初始化不当或学习率过高时
  • 判别器输出的极端值导致梯度爆炸

2.3 平衡问题

判别器和生成器的训练需要精心平衡:

  • 过强判别器:生成器梯度消失
  • 过弱判别器:生成器失去学习目标

TTUR(Two Time-Scale Update Rule):为判别器和生成器设置不同的学习率,通常判别器学习率更高。

3. 核心概念

3.1 潜在空间(Latent Space)

生成器的输入 构成潜在空间 。理想情况下:

  • 潜在空间应具有良好的语义结构
  • 相似的 应生成语义相似的样本
  • 插值操作在潜在空间中应有平滑的视觉效果
# 潜在空间插值演示
def interpolate_z(z1, z2, steps=10):
    """在两个潜在向量之间进行线性插值"""
    alphas = torch.linspace(0, 1, steps)
    interpolations = []
    for alpha in alphas:
        z_interp = alpha * z1 + (1 - alpha) * z2
        interpolations.append(z_interp)
    return torch.stack(interpolations)

3.2 条件生成

条件GAN(cGAN) 通过额外输入条件信息 来控制生成内容:

应用场景

  • 类别条件生成(class-conditional)
  • 文本到图像生成
  • 图像到图像转换
  • 风格控制

3.3 对抗训练中的博弈动态

GAN训练过程中的博弈动态可以用以下图示理解:

                    生成器 G
                        ↑
                        |  生成的假样本 G(z)
                        |
    梯度更新            |            梯度更新
    试图欺骗D  ←——————→  D ←———————→  更好地区分真假
                        |
                        |  真实样本 x ~ p_data
                        |
                    判别器 D

理想状态:经过充分训练后,判别器输出 ,即无法区分真假。

4. GAN损失函数变体

4.1 原始GAN损失

生成器损失(Minimax的直接形式):

判别器损失

4.2 非饱和损失(Non-Saturating Loss)

为解决饱和问题,生成器使用:

时,梯度仍然存在,避免了梯度消失。

4.3 最小二乘损失(Least Squares GAN)

使用MSE损失提高训练稳定性:

4.4 Wasserstein损失

使用Wasserstein-1距离替代JS散度,具体见wasserstein-gan

5. GAN的表示学习特性

5.1 隐式密度

GAN属于隐式生成模型,通过采样过程定义分布,不直接给出概率密度函数

优点

  • 可以表示复杂、高维分布
  • 避免了马尔可夫链蒙特卡洛(MCMC)采样的问题
  • 适合大规模数据集

缺点

  • 无法直接计算似然
  • 难以进行精确的统计推断

5.2 表示分解

GAN的判别器学习到的表示可以用于:

  • 特征可视化
  • 半监督学习
  • 迁移学习

6. 实践指南

6.1 训练技巧

  1. 使用BatchNorm:稳定训练,提供清晰的损失景观
  2. 使用Adam优化器:学习率通常设置为
  3. 标签平滑:真实标签使用 而非
  4. 避免ReLU/LeakyReLU的极端斜率:通常使用
  5. 监控损失曲线:判别器损失趋近于 是理想状态

6.2 架构选择

任务推荐架构
图像生成DCGAN、StyleGAN
文本生成脆性较强,通常用其他方法
语音合成WaveGAN
图生成GRAN

6.3 常见问题排查

问题可能原因解决方案
判别器损失迅速降至0判别器过强提高判别器学习率或减少判别器训练步数
生成器损失爆炸梯度爆炸降低学习率,使用梯度裁剪
模式崩溃生成器缺乏多样性添加小批量判别,使用WGAN
训练不收敛学习率过高使用WGAN-GP或谱归一化

7. GAN与其他生成模型对比

特性GANVAEFlowDiffusion
似然计算
隐式密度
训练稳定性中等稳定稳定非常稳定
样本质量中等中等
采样速度
模式覆盖

8. 参考资料

扩展阅读:

Footnotes

  1. Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial networks. NeurIPS, 2014. arXiv:1406.2661