GAN架构变体
从2014年原始GAN提出至今,研究者们设计了许多专用架构来解决特定任务。本文档介绍三种最具影响力的GAN架构:DCGAN、StyleGAN和BigGAN。
1. DCGAN(Deep Convolutional GAN)
1.1 架构设计原则
DCGAN1是最早成功用于图像生成的GAN架构。其核心设计原则成为后续工作的基础:
| 设计原则 | 说明 |
|---|---|
| 全卷积网络 | 用步长卷积替代池化层 |
| 批归一化 | 所有卷积层使用BatchNorm(生成器输出和判别器输入除外) |
| LeakyReLU激活 | 判别器使用斜率0.2的LeakyReLU |
| 移除全连接层 | 提高稳定性 |
| Tanh激活 | 生成器输出层使用Tanh |
1.2 DCGAN架构图示
生成器架构:
输入: z ∈ R¹⁰⁰ ~ N(0,1)
↓
全连接层: 100 → 512×4×4 + BN + ReLU
↓
反卷积: 512ch, 4×4, stride=2, padding=1 + BN + ReLU → 256×8×8
↓
反卷积: 256ch, 4×4, stride=2, padding=1 + BN + ReLU → 128×16×16
↓
反卷积: 128ch, 4×4, stride=2, padding=1 + BN + ReLU → 64×32×32
↓
反卷积: 64ch, 4×4, stride=2, padding=1 + Tanh → 3×64×64
↓
输出: RGB图像 [64, 64, 3]
1.3 PyTorch实现
import torch
import torch.nn as nn
class DCGANGenerator(nn.Module):
"""DCGAN生成器实现"""
def __init__(self, latent_dim=100, channels=3, feature_g=64):
super().__init__()
self.net = nn.Sequential(
# 输入: latent_dim × 1 × 1
nn.ConvTranspose2d(latent_dim, feature_g * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(feature_g * 8),
nn.ReLU(True),
# 状态: (feature_g*8) × 4 × 4
nn.ConvTranspose2d(feature_g * 8, feature_g * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_g * 4),
nn.ReLU(True),
# 状态: (feature_g*4) × 8 × 8
nn.ConvTranspose2d(feature_g * 4, feature_g * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_g * 2),
nn.ReLU(True),
# 状态: (feature_g*2) × 16 × 16
nn.ConvTranspose2d(feature_g * 2, feature_g, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_g),
nn.ReLU(True),
# 状态: feature_g × 32 × 32
nn.ConvTranspose2d(feature_g, channels, 4, 2, 1, bias=False),
nn.Tanh()
# 输出: channels × 64 × 64
)
self._initialize_weights()
def _initialize_weights(self):
"""DCGAN推荐的初始化策略"""
for m in self.modules():
if isinstance(m, (nn.ConvTranspose2d, nn.Conv2d, nn.Linear)):
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
def forward(self, z):
"""前向传播
Args:
z: 随机噪声 (batch_size, latent_dim)
Returns:
生成的图像 (batch_size, channels, 64, 64)
"""
# 调整形状: (B, latent_dim) → (B, latent_dim, 1, 1)
z = z.view(z.size(0), z.size(1), 1, 1)
return self.net(z)
class DCGANDiscriminator(nn.Module):
"""DCGAN判别器实现"""
def __init__(self, channels=3, feature_d=64):
super().__init__()
self.net = nn.Sequential(
# 输入: channels × 64 × 64
nn.Conv2d(channels, feature_d, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# 状态: feature_d × 32 × 32
nn.Conv2d(feature_d, feature_d * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_d * 2),
nn.LeakyReLU(0.2, inplace=True),
# 状态: (feature_d*2) × 16 × 16
nn.Conv2d(feature_d * 2, feature_d * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_d * 4),
nn.LeakyReLU(0.2, inplace=True),
# 状态: (feature_d*4) × 8 × 8
nn.Conv2d(feature_d * 4, feature_d * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(feature_d * 8),
nn.LeakyReLU(0.2, inplace=True),
# 状态: (feature_d*8) × 4 × 4
nn.Conv2d(feature_d * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
# 输出: 1 × 1 × 1
)
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, (nn.ConvTranspose2d, nn.Conv2d, nn.Linear)):
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
def forward(self, x):
return self.net(x).view(-1, 1).squeeze(1)1.4 DCGAN的潜在空间特性
DCGAN的一个关键发现是潜在空间具有语义结构:
- 向量运算:潜在空间支持线性语义运算
- 插值平滑:两个点之间的插值产生平滑过渡
- 特征可视化:特定方向对应特定视觉属性
def latent_space_arithmetic(generator, z1, z2, z3, alpha=0.5):
"""
演示潜在空间向量运算
例: "戴眼镜的男人" - "男人" + "女人" = "戴眼镜的女人"
"""
# 假设 z1=戴眼镜的男人, z2=普通男人, z3=普通女人
z_new = z1 - alpha * z2 + alpha * z3
return generator(z_new)2. StyleGAN系列
2.1 StyleGAN核心创新
StyleGAN2引入风格化生成器架构,实现了前所未有的生成图像控制能力。
核心创新:
- 映射网络(Mapping Network):将潜在向量 映射到中间潜在空间
- 自适应实例归一化(AdaIN):在每层注入风格信息
- 移除输入层:从常量开始,避免纠缠
2.2 架构对比
| 组件 | DCGAN | StyleGAN |
|---|---|---|
| 潜在向量 | 直接输入 | 经映射网络处理 |
| 风格注入 | 无 | AdaIN每层注入 |
| 噪声输入 | 随机 | + 每层噪声 |
| 图像分辨率 | 固定 | 渐进式增长 |
2.3 StyleGAN生成器架构
映射网络
↓
w ∈ R⁵¹² (中间潜在向量)
↓
┌─────────────┼─────────────┐
↓ ↓ ↓
风格s₀ 风格s₁ 风格s₂ ...
↓ ↓ ↓
AdaIN AdaIN AdaIN
↓ ↓ ↓
Conv 4×4 Conv 3×3 Conv 3×3 ...
↓ ↓ ↓
分辨率 分辨率 分辨率
4×4 8×8 16×16 ...
AdaIN操作:
其中 是从风格向量 学习得到的缩放和偏置参数。
2.4 StyleGAN2改进
StyleGAN23针对StyleGAN的伪影问题进行了改进:
主要改进:
| 问题 | 解决方案 |
|---|---|
| 水滴伪影 | 重新设计归一化操作 |
| 权重调制解调 | 更好的风格控制 |
| 路径长度正则化 | 改善潜在空间平滑性 |
权重调制解调:
class StyleGAN2GeneratorBlock(nn.Module):
"""StyleGAN2生成器模块"""
def __init__(self, in_channels, out_channels, style_dim):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
self.normalize = AdaptiveInstanceNorm(in_channels, style_dim)
def forward(self, x, style):
# 权重调制
weight = self.conv.weight * style[:, None, :, None] # 调制
# ... 归一化 ...
return self.conv_forward(x, weight)2.5 StyleGAN3改进
StyleGAN34解决了StyleGAN2的纹理粘连(texture sticking)问题:
问题:特征似乎”粘”在图像坐标上,限制了旋转和平移的平滑性。
解决方案:
| 技术 | 说明 |
|---|---|
| 临界采样 | 移除上采样滤波器 |
| 等变约束 | 确保特征响应不受图像变换影响 |
| 傅里叶特征 | 使用傅里叶特征替代显式坐标 |
2.6 StyleGAN代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class StyleGAN2Generator(nn.Module):
"""StyleGAN2生成器(简化版)"""
def __init__(self, latent_dim=512, channels=3, hidden_dim=512):
super().__init__()
self.latent_dim = latent_dim
# 映射网络: z -> w
self.mapping = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
# 合成网络
self.const = nn.Parameter(torch.randn(1, hidden_dim, 4, 4))
# 卷积层(逐步增加分辨率)
self.convs = nn.ModuleList([
nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1),
nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1),
nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1),
nn.Conv2d(hidden_dim, hidden_dim // 2, 3, 1, 1),
nn.Conv2d(hidden_dim // 2, hidden_dim // 2, 3, 1, 1),
nn.Conv2d(hidden_dim // 2, hidden_dim // 4, 3, 1, 1),
nn.Conv2d(hidden_dim // 4, hidden_dim // 4, 3, 1, 1),
nn.Conv2d(hidden_dim // 4, hidden_dim // 8, 3, 1, 1),
nn.Conv2d(hidden_dim // 8, hidden_dim // 8, 3, 1, 1),
])
# 上采样层
self.upamples = nn.ModuleList([
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Upsample(scale_factor=2, mode='nearest'),
])
# AdaIN层
self.adains = nn.ModuleList([
AdaptiveInstanceNorm(hidden_dim, latent_dim)
for _ in range(18)
])
# 输出层
self.to_rgb = nn.Conv2d(hidden_dim // 8, channels, 3, 1, 1)
def forward(self, z, styles=None, noise=None):
"""前向传播
Args:
z: 潜在向量 (B, latent_dim)
styles: 可选的风格向量列表
noise: 可选的噪声列表
"""
# 映射: z -> w
w = self.mapping(z)
# 如果没有提供styles,使用w作为所有层的风格
if styles is None:
styles = [w] * 18
# 从常量开始
x = self.const.repeat(z.size(0), 1, 1, 1)
# 合成网络
resolution = 4
style_idx = 0
for i, (conv, upsample) in enumerate(zip(
[self.convs[i] for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], # 简化索引
[None, self.upamples[0], self.upamples[0], None,
self.upamples[1], self.upamples[1], None,
self.upamples[2], self.upamples[2], None]
)):
# 上采样
if upsample is not None:
x = upsample(x)
resolution *= 2
# AdaIN
x = self.adains[style_idx](x, styles[style_idx])
style_idx += 1
# 卷积
x = conv(x)
# 第二个AdaIN(对于有skip连接的层)
if i in [0, 3, 5, 7, 9]:
x = self.adains[style_idx](x, styles[style_idx])
style_idx += 1
return torch.tanh(self.to_rgb(x))
class AdaptiveInstanceNorm(nn.Module):
"""自适应实例归一化"""
def __init__(self, channels, style_dim):
super().__init__()
self.norm = nn.InstanceNorm2d(channels, affine=False)
self.style_scale = nn.Linear(style_dim, channels)
self.style_bias = nn.Linear(style_dim, channels)
def forward(self, x, style):
normalized = self.norm(x)
scale = self.style_scale(style).unsqueeze(2).unsqueeze(3)
bias = self.style_bias(style).unsqueeze(2).unsqueeze(3)
return normalized * (scale + 1) + bias3. BigGAN与GigaGAN
3.1 BigGAN核心创新
BigGAN5将GAN扩展到前所未有的规模,在ImageNet上实现了高质量类别条件图像生成。
核心创新:
| 技术 | 描述 |
|---|---|
| 大规模训练 | 数百至数千GPU并行训练 |
| 截断技巧 | 控制生成样本质量与多样性权衡 |
| 类别条件批归一化 | 根据类别嵌入调整归一化参数 |
| 正交正则化 | 防止权重矩阵病态 |
3.2 类别条件批归一化
class ConditionalBatchNorm(nn.Module):
"""类别条件批归一化"""
def __init__(self, num_features, num_classes):
super().__init__()
self.bn = nn.BatchNorm2d(num_features, affine=False)
self.embed = nn.Embedding(num_classes, num_features * 2)
# 学习每个类别的γ和β
def forward(self, x, class_idx):
out = self.bn(x)
gamma, beta = self.embed(class_idx).chunk(2, dim=1)
return out * (gamma.view(-1, 1, 1) + 1) + beta.view(-1, 1, 1)3.3 截断技巧(Truncation Trick)
BigGAN使用截断技巧来平衡生成质量与多样性:
def truncated_z(z, threshold=0.5):
"""截断潜在向量
更大的threshold = 更多样性但可能质量下降
更小的threshold = 更高质量但多样性下降
"""
if threshold > 0:
# 重采样使z在截断阈值内
z = torch.where(z.abs() > threshold,
torch.sign(z) * threshold,
z)
# 重缩放以保持方差
z = z * threshold / (z.abs().mean() + 1e-8)
return z3.4 GigaGAN
GigaGAN6进一步扩展,提出生成器协作策略:
创新点:
- 使用多个生成器协作
- 解决单一生成器的容量限制
- 在10亿参数规模仍保持稳定训练
4. 架构选择指南
| 场景 | 推荐架构 | 理由 |
|---|---|---|
| 快速原型 | DCGAN | 简单易实现 |
| 高质量人脸 | StyleGAN2/3 | 最先进的图像质量 |
| 类别条件生成 | BigGAN | ImageNet高质量 |
| 图像编辑 | StyleGAN | 良好的潜在空间结构 |
| 文本到图像 | DALL-E/Imagen | 需扩散模型 |
5. 参考资料
扩展阅读:
Footnotes
-
Radford A, Metz L, Chintala S. Unsupervised representation learning with deep convolutional generative adversarial networks. ICLR, 2016. arXiv:1511.06434 ↩
-
Karras T, Laine S, Aila T. A style-based generator architecture for generative adversarial networks. CVPR, 2019. arXiv:1812.04948 ↩
-
Karras T, Laine S, Aittala M, et al. Analyzing and improving the image quality of stylegan. CVPR, 2020. arXiv:1912.04958 ↩
-
Karras T, Aittala M, Laine S, et al. Alias-free generative adversarial networks. NeurIPS, 2021. arXiv:2106.12423 ↩
-
Brock A, Donahue J, Simonyan K. Large scale GAN training for high fidelity natural image synthesis. ICLR, 2019. arXiv:1809.11096 ↩
-
Kang M, Park J, Han B. GigaGAN: Large-scale GAN cascade for generative high fidelity synthesis. ICML, 2023. ↩