概述

DiT (Diffusion Transformer) 是扩散模型架构的重要演进方向,由 Peebles 和 Xie 在 2023 年提出。与传统的 UNet 架构相比,DiT 使用纯 Transformer 作为去噪网络的骨干,带来了更好的可扩展性和更强的表达能力。1

本文件系统介绍 DiT 的核心架构设计、条件机制、数学基础,以及最新的改进方向。


1. 从 UNet 到 Transformer

传统扩散模型的局限性

早期扩散模型(如 DDPM、ADM)普遍采用 UNet 作为去噪网络骨干。UNet 虽然在图像任务上表现出色,但存在以下局限:

方面UNetTransformer
归纳偏置强局部性假设弱归纳偏置,更灵活
参数效率中等高(相同参数更强表达)
可扩展性困难容易(天然支持并行)
长程依赖需多层级联全局注意力天然建模

DiT 的核心洞察

DiT 的关键发现:当使用足够强大的架构时,扩散模型可以在纯 Transformer 骨干上达到甚至超越 UNet 的性能。

“We replace the standard ConvNet encoder-decoder with a Vision Transformer operating on latent image patches.” — Peebles & Xie


2. 潜扩散框架

DiT 通常在潜空间(Latent Space)中操作,这得益于 VAE 的帮助。

完整流程

原始图像 x ∈ R^{H×W×3}
    ↓ VAE Encoder
潜变量 z ∈ R^{h×w×c}  (通常 h=H/8, w=W/8)
    ↓ Patchify
序列 tokens p_i ∈ R^{d}  (每个 patch 转为 d 维 token)
    ↓ DiT Transformer
去噪预测 ε_θ(z_t, t, c)
    ↓ VAE Decoder
生成图像 x̂

关键参数

参数含义典型值
原始图像分辨率256, 512
潜空间分辨率32, 64
Patch 大小2
潜通道数4
Transformer 隐藏维度1024

Token 数量计算

例如, 图像,


3. Patchify 模块

Patchify 是 DiT 处理图像的核心操作,将 2D 图像 patch 线性投影为序列 token。

数学表示

设输入潜变量 ,Patchify 操作定义为:

其中 ,

PyTorch 实现

class PatchEmbed(nn.Module):
    """将图像 patch 投影为 tokens"""
    
    def __init__(self, patch_size=2, in_channels=4, hidden_dim=1024):
        super().__init__()
        self.patch_size = patch_size
        # 每个 patch 展平后线性投影
        self.proj = nn.Linear(
            patch_size * patch_size * in_channels, 
            hidden_dim
        )
    
    def forward(self, x):
        """
        x: [B, C, H, W]
        """
        B, C, H, W = x.shape
        p = self.patch_size
        
        # 分割为 patches: [B, C, H/p, p, W/p, p] → [B, H/p, W/p, C*p*p]
        x = x.view(B, C, H // p, p, W // p, p)
        x = x.permute(0, 2, 4, 1, 3, 5).contiguous()
        x = x.view(B, (H // p) * (W // p), C * p * p)
        
        # 线性投影
        x = self.proj(x)  # [B, N, hidden_dim]
        return x

位置编码

DiT 使用标准的可学习位置编码或傅里叶位置编码:

class DiTPositionEmbedding(nn.Module):
    def __init__(self, num_patches, hidden_dim):
        super().__init__()
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches, hidden_dim)
        )
    
    def forward(self, x):
        return x + self.pos_embed

4. 条件机制:AdaLN

DiT 使用 Adaptive Layer Norm (AdaLN) 来注入时间步 和类别 条件。

AdaLN vs 原始 Layer Norm

原始 Layer Norm

AdaLN

其中 由条件 通过 MLP 生成。

AdaLN-Zero

DiT 的一大创新是 AdaLN-Zero,将 作为初始化。这使得残差分支的初始输出为零,促进训练的稳定性。

class DiTBlock(nn.Module):
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0):
        super().__init__()
        self.hidden_size = hidden_size
        
        # 条件调制网络
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )
        
        # 注意力
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
        self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
        
        # MLP
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_ratio * hidden_size),
            nn.GELU(approximate='tanh'),
            nn.Linear(mlp_ratio * hidden_size, hidden_size)
        )
        
        # 初始化 AdaLN 输出为零
        nn.init.zeros_(self.adaLN_modulation[-1].weight)
        nn.init.zeros_(self.adaLN_modulation[-1].bias)
    
    def forward(self, x, c):
        # c: 条件嵌入 (time + class)
        # 生成调制参数
        modulator = self.adaLN_modulation(c)
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = 
            modulator.chunk(6, dim=-1)
        
        # Self-attention with modulation
        x = x + gate_msa * self.attn(
            self.norm1(x) * (1 + scale_msa) + shift_msa,
            self.norm1(x) * (1 + scale_msa) + shift_msa,
            self.norm1(x) * (1 + scale_msa) + shift_msa
        )[0]
        
        # MLP with modulation
        x = x + gate_mlp * self.mlp(
            self.norm2(x) * (1 + scale_mlp) + shift_mlp
        )
        
        return x

条件的注入方式比较

方法描述优缺点
Cross-attention条件作为额外 key/value通用但计算量大
Adaptive Norm调整归一化参数高效且有效
AdaLN-ZeroAdaLN + 零初始化最常用,稳定训练

5. DiT 主干网络

完整 DiT 架构

