GAN架构变体

从2014年原始GAN提出至今,研究者们设计了许多专用架构来解决特定任务。本文档介绍三种最具影响力的GAN架构:DCGAN、StyleGAN和BigGAN。

1. DCGAN(Deep Convolutional GAN)

1.1 架构设计原则

DCGAN1是最早成功用于图像生成的GAN架构。其核心设计原则成为后续工作的基础:

设计原则说明
全卷积网络用步长卷积替代池化层
批归一化所有卷积层使用BatchNorm(生成器输出和判别器输入除外)
LeakyReLU激活判别器使用斜率0.2的LeakyReLU
移除全连接层提高稳定性
Tanh激活生成器输出层使用Tanh

1.2 DCGAN架构图示

生成器架构

输入: z ∈ R¹⁰⁰ ~ N(0,1)
  ↓
全连接层: 100 → 512×4×4 + BN + ReLU
  ↓
反卷积: 512ch, 4×4, stride=2, padding=1 + BN + ReLU → 256×8×8
  ↓
反卷积: 256ch, 4×4, stride=2, padding=1 + BN + ReLU → 128×16×16
  ↓
反卷积: 128ch, 4×4, stride=2, padding=1 + BN + ReLU → 64×32×32
  ↓
反卷积: 64ch, 4×4, stride=2, padding=1 + Tanh → 3×64×64
  ↓
输出: RGB图像 [64, 64, 3]

1.3 PyTorch实现

import torch
import torch.nn as nn
 
class DCGANGenerator(nn.Module):
    """DCGAN生成器实现"""
    def __init__(self, latent_dim=100, channels=3, feature_g=64):
        super().__init__()
        
        self.net = nn.Sequential(
            # 输入: latent_dim × 1 × 1
            nn.ConvTranspose2d(latent_dim, feature_g * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(feature_g * 8),
            nn.ReLU(True),
            # 状态: (feature_g*8) × 4 × 4
            
            nn.ConvTranspose2d(feature_g * 8, feature_g * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_g * 4),
            nn.ReLU(True),
            # 状态: (feature_g*4) × 8 × 8
            
            nn.ConvTranspose2d(feature_g * 4, feature_g * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_g * 2),
            nn.ReLU(True),
            # 状态: (feature_g*2) × 16 × 16
            
            nn.ConvTranspose2d(feature_g * 2, feature_g, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_g),
            nn.ReLU(True),
            # 状态: feature_g × 32 × 32
            
            nn.ConvTranspose2d(feature_g, channels, 4, 2, 1, bias=False),
            nn.Tanh()
            # 输出: channels × 64 × 64
        )
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        """DCGAN推荐的初始化策略"""
        for m in self.modules():
            if isinstance(m, (nn.ConvTranspose2d, nn.Conv2d, nn.Linear)):
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias.data, 0)
    
    def forward(self, z):
        """前向传播
        
        Args:
            z: 随机噪声 (batch_size, latent_dim)
        Returns:
            生成的图像 (batch_size, channels, 64, 64)
        """
        # 调整形状: (B, latent_dim) → (B, latent_dim, 1, 1)
        z = z.view(z.size(0), z.size(1), 1, 1)
        return self.net(z)
 
 
