概述
DiT (Diffusion Transformer) 是扩散模型架构的重要演进方向,由 Peebles 和 Xie 在 2023 年提出。与传统的 UNet 架构相比,DiT 使用纯 Transformer 作为去噪网络的骨干,带来了更好的可扩展性和更强的表达能力。1
本文件系统介绍 DiT 的核心架构设计、条件机制、数学基础,以及最新的改进方向。
1. 从 UNet 到 Transformer
传统扩散模型的局限性
早期扩散模型(如 DDPM、ADM)普遍采用 UNet 作为去噪网络骨干。UNet 虽然在图像任务上表现出色,但存在以下局限:
| 方面 | UNet | Transformer |
|---|---|---|
| 归纳偏置 | 强局部性假设 | 弱归纳偏置,更灵活 |
| 参数效率 | 中等 | 高(相同参数更强表达) |
| 可扩展性 | 困难 | 容易(天然支持并行) |
| 长程依赖 | 需多层级联 | 全局注意力天然建模 |
DiT 的核心洞察
DiT 的关键发现:当使用足够强大的架构时,扩散模型可以在纯 Transformer 骨干上达到甚至超越 UNet 的性能。
“We replace the standard ConvNet encoder-decoder with a Vision Transformer operating on latent image patches.” — Peebles & Xie
2. 潜扩散框架
DiT 通常在潜空间(Latent Space)中操作,这得益于 VAE 的帮助。
完整流程
原始图像 x ∈ R^{H×W×3}
↓ VAE Encoder
潜变量 z ∈ R^{h×w×c} (通常 h=H/8, w=W/8)
↓ Patchify
序列 tokens p_i ∈ R^{d} (每个 patch 转为 d 维 token)
↓ DiT Transformer
去噪预测 ε_θ(z_t, t, c)
↓ VAE Decoder
生成图像 x̂
关键参数
| 参数 | 含义 | 典型值 |
|---|---|---|
| 原始图像分辨率 | 256, 512 | |
| 潜空间分辨率 | 32, 64 | |
| Patch 大小 | 2 | |
| 潜通道数 | 4 | |
| Transformer 隐藏维度 | 1024 |
Token 数量计算:
例如, 图像,,:
3. Patchify 模块
Patchify 是 DiT 处理图像的核心操作,将 2D 图像 patch 线性投影为序列 token。
数学表示
设输入潜变量 ,Patchify 操作定义为:
其中 , 。
PyTorch 实现
class PatchEmbed(nn.Module):
"""将图像 patch 投影为 tokens"""
def __init__(self, patch_size=2, in_channels=4, hidden_dim=1024):
super().__init__()
self.patch_size = patch_size
# 每个 patch 展平后线性投影
self.proj = nn.Linear(
patch_size * patch_size * in_channels,
hidden_dim
)
def forward(self, x):
"""
x: [B, C, H, W]
"""
B, C, H, W = x.shape
p = self.patch_size
# 分割为 patches: [B, C, H/p, p, W/p, p] → [B, H/p, W/p, C*p*p]
x = x.view(B, C, H // p, p, W // p, p)
x = x.permute(0, 2, 4, 1, 3, 5).contiguous()
x = x.view(B, (H // p) * (W // p), C * p * p)
# 线性投影
x = self.proj(x) # [B, N, hidden_dim]
return x位置编码
DiT 使用标准的可学习位置编码或傅里叶位置编码:
class DiTPositionEmbedding(nn.Module):
def __init__(self, num_patches, hidden_dim):
super().__init__()
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches, hidden_dim)
)
def forward(self, x):
return x + self.pos_embed4. 条件机制:AdaLN
DiT 使用 Adaptive Layer Norm (AdaLN) 来注入时间步 和类别 条件。
AdaLN vs 原始 Layer Norm
原始 Layer Norm:
AdaLN:
其中 由条件 通过 MLP 生成。
AdaLN-Zero
DiT 的一大创新是 AdaLN-Zero,将 作为初始化。这使得残差分支的初始输出为零,促进训练的稳定性。
class DiTBlock(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0):
super().__init__()
self.hidden_size = hidden_size
# 条件调制网络
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
# 注意力
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
# MLP
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_ratio * hidden_size),
nn.GELU(approximate='tanh'),
nn.Linear(mlp_ratio * hidden_size, hidden_size)
)
# 初始化 AdaLN 输出为零
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)
def forward(self, x, c):
# c: 条件嵌入 (time + class)
# 生成调制参数
modulator = self.adaLN_modulation(c)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp =
modulator.chunk(6, dim=-1)
# Self-attention with modulation
x = x + gate_msa * self.attn(
self.norm1(x) * (1 + scale_msa) + shift_msa,
self.norm1(x) * (1 + scale_msa) + shift_msa,
self.norm1(x) * (1 + scale_msa) + shift_msa
)[0]
# MLP with modulation
x = x + gate_mlp * self.mlp(
self.norm2(x) * (1 + scale_mlp) + shift_mlp
)
return x条件的注入方式比较
| 方法 | 描述 | 优缺点 |
|---|---|---|
| Cross-attention | 条件作为额外 key/value | 通用但计算量大 |
| Adaptive Norm | 调整归一化参数 | 高效且有效 |
| AdaLN-Zero | AdaLN + 零初始化 | 最常用,稳定训练 |
5. DiT 主干网络
完整 DiT 架构
class DiT(nn.Module):
def __init__(
self,
hidden_size=1024,
num_heads=16,
num_layers=28,
patch_size=2,
in_channels=4,
mlp_ratio=4.0,
class_dropout_prob=0.1,
num_classes=1000
):
super().__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.patch_size = patch_size
# Patch embedding
self.x_embed = PatchEmbed(patch_size, in_channels, hidden_size)
# 位置编码
self.pos_embed = nn.Parameter(torch.zeros(1, 256, hidden_size))
# 时间步嵌入
self.t_embed = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size)
)
# 类别嵌入
self.y_embed = nn.Sequential(
nn.Linear(num_classes, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size)
)
# DiT Blocks
self.blocks = nn.ModuleList([
DiTBlock(hidden_size, num_heads, mlp_ratio)
for _ in range(num_layers)
])
# 输出归一化
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False)
# 头部:patch 回填为图像
self.proj_out = nn.Linear(
hidden_size,
patch_size * patch_size * in_channels
)
self.initialize_weights()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Zero-initialize output projection
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
def unpatchify(self, x):
"""将 tokens 还原为图像"""
c = self.in_channels
p = self.patch_size
h = w = int(x.shape[1] ** 0.5)
x = x.reshape(x.shape[0], h, w, p, p, c)
x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
x = x.reshape(x.shape[0], c, h * p, w * p)
return x
def forward(self, x, t, y):
"""
x: 噪声潜变量 [B, C, H, W]
t: 时间步 [B]
y: 类别标签 [B]
"""
# 获取 shape
B, C, H, W = x.shape
# Patchify
x = self.x_embed(x) # [B, N, D]
x = x + self.pos_embed
# 条件嵌入
t = self.t_embed(timestep_embedding(t, self.hidden_size))
y = self.y_embed(F.one_hot(y, num_classes=1000).float())
c = t + y
# 应用 DiT blocks
for block in self.blocks:
x = block(x, c)
# 输出
x = self.norm_final(x)
x = self.proj_out(x)
x = self.unpatchify(x)
return x时间步嵌入
def timestep_embedding(t, dim, max_period=10000):
"""
创建正弦时间步嵌入
遵循 Attention is All You Need 的位置编码方案
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(half, device=t.device) / half
).repeat(1, 2)
args = t[:, None].float() * freqs
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding6. DiT 变体配置
DiT 提供了不同规模的可扩展配置:
| 变体 | 层数 | 隐藏维度 | 头数 | 参数量 | GFLOPs |
|---|---|---|---|---|---|
| DiT-S | 12 | 384 | 6 | 39M | 61.6 |
| DiT-B | 12 | 768 | 12 | 123M | 118.6 |
| DiT-L | 24 | 1024 | 16 | 457M | 1035 |
| DiT-XL/2 | 28 | 1152 | 16 | 675M | 118.6 |
命名规则:DiT-{size}/{patch_size}
DiT-XL/2:XL 规模,patch_size=2(最小 patch)DiT-B/4:B 规模,patch_size=4
7. 训练目标
DiT 预测噪声 ,使用简单的 MSE 损失:
其中 。
Classifier-Free Guidance
DiT 同样支持 Classifier-Free Guidance (CFG):
其中 是引导权重, 表示无条件预测。
8. 实验结果
ImageNet 256×256 结果
| 模型 | GFLOPs | 参数量 | FID ↓ | IS ↑ |
|---|---|---|---|---|
| ADM | 2060 | 554M | 1.48 | 265.7 |
| LDM-8 | 266 | 395M | 3.57 | 185.4 |
| DiT-XL/2 | 118.6 | 675M | 1.81 | 241.5 |
| DiT-XL/2 + CFG | 118.6 | 675M | 1.55 | 247.5 |
关键发现
- Patch 大小影响:patch_size=2 的 DiT-XL/2 显著优于 patch_size=8 的 DiT-XL/8
- 模型规模缩放:DiT-XL >> DiT-L >> DiT-B >> DiT-S
- 计算效率:DiT-XL/2 在比 ADM 少 17 倍计算量的情况下达到可比性能
9. 与传统 UNet 的对比
架构差异
| 方面 | UNet | DiT |
|---|---|---|
| 空间建模 | 多尺度特征融合 | 全序列自注意力 |
| 跳跃连接 | Encoder-Decoder 跳跃 | 无(纯前馈) |
| 条件注入 | 多处注入 | AdaLN 统一注入 |
| 位置信息 | 卷积天然编码 | 显式位置编码 |
| 感受野 | 逐渐增大 | 初始全局 |
优劣分析
DiT 优势:
- ✅ 更好的可扩展性
- ✅ 更强的长程依赖建模
- ✅ 与大语言模型技术共享
DiT 局限:
- ❌ 计算量随序列长度二次增长
- ❌ 需要更多训练数据
- ❌ 训练稳定性要求更高
10. 实践指南
超参数选择
# 推荐配置
config = {
# 模型规模
'hidden_size': 1024, # 增大提升质量
'num_heads': 16,
'num_layers': 28,
# Patch 大小(影响显存和速度)
'patch_size': 2, # 2 质量最高,8 最快
# 训练
'learning_rate': 1e-4,
'weight_decay': 0.0, # DiT 通常不用 weight decay
'batch_size': 256,
# 推理
'guidance_scale': 5.0, # CFG 强度
'num_sampling_steps': 50,
}训练技巧
- 使用 EMA:DiT 对 EMA 敏感,推荐使用 0.9999 的 EMA
- 傅里叶位置编码:在某些任务上优于可学习位置编码
- 渐进式训练:先小 patch,再大 patch
- 混合精度:使用 FP16 加速训练
参考
相关阅读
- Diffusion Model 基础 — 扩散模型的核心原理
- DiT vs UNet 理论分析 — 为什么 Transformer 优于 UNet
- HiDiT 高效 DiT 架构 — DiT 的效率优化
- Flow Matching 最优传输 — 替代 DDPM 的生成范式
Footnotes
-
Peebles, W., & Xie, S. (2023). “Scalable Diffusion Models with Transformers.” ICCV 2023. arXiv:2212.09748 ↩