class DiT(nn.Module):
    def __init__(
        self,
        hidden_size=1024,
        num_heads=16,
        num_layers=28,
        patch_size=2,
        in_channels=4,
        mlp_ratio=4.0,
        class_dropout_prob=0.1,
        num_classes=1000
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.patch_size = patch_size
        
        # Patch embedding
        self.x_embed = PatchEmbed(patch_size, in_channels, hidden_size)
        
        # 位置编码
        self.pos_embed = nn.Parameter(torch.zeros(1, 256, hidden_size))
        
        # 时间步嵌入
        self.t_embed = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size)
        )
        
        # 类别嵌入
        self.y_embed = nn.Sequential(
            nn.Linear(num_classes, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size)
        )
        
        # DiT Blocks
        self.blocks = nn.ModuleList([
            DiTBlock(hidden_size, num_heads, mlp_ratio)
            for _ in range(num_layers)
        ])
        
        # 输出归一化
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False)
        
        # 头部:patch 回填为图像
        self.proj_out = nn.Linear(
            hidden_size, 
            patch_size * patch_size * in_channels
        )
        
        self.initialize_weights()
    
    def initialize_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        
        self.apply(_basic_init)
        
        # Zero-initialize output projection
        nn.init.zeros_(self.proj_out.weight)
        nn.init.zeros_(self.proj_out.bias)
    
    def unpatchify(self, x):
        """将 tokens 还原为图像"""
        c = self.in_channels
        p = self.patch_size
        h = w = int(x.shape[1] ** 0.5)
        x = x.reshape(x.shape[0], h, w, p, p, c)
        x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
        x = x.reshape(x.shape[0], c, h * p, w * p)
        return x
    
    def forward(self, x, t, y):
        """
        x: 噪声潜变量 [B, C, H, W]
        t: 时间步 [B]
        y: 类别标签 [B]
        """
        # 获取 shape
        B, C, H, W = x.shape
        
        # Patchify
        x = self.x_embed(x)  # [B, N, D]
        x = x + self.pos_embed
        
        # 条件嵌入
        t = self.t_embed(timestep_embedding(t, self.hidden_size))
        y = self.y_embed(F.one_hot(y, num_classes=1000).float())
        c = t + y
        
        # 应用 DiT blocks
        for block in self.blocks:
            x = block(x, c)
        
        # 输出
        x = self.norm_final(x)
        x = self.proj_out(x)
        x = self.unpatchify(x)
        
        return x

时间步嵌入

def timestep_embedding(t, dim, max_period=10000):
    """
    创建正弦时间步嵌入
    
    遵循 Attention is All You Need 的位置编码方案
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(half, device=t.device) / half
    ).repeat(1, 2)
    
    args = t[:, None].float() * freqs
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    
    return embedding

6. DiT 变体配置

DiT 提供了不同规模的可扩展配置:

变体层数隐藏维度头数参数量GFLOPs
DiT-S12384639M61.6
DiT-B1276812123M118.6
DiT-L24102416457M1035
DiT-XL/228115216675M118.6

命名规则:DiT-{size}/{patch_size}

  • DiT-XL/2:XL 规模,patch_size=2(最小 patch)
  • DiT-B/4:B 规模,patch_size=4

7. 训练目标

DiT 预测噪声 ,使用简单的 MSE 损失:

其中

Classifier-Free Guidance

DiT 同样支持 Classifier-Free Guidance (CFG):

其中 是引导权重, 表示无条件预测。


8. 实验结果

ImageNet 256×256 结果

模型GFLOPs参数量FID ↓IS ↑
ADM2060554M1.48265.7
LDM-8266395M3.57185.4
DiT-XL/2118.6675M1.81241.5
DiT-XL/2 + CFG118.6675M1.55247.5

关键发现

  1. Patch 大小影响:patch_size=2 的 DiT-XL/2 显著优于 patch_size=8 的 DiT-XL/8
  2. 模型规模缩放:DiT-XL >> DiT-L >> DiT-B >> DiT-S
  3. 计算效率:DiT-XL/2 在比 ADM 少 17 倍计算量的情况下达到可比性能

9. 与传统 UNet 的对比

架构差异

方面UNetDiT
空间建模多尺度特征融合全序列自注意力
跳跃连接Encoder-Decoder 跳跃无(纯前馈)
条件注入多处注入AdaLN 统一注入
位置信息卷积天然编码显式位置编码
感受野逐渐增大初始全局

优劣分析

DiT 优势

  • ✅ 更好的可扩展性
  • ✅ 更强的长程依赖建模
  • ✅ 与大语言模型技术共享

DiT 局限

  • ❌ 计算量随序列长度二次增长
  • ❌ 需要更多训练数据
  • ❌ 训练稳定性要求更高

10. 实践指南

超参数选择

# 推荐配置
config = {
    # 模型规模
    'hidden_size': 1024,  # 增大提升质量
    'num_heads': 16,
    'num_layers': 28,
    
    # Patch 大小(影响显存和速度)
    'patch_size': 2,  # 2 质量最高,8 最快
    
    # 训练
    'learning_rate': 1e-4,
    'weight_decay': 0.0,  # DiT 通常不用 weight decay
    'batch_size': 256,
    
    # 推理
    'guidance_scale': 5.0,  # CFG 强度
    'num_sampling_steps': 50,
}

训练技巧

  1. 使用 EMA:DiT 对 EMA 敏感,推荐使用 0.9999 的 EMA
  2. 傅里叶位置编码:在某些任务上优于可学习位置编码
  3. 渐进式训练:先小 patch,再大 patch
  4. 混合精度:使用 FP16 加速训练

参考


相关阅读

Footnotes

  1. Peebles, W., & Xie, S. (2023). “Scalable Diffusion Models with Transformers.” ICCV 2023. arXiv:2212.09748