class DCGANDiscriminator(nn.Module):
    """DCGAN判别器实现"""
    def __init__(self, channels=3, feature_d=64):
        super().__init__()
        
        self.net = nn.Sequential(
            # 输入: channels × 64 × 64
            nn.Conv2d(channels, feature_d, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态: feature_d × 32 × 32
            
            nn.Conv2d(feature_d, feature_d * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_d * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态: (feature_d*2) × 16 × 16
            
            nn.Conv2d(feature_d * 2, feature_d * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_d * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态: (feature_d*4) × 8 × 8
            
            nn.Conv2d(feature_d * 4, feature_d * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(feature_d * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 状态: (feature_d*8) × 4 × 4
            
            nn.Conv2d(feature_d * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
            # 输出: 1 × 1 × 1
        )
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.ConvTranspose2d, nn.Conv2d, nn.Linear)):
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias.data, 0)
    
    def forward(self, x):
        return self.net(x).view(-1, 1).squeeze(1)

1.4 DCGAN的潜在空间特性

DCGAN的一个关键发现是潜在空间具有语义结构

  • 向量运算:潜在空间支持线性语义运算
  • 插值平滑:两个点之间的插值产生平滑过渡
  • 特征可视化:特定方向对应特定视觉属性
def latent_space_arithmetic(generator, z1, z2, z3, alpha=0.5):
    """
    演示潜在空间向量运算
    
    例: "戴眼镜的男人" - "男人" + "女人" = "戴眼镜的女人"
    """
    # 假设 z1=戴眼镜的男人, z2=普通男人, z3=普通女人
    z_new = z1 - alpha * z2 + alpha * z3
    return generator(z_new)

2. StyleGAN系列

2.1 StyleGAN核心创新

StyleGAN2引入风格化生成器架构,实现了前所未有的生成图像控制能力。

核心创新

  1. 映射网络(Mapping Network):将潜在向量 映射到中间潜在空间
  2. 自适应实例归一化(AdaIN):在每层注入风格信息
  3. 移除输入层:从常量开始,避免纠缠

2.2 架构对比

组件DCGANStyleGAN
潜在向量直接输入经映射网络处理
风格注入AdaIN每层注入
噪声输入随机 + 每层噪声
图像分辨率固定渐进式增长

2.3 StyleGAN生成器架构

                    映射网络
                       ↓
            w ∈ R⁵¹² (中间潜在向量)
                       ↓
         ┌─────────────┼─────────────┐
         ↓             ↓             ↓
      风格s₀        风格s₁        风格s₂ ...
         ↓             ↓             ↓
      AdaIN          AdaIN          AdaIN
         ↓             ↓             ↓
      Conv 4×4      Conv 3×3      Conv 3×3 ...
         ↓             ↓             ↓
      分辨率        分辨率        分辨率
      4×4           8×8           16×16 ...

AdaIN操作

其中 是从风格向量 学习得到的缩放和偏置参数。

2.4 StyleGAN2改进

StyleGAN23针对StyleGAN的伪影问题进行了改进:

主要改进

问题解决方案
水滴伪影重新设计归一化操作
权重调制解调更好的风格控制
路径长度正则化改善潜在空间平滑性

权重调制解调

class StyleGAN2GeneratorBlock(nn.Module):
    """StyleGAN2生成器模块"""
    def __init__(self, in_channels, out_channels, style_dim):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
        self.normalize = AdaptiveInstanceNorm(in_channels, style_dim)
    
    def forward(self, x, style):
        # 权重调制
        weight = self.conv.weight * style[:, None, :, None]  # 调制
        # ... 归一化 ...
        return self.conv_forward(x, weight)

2.5 StyleGAN3改进

StyleGAN34解决了StyleGAN2的纹理粘连(texture sticking)问题:

问题:特征似乎”粘”在图像坐标上,限制了旋转和平移的平滑性。

解决方案

技术说明
临界采样移除上采样滤波器
等变约束确保特征响应不受图像变换影响
傅里叶特征使用傅里叶特征替代显式坐标

2.6 StyleGAN代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class StyleGAN2Generator(nn.Module):
    """StyleGAN2生成器(简化版)"""
    def __init__(self, latent_dim=512, channels=3, hidden_dim=512):
        super().__init__()
        self.latent_dim = latent_dim
        
        # 映射网络: z -> w
        self.mapping = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        
        # 合成网络
        self.const = nn.Parameter(torch.randn(1, hidden_dim, 4, 4))
        
        # 卷积层(逐步增加分辨率)
        self.convs = nn.ModuleList([
            nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1),
            nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1),
            nn.Conv2d(hidden_dim, hidden_dim, 3, 1, 1),
            nn.Conv2d(hidden_dim, hidden_dim // 2, 3, 1, 1),
            nn.Conv2d(hidden_dim // 2, hidden_dim // 2, 3, 1, 1),
            nn.Conv2d(hidden_dim // 2, hidden_dim // 4, 3, 1, 1),
            nn.Conv2d(hidden_dim // 4, hidden_dim // 4, 3, 1, 1),
            nn.Conv2d(hidden_dim // 4, hidden_dim // 8, 3, 1, 1),
            nn.Conv2d(hidden_dim // 8, hidden_dim // 8, 3, 1, 1),
        ])
        
        # 上采样层
        self.upamples = nn.ModuleList([
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Upsample(scale_factor=2, mode='nearest'),
        ])
        
        # AdaIN层
        self.adains = nn.ModuleList([
            AdaptiveInstanceNorm(hidden_dim, latent_dim)
            for _ in range(18)
        ])
        
        # 输出层
        self.to_rgb = nn.Conv2d(hidden_dim // 8, channels, 3, 1, 1)
    
    def forward(self, z, styles=None, noise=None):
        """前向传播
        
        Args:
            z: 潜在向量 (B, latent_dim)
            styles: 可选的风格向量列表
            noise: 可选的噪声列表
        """
        # 映射: z -> w
        w = self.mapping(z)
        
        # 如果没有提供styles,使用w作为所有层的风格
        if styles is None:
            styles = [w] * 18
        
        # 从常量开始
        x = self.const.repeat(z.size(0), 1, 1, 1)
        
        # 合成网络
        resolution = 4
        style_idx = 0
        
        for i, (conv, upsample) in enumerate(zip(
            [self.convs[i] for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]],  # 简化索引
            [None, self.upamples[0], self.upamples[0], None, 
             self.upamples[1], self.upamples[1], None,
             self.upamples[2], self.upamples[2], None]
        )):
            # 上采样
            if upsample is not None:
                x = upsample(x)
                resolution *= 2
            
            # AdaIN
            x = self.adains[style_idx](x, styles[style_idx])
            style_idx += 1
            
            # 卷积
            x = conv(x)
            
            # 第二个AdaIN(对于有skip连接的层)
            if i in [0, 3, 5, 7, 9]:
                x = self.adains[style_idx](x, styles[style_idx])
                style_idx += 1
        
        return torch.tanh(self.to_rgb(x))
 
 
class AdaptiveInstanceNorm(nn.Module):
    """自适应实例归一化"""
    def __init__(self, channels, style_dim):
        super().__init__()
        self.norm = nn.InstanceNorm2d(channels, affine=False)
        self.style_scale = nn.Linear(style_dim, channels)
        self.style_bias = nn.Linear(style_dim, channels)
    
    def forward(self, x, style):
        normalized = self.norm(x)
        scale = self.style_scale(style).unsqueeze(2).unsqueeze(3)
        bias = self.style_bias(style).unsqueeze(2).unsqueeze(3)
        return normalized * (scale + 1) + bias

3. BigGAN与GigaGAN

3.1 BigGAN核心创新

BigGAN5将GAN扩展到前所未有的规模,在ImageNet上实现了高质量类别条件图像生成。

核心创新

技术描述
大规模训练数百至数千GPU并行训练
截断技巧控制生成样本质量与多样性权衡
类别条件批归一化根据类别嵌入调整归一化参数
正交正则化防止权重矩阵病态

3.2 类别条件批归一化

class ConditionalBatchNorm(nn.Module):
    """类别条件批归一化"""
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.bn = nn.BatchNorm2d(num_features, affine=False)
        self.embed = nn.Embedding(num_classes, num_features * 2)
        # 学习每个类别的γ和β
        
    def forward(self, x, class_idx):
        out = self.bn(x)
        gamma, beta = self.embed(class_idx).chunk(2, dim=1)
        return out * (gamma.view(-1, 1, 1) + 1) + beta.view(-1, 1, 1)

3.3 截断技巧(Truncation Trick)

BigGAN使用截断技巧来平衡生成质量与多样性:

def truncated_z(z, threshold=0.5):
    """截断潜在向量
    
    更大的threshold = 更多样性但可能质量下降
    更小的threshold = 更高质量但多样性下降
    """
    if threshold > 0:
        # 重采样使z在截断阈值内
        z = torch.where(z.abs() > threshold, 
                       torch.sign(z) * threshold, 
                       z)
        # 重缩放以保持方差
        z = z * threshold / (z.abs().mean() + 1e-8)
    return z

3.4 GigaGAN

GigaGAN6进一步扩展,提出生成器协作策略:

创新点

  • 使用多个生成器协作
  • 解决单一生成器的容量限制
  • 在10亿参数规模仍保持稳定训练

4. 架构选择指南

场景推荐架构理由
快速原型DCGAN简单易实现
高质量人脸StyleGAN2/3最先进的图像质量
类别条件生成BigGANImageNet高质量
图像编辑StyleGAN良好的潜在空间结构
文本到图像DALL-E/Imagen需扩散模型

5. 参考资料

扩展阅读:

Footnotes

  1. Radford A, Metz L, Chintala S. Unsupervised representation learning with deep convolutional generative adversarial networks. ICLR, 2016. arXiv:1511.06434

  2. Karras T, Laine S, Aila T. A style-based generator architecture for generative adversarial networks. CVPR, 2019. arXiv:1812.04948

  3. Karras T, Laine S, Aittala M, et al. Analyzing and improving the image quality of stylegan. CVPR, 2020. arXiv:1912.04958

  4. Karras T, Aittala M, Laine S, et al. Alias-free generative adversarial networks. NeurIPS, 2021. arXiv:2106.12423

  5. Brock A, Donahue J, Simonyan K. Large scale GAN training for high fidelity natural image synthesis. ICLR, 2019. arXiv:1809.11096

  6. Kang M, Park J, Han B. GigaGAN: Large-scale GAN cascade for generative high fidelity synthesis. ICML, 2023.