GAN训练稳定性技术
原始GAN的交叉熵损失在训练过程中容易导致梯度消失或爆炸问题。本文档介绍三种主流的GAN训练稳定性技术:Wasserstein GAN、梯度惩罚和谱归一化。
1. Wasserstein GAN(WGAN)
1.1 动机
原始GAN使用的Jensen-Shannon散度在两个分布支撑集不重叠时会退化为常数,导致梯度消失。
JS散度的局限性:
其中 。当 和 支撑集不重叠时,(常数)。
1.2 Earth Mover距离
Wasserstein GAN使用Earth Mover(EM)距离(又称Wasserstein-1距离):
直观理解:将分布 “移动”为分布 所需的最小”工作量”。
优势:
- 即使分布完全不重叠,EM距离仍然可以衡量它们的远近
- 提供有意义且平滑的梯度
1.3 Kantorovich对偶性
直接计算EM距离是 intractable 的。WGAN利用Kantorovich对偶性将其转化为可计算形式:
其中 表示 是1-Lipschitz函数。
判别器扮演的角色:学习一个Lipschitz连续的函数 来估计EM距离。
1.4 WGAN算法
import torch
import torch.nn as nn
class WGANDiscriminator(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(784, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
# 不使用sigmoid!
)
def forward(self, x):
return self.net(x)
def wgan_training_loop(real_samples, generator, discriminator,
opt_g, opt_d, clip_value=0.01):
"""WGAN训练循环"""
batch_size = real_samples.size(0)
# 1. 训练判别器( Critic)
opt_d.zero_grad()
# 真实样本:期望 f(x) 最大化
real_validity = discriminator(real_samples)
loss_d_real = -torch.mean(real_validity)
# 假样本:期望 f(G(z)) 最小化
z = torch.randn(batch_size, z_dim)
fake_samples = generator(z)
fake_validity = discriminator(fake_samples.detach())
loss_d_fake = torch.mean(fake_validity)
# WGAN损失
loss_d = loss_d_real + loss_d_fake
loss_d.backward()
opt_d.step()
# 2. 权重裁剪(Lipschitz约束的朴素实现)
for p in discriminator.parameters():
p.data.clamp_(-clip_value, clip_value)
# 3. 训练生成器
opt_g.zero_grad()
z = torch.randn(batch_size, z_dim)
fake_samples = generator(z)
fake_validity = discriminator(fake_samples)
loss_g = -torch.mean(fake_validity) # 最大化 f(G(z))
loss_g.backward()
opt_g.step()1.5 权重裁剪的局限性
WGAN原始论文使用权重裁剪(weight clipping)来强制Lipschitz约束:
for p in discriminator.parameters():
p.data.clamp_(-c, c)问题:
- 裁剪后判别器趋向于学习最简单的函数
- 可能导致梯度消失或爆炸
- 需要仔细调参裁剪值
2. WGAN-GP:梯度惩罚
2.1 核心思想
WGAN-GP1用梯度惩罚替代权重裁剪,更优雅地实现Lipschitz约束:
其中 是真实样本和生成样本之间的均匀分布:
2.2 梯度惩罚的理解
约束: 对所有插值点成立
效果:判别器的梯度范数被约束在1附近,提供稳定的梯度信号。
2.3 完整实现
import torch
import torch.nn as nn
from torch.autograd import grad
class WGANGPGenerator(nn.Module):
def __init__(self, latent_dim=100, img_shape=(1, 28, 28)):
super().__init__()
self.img_shape = img_shape
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(torch.prod(torch.tensor(img_shape)))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
return img.view(img.size(0), *self.img_shape)
def compute_gradient_penalty(discriminator, real_samples, fake_samples, device):
"""计算梯度惩罚项"""
batch_size = real_samples.size(0)
alpha = torch.rand(batch_size, 1, 1, 1).to(device)
# 插值样本
interpolates = alpha * real_samples + (1 - alpha) * fake_samples
interpolates = interpolates.requires_grad_(True)
# 计算梯度
disc_interpolates = discriminator(interpolates)
gradients = grad(
outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=torch.ones(batch_size, 1).to(device),
create_graph=True,
retain_graph=True,
only_inputs=True
)[0]
# 梯度范数
gradients = gradients.view(batch_size, -1)
gradient_norm = gradients.norm(2, dim=1)
# 惩罚项:(||∇D(ẋ)||_2 - 1)^2
penalty = ((gradient_norm - 1) ** 2).mean()
return penalty
def wgan_gp_training_step(generator, discriminator, real_samples,
opt_g, opt_d, device, lambda_gp=10):
"""WGAN-GP单步训练"""
batch_size = real_samples.size(0)
z = torch.randn(batch_size, generator.model[0].in_features).to(device)
fake_samples = generator(z).detach()
# 训练判别器
opt_d.zero_grad()
real_validity = discriminator(real_samples)
fake_validity = discriminator(fake_samples)
# WGAN损失
d_loss = (torch.mean(fake_validity) - torch.mean(real_validity))
# 梯度惩罚
gp = compute_gradient_penalty(discriminator, real_samples, fake_samples, device)
d_loss += lambda_gp * gp
d_loss.backward()
opt_d.step()
# 训练生成器(每n步训练一次)
opt_g.zero_grad()
z = torch.randn(batch_size, generator.model[0].in_features).to(device)
fake_samples = generator(z)
g_loss = -torch.mean(discriminator(fake_samples))
g_loss.backward()
opt_g.step()
return d_loss.item(), g_loss.item()2.4 梯度惩罚的变体
| 变体 | 改进 |
|---|---|
| 0-GP | 惩罚目标从1变为0,可能更稳定 |
| R1正则 | 在真实样本处惩罚梯度范数(用于SAGAN等) |
| PPN | Progressive gradient penalty,随训练自适应调整 |
3. 谱归一化(Spectral Normalization)
3.1 理论基础
谱归一化2通过限制判别器权重矩阵的谱范数来实现Lipschitz约束:
其中 是 的最大奇异值(等于谱范数)。
为什么有效:对于多层网络
其Lipschitz常数为:
谱归一化确保每个 ,从而 。
3.2 幂迭代法计算谱范数
直接计算奇异值分解(SVD)对于大矩阵开销巨大。使用幂迭代法高效近似:
def spectral_norm(W, n_power_iterations=1, eps=1e-12):
"""
使用幂迭代法计算谱范数
"""
# 初始化随机向量
u = torch.randn(W.size(0), 1).to(W.device)
u = u / (u.norm() + eps)
v = torch.randn(W.size(1), 1).to(W.device)
v = v / (v.norm() + eps)
for _ in range(n_power_iterations):
# 幂迭代
Wv = W @ v
u = Wv / (Wv.norm() + eps)
Wu = W.t() @ u
v = Wu / (Wu.norm() + eps)
# 谱范数估计
sigma = (u.t() @ W @ v).item()
return sigma
def spectral_normalize(W, n_power_iterations=1):
"""谱归一化权重"""
sigma = spectral_norm(W, n_power_iterations)
return W / (sigma + 1e-12)3.3 PyTorch实现
import torch
import torch.nn as nn
import torch.nn.utils as nn_utils
class SpectralNorm(nn.Module):
"""谱归一化层"""
def __init__(self, module, name='weight', n_power_iterations=1):
super().__init__()
self.module = module
self.name = name
self.n_power_iterations = n_power_iterations
# 初始化u和v向量
w = getattr(module, name)
self.register_buffer('_u', torch.randn(w.size(0), 1))
self.register_buffer('_v', torch.randn(w.size(1), 1))
self._u = nn.functional.normalize(self._u, dim=0, eps=1e-12)
self._v = nn.functional.normalize(self._v, dim=0, eps=1e-12)
def forward(self, x):
w = getattr(self.module, self.name)
sigma, u, v = self._compute_spectral_norm(w)
# 更新缓存
self._u.copy_(u)
self._v.copy_(v)
# 谱归一化
w_sn = w / sigma
setattr(self.module, self.name, w_sn)
return self.module(x)
def _compute_spectral_norm(self, W):
"""幂迭代计算"""
u = self._u
v = self._v
for _ in range(self.n_power_iterations):
v = torch.mv(W.t(), u)
v = nn.functional.normalize(v, dim=0, eps=1e-12)
u = torch.mv(W, v)
u = nn.functional.normalize(u, dim=0, eps=1e-12)
sigma = torch.dot(u.view(-1), torch.mv(W, v))
return sigma, u, v
# 使用示例:谱归一化卷积层
class SNConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
stride=stride, padding=padding)
# 对卷积核进行谱归一化
self.sn = SpectralNorm(self.conv, name='weight')
def forward(self, x):
return self.sn(self.conv(x))3.4 谱归一化vs梯度惩罚对比
| 特性 | 谱归一化 | WGAN-GP |
|---|---|---|
| 实现方式 | 归一化权重 | 额外损失项 |
| 计算开销 | 中等(需幂迭代) | 较高(需计算梯度) |
| 训练速度 | 快 | 较慢 |
| 收敛稳定性 | 高 | 高 |
| 对架构的要求 | 无特殊要求 | 需小心设计 |
| 适用场景 | 通用 | 复杂分布 |
3.5 与LeCun初始化的联系
谱归一化与LeCun初始化有深层联系:
LeCun初始化:
谱归一化:
两者都通过控制权重矩阵的谱属性来稳定梯度传播。谱归一化可以看作是一种自适应的动态初始化策略。
4. 小批量判别与特征匹配
4.1 小批量判别(Mini-batch Discrimination)
为解决模式崩溃,在判别器中加入批次内样本间的差异信息:
class MiniBatchDiscrimination(nn.Module):
def __init__(self, input_dim, num_kernels, kernel_dim):
super().__init__()
self.T = nn.Parameter(torch.randn(input_dim, num_kernels * kernel_dim))
def forward(self, x):
# 计算批次内每个样本与其他样本的距离
batch_size = x.size(0)
# x: (B, D) -> (B, num_kernels, kernel_dim)
x_expanded = x.unsqueeze(1) # (B, 1, D)
x_expanded = x_expanded.expand(batch_size, batch_size, x.size(1))
# 计算距离矩阵
T = self.T.view(-1, x.size(1)) # (num_kernels * kernel_dim, D)
M = x @ T.t() # (B, num_kernels * kernel_dim)
M = M.view(batch_size, -1) # (B, num_kernels, kernel_dim)
# 样本间差异
o = torch.zeros(batch_size, self.num_kernels).to(x.device)
for i in range(batch_size):
diff = torch.abs(M - M[i])
o[i] = torch.exp(-diff.sum(dim=1))
return torch.cat([x, o], dim=1)4.2 特征匹配(Feature Matching)
训练生成器匹配判别器中间层的统计特征:
其中 是判别器中间层的特征表示。
5. 训练策略汇总
5.1 完整训练流程建议
def gan_training_strategies(generator, discriminator, dataloader,
num_epochs, device, strategy='wgan-gp'):
"""
GAN训练策略选择指南
"""
if strategy == 'wgan-gp':
# WGAN-GP:最通用的选择
lambda_gp = 10
lr_d, lr_g = 1e-4, 1e-4
n_critic = 5 # 每轮判别器训练次数
opt_g = lambda p: torch.optim.Adam(p, lr=lr_g, betas=(0.0, 0.9))
opt_d = lambda p: torch.optim.Adam(p, lr=lr_d, betas=(0.0, 0.9))
elif strategy == 'spectral_norm':
# 谱归一化:即插即用
# 判别器所有层使用谱归一化
lambda_gp = 0 # 不需要梯度惩罚
lr_d, lr_g = 2e-4, 2e-4
n_critic = 1
opt_g = lambda p: torch.optim.Adam(p, lr=lr_g, betas=(0.5, 0.999))
opt_d = lambda p: torch.optim.Adam(p, lr=lr_d, betas=(0.5, 0.999))
elif strategy == 'lsgan':
# 最小二乘GAN:简单稳定
lambda_gp = 0
lr_d, lr_g = 1e-3, 1e-3
n_critic = 1
# 通用训练循环
for epoch in range(num_epochs):
for real_samples, _ in dataloader:
real_samples = real_samples.to(device)
batch_size = real_samples.size(0)
# 训练判别器 n_critic 次
for _ in range(n_critic):
# ... 判别器训练代码 ...
pass
# 训练生成器
# ... 生成器训练代码 ...
pass5.2 优化器与学习率选择
| 优化器 | 学习率(判别器) | 学习率(生成器) | ||
|---|---|---|---|---|
| Adam | 2e-4 | 2e-4 | 0.5 | 0.999 |
| RMSprop | 5e-5 | 5e-5 | - | - |
| SGD | 1e-3 | 1e-3 | - | - |
5.3 监控指标
| 指标 | 理想值 | 异常信号 |
|---|---|---|
| 趋近0或1 | ||
| 趋近0或1 | ||
| 逐渐下降 | 震荡或上升 |
6. 参考资料
扩展阅读:
Footnotes
-
Gulrajani I, Ahmed F, Arjovsky M, et al. Improved training of Wasserstein GANs. NeurIPS, 2017. arXiv:1704.00028 ↩
-
Miyato T, Kataoka T, Koyama M, et al. Spectral normalization for generative adversarial networks. ICLR, 2018. arXiv:1802.05957 ↩