GAN训练稳定性技术

原始GAN的交叉熵损失在训练过程中容易导致梯度消失或爆炸问题。本文档介绍三种主流的GAN训练稳定性技术:Wasserstein GAN、梯度惩罚和谱归一化。

1. Wasserstein GAN(WGAN)

1.1 动机

原始GAN使用的Jensen-Shannon散度在两个分布支撑集不重叠时会退化为常数,导致梯度消失。

JS散度的局限性

其中 。当 支撑集不重叠时,(常数)。

1.2 Earth Mover距离

Wasserstein GAN使用Earth Mover(EM)距离(又称Wasserstein-1距离):

直观理解:将分布 “移动”为分布 所需的最小”工作量”。

优势

  • 即使分布完全不重叠,EM距离仍然可以衡量它们的远近
  • 提供有意义且平滑的梯度

1.3 Kantorovich对偶性

直接计算EM距离是 intractable 的。WGAN利用Kantorovich对偶性将其转化为可计算形式:

其中 表示 是1-Lipschitz函数。

判别器扮演的角色:学习一个Lipschitz连续的函数 来估计EM距离。

1.4 WGAN算法

import torch
import torch.nn as nn
 
class WGANDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            # 不使用sigmoid!
        )
    
    def forward(self, x):
        return self.net(x)
 
def wgan_training_loop(real_samples, generator, discriminator, 
                       opt_g, opt_d, clip_value=0.01):
    """WGAN训练循环"""
    batch_size = real_samples.size(0)
    
    # 1. 训练判别器( Critic)
    opt_d.zero_grad()
    
    # 真实样本:期望 f(x) 最大化
    real_validity = discriminator(real_samples)
    loss_d_real = -torch.mean(real_validity)
    
    # 假样本:期望 f(G(z)) 最小化
    z = torch.randn(batch_size, z_dim)
    fake_samples = generator(z)
    fake_validity = discriminator(fake_samples.detach())
    loss_d_fake = torch.mean(fake_validity)
    
    # WGAN损失
    loss_d = loss_d_real + loss_d_fake
    loss_d.backward()
    opt_d.step()
    
    # 2. 权重裁剪(Lipschitz约束的朴素实现)
    for p in discriminator.parameters():
        p.data.clamp_(-clip_value, clip_value)
    
    # 3. 训练生成器
    opt_g.zero_grad()
    z = torch.randn(batch_size, z_dim)
    fake_samples = generator(z)
    fake_validity = discriminator(fake_samples)
    loss_g = -torch.mean(fake_validity)  # 最大化 f(G(z))
    loss_g.backward()
    opt_g.step()

1.5 权重裁剪的局限性

WGAN原始论文使用权重裁剪(weight clipping)来强制Lipschitz约束:

for p in discriminator.parameters():
    p.data.clamp_(-c, c)

问题

  • 裁剪后判别器趋向于学习最简单的函数
  • 可能导致梯度消失或爆炸
  • 需要仔细调参裁剪值

2. WGAN-GP:梯度惩罚

2.1 核心思想

WGAN-GP1梯度惩罚替代权重裁剪,更优雅地实现Lipschitz约束:

其中 是真实样本和生成样本之间的均匀分布:

2.2 梯度惩罚的理解

约束 对所有插值点成立

效果:判别器的梯度范数被约束在1附近,提供稳定的梯度信号。

2.3 完整实现

import torch
import torch.nn as nn
from torch.autograd import grad
 
class WGANGPGenerator(nn.Module):
    def __init__(self, latent_dim=100, img_shape=(1, 28, 28)):
        super().__init__()
        self.img_shape = img_shape
        
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh()
        )
    
    def forward(self, z):
        img = self.model(z)
        return img.view(img.size(0), *self.img_shape)
 
def compute_gradient_penalty(discriminator, real_samples, fake_samples, device):
    """计算梯度惩罚项"""
    batch_size = real_samples.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1).to(device)
    
    # 插值样本
    interpolates = alpha * real_samples + (1 - alpha) * fake_samples
    interpolates = interpolates.requires_grad_(True)
    
    # 计算梯度
    disc_interpolates = discriminator(interpolates)
    gradients = grad(
        outputs=disc_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones(batch_size, 1).to(device),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    
    # 梯度范数
    gradients = gradients.view(batch_size, -1)
    gradient_norm = gradients.norm(2, dim=1)
    
    # 惩罚项:(||∇D(ẋ)||_2 - 1)^2
    penalty = ((gradient_norm - 1) ** 2).mean()
    
    return penalty
 
def wgan_gp_training_step(generator, discriminator, real_samples,
                          opt_g, opt_d, device, lambda_gp=10):
    """WGAN-GP单步训练"""
    batch_size = real_samples.size(0)
    z = torch.randn(batch_size, generator.model[0].in_features).to(device)
    fake_samples = generator(z).detach()
    
    # 训练判别器
    opt_d.zero_grad()
    real_validity = discriminator(real_samples)
    fake_validity = discriminator(fake_samples)
    
    # WGAN损失
    d_loss = (torch.mean(fake_validity) - torch.mean(real_validity))
    
    # 梯度惩罚
    gp = compute_gradient_penalty(discriminator, real_samples, fake_samples, device)
    d_loss += lambda_gp * gp
    
    d_loss.backward()
    opt_d.step()
    
    # 训练生成器(每n步训练一次)
    opt_g.zero_grad()
    z = torch.randn(batch_size, generator.model[0].in_features).to(device)
    fake_samples = generator(z)
    g_loss = -torch.mean(discriminator(fake_samples))
    g_loss.backward()
    opt_g.step()
    
    return d_loss.item(), g_loss.item()

2.4 梯度惩罚的变体

变体改进
0-GP惩罚目标从1变为0,可能更稳定
R1正则在真实样本处惩罚梯度范数(用于SAGAN等)
PPNProgressive gradient penalty,随训练自适应调整

3. 谱归一化(Spectral Normalization)

3.1 理论基础

谱归一化2通过限制判别器权重矩阵的谱范数来实现Lipschitz约束:

其中 的最大奇异值(等于谱范数)。

为什么有效:对于多层网络

其Lipschitz常数为:

谱归一化确保每个 ,从而

3.2 幂迭代法计算谱范数

直接计算奇异值分解(SVD)对于大矩阵开销巨大。使用幂迭代法高效近似:

def spectral_norm(W, n_power_iterations=1, eps=1e-12):
    """
    使用幂迭代法计算谱范数
    """
    # 初始化随机向量
    u = torch.randn(W.size(0), 1).to(W.device)
    u = u / (u.norm() + eps)
    
    v = torch.randn(W.size(1), 1).to(W.device)
    v = v / (v.norm() + eps)
    
    for _ in range(n_power_iterations):
        # 幂迭代
        Wv = W @ v
        u = Wv / (Wv.norm() + eps)
        
        Wu = W.t() @ u
        v = Wu / (Wu.norm() + eps)
    
    # 谱范数估计
    sigma = (u.t() @ W @ v).item()
    
    return sigma
 
def spectral_normalize(W, n_power_iterations=1):
    """谱归一化权重"""
    sigma = spectral_norm(W, n_power_iterations)
    return W / (sigma + 1e-12)

3.3 PyTorch实现

import torch
import torch.nn as nn
import torch.nn.utils as nn_utils
 
class SpectralNorm(nn.Module):
    """谱归一化层"""
    def __init__(self, module, name='weight', n_power_iterations=1):
        super().__init__()
        self.module = module
        self.name = name
        self.n_power_iterations = n_power_iterations
        
        # 初始化u和v向量
        w = getattr(module, name)
        self.register_buffer('_u', torch.randn(w.size(0), 1))
        self.register_buffer('_v', torch.randn(w.size(1), 1))
        self._u = nn.functional.normalize(self._u, dim=0, eps=1e-12)
        self._v = nn.functional.normalize(self._v, dim=0, eps=1e-12)
    
    def forward(self, x):
        w = getattr(self.module, self.name)
        sigma, u, v = self._compute_spectral_norm(w)
        
        # 更新缓存
        self._u.copy_(u)
        self._v.copy_(v)
        
        # 谱归一化
        w_sn = w / sigma
        setattr(self.module, self.name, w_sn)
        
        return self.module(x)
    
    def _compute_spectral_norm(self, W):
        """幂迭代计算"""
        u = self._u
        v = self._v
        
        for _ in range(self.n_power_iterations):
            v = torch.mv(W.t(), u)
            v = nn.functional.normalize(v, dim=0, eps=1e-12)
            
            u = torch.mv(W, v)
            u = nn.functional.normalize(u, dim=0, eps=1e-12)
        
        sigma = torch.dot(u.view(-1), torch.mv(W, v))
        
        return sigma, u, v
 
# 使用示例:谱归一化卷积层
class SNConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, 
                              stride=stride, padding=padding)
        # 对卷积核进行谱归一化
        self.sn = SpectralNorm(self.conv, name='weight')
    
    def forward(self, x):
        return self.sn(self.conv(x))

3.4 谱归一化vs梯度惩罚对比

特性谱归一化WGAN-GP
实现方式归一化权重额外损失项
计算开销中等(需幂迭代)较高(需计算梯度)
训练速度较慢
收敛稳定性
对架构的要求无特殊要求需小心设计
适用场景通用复杂分布

3.5 与LeCun初始化的联系

谱归一化与LeCun初始化有深层联系:

LeCun初始化

谱归一化

两者都通过控制权重矩阵的谱属性来稳定梯度传播。谱归一化可以看作是一种自适应的动态初始化策略。

4. 小批量判别与特征匹配

4.1 小批量判别(Mini-batch Discrimination)

为解决模式崩溃,在判别器中加入批次内样本间的差异信息:

class MiniBatchDiscrimination(nn.Module):
    def __init__(self, input_dim, num_kernels, kernel_dim):
        super().__init__()
        self.T = nn.Parameter(torch.randn(input_dim, num_kernels * kernel_dim))
    
    def forward(self, x):
        # 计算批次内每个样本与其他样本的距离
        batch_size = x.size(0)
        
        # x: (B, D) -> (B, num_kernels, kernel_dim)
        x_expanded = x.unsqueeze(1)  # (B, 1, D)
        x_expanded = x_expanded.expand(batch_size, batch_size, x.size(1))
        
        # 计算距离矩阵
        T = self.T.view(-1, x.size(1))  # (num_kernels * kernel_dim, D)
        M = x @ T.t()  # (B, num_kernels * kernel_dim)
        M = M.view(batch_size, -1)  # (B, num_kernels, kernel_dim)
        
        # 样本间差异
        o = torch.zeros(batch_size, self.num_kernels).to(x.device)
        for i in range(batch_size):
            diff = torch.abs(M - M[i])
            o[i] = torch.exp(-diff.sum(dim=1))
        
        return torch.cat([x, o], dim=1)

4.2 特征匹配(Feature Matching)

训练生成器匹配判别器中间层的统计特征:

其中 是判别器中间层的特征表示。

5. 训练策略汇总

5.1 完整训练流程建议

def gan_training_strategies(generator, discriminator, dataloader, 
                           num_epochs, device, strategy='wgan-gp'):
    """
    GAN训练策略选择指南
    """
    if strategy == 'wgan-gp':
        # WGAN-GP:最通用的选择
        lambda_gp = 10
        lr_d, lr_g = 1e-4, 1e-4
        n_critic = 5  # 每轮判别器训练次数
        opt_g = lambda p: torch.optim.Adam(p, lr=lr_g, betas=(0.0, 0.9))
        opt_d = lambda p: torch.optim.Adam(p, lr=lr_d, betas=(0.0, 0.9))
        
    elif strategy == 'spectral_norm':
        # 谱归一化:即插即用
        # 判别器所有层使用谱归一化
        lambda_gp = 0  # 不需要梯度惩罚
        lr_d, lr_g = 2e-4, 2e-4
        n_critic = 1
        opt_g = lambda p: torch.optim.Adam(p, lr=lr_g, betas=(0.5, 0.999))
        opt_d = lambda p: torch.optim.Adam(p, lr=lr_d, betas=(0.5, 0.999))
        
    elif strategy == 'lsgan':
        # 最小二乘GAN:简单稳定
        lambda_gp = 0
        lr_d, lr_g = 1e-3, 1e-3
        n_critic = 1
        
    # 通用训练循环
    for epoch in range(num_epochs):
        for real_samples, _ in dataloader:
            real_samples = real_samples.to(device)
            batch_size = real_samples.size(0)
            
            # 训练判别器 n_critic 次
            for _ in range(n_critic):
                # ... 判别器训练代码 ...
                pass
            
            # 训练生成器
            # ... 生成器训练代码 ...
            pass

5.2 优化器与学习率选择

优化器学习率(判别器)学习率(生成器)
Adam2e-42e-40.50.999
RMSprop5e-55e-5--
SGD1e-31e-3--

5.3 监控指标

指标理想值异常信号
趋近0或1
趋近0或1
逐渐下降震荡或上升

6. 参考资料

扩展阅读:

Footnotes

  1. Gulrajani I, Ahmed F, Arjovsky M, et al. Improved training of Wasserstein GANs. NeurIPS, 2017. arXiv:1704.00028

  2. Miyato T, Kataoka T, Koyama M, et al. Spectral normalization for generative adversarial networks. ICLR, 2018. arXiv:1802.05957