DiT: Diffusion Transformer 架构详解
概述
DiT (Diffusion Transformer) 由 William Peebles 和 Saining Xie 于 ICCV 2023 提出,是首个成功将 Transformer 架构应用于扩散模型的工作。1
DiT 的核心贡献是证明了:
- Transformer 可以作为扩散模型的主干网络
- DiT 遵循与 LLM 相同的 Scaling Laws
- U-Net 的归纳偏置在扩散模型中并非必须
1. DiT 整体架构
1.1 架构概览
┌─────────────────────────────────────────────────────────────────┐
│ DiT 架构 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 输入图像 │
│ │ │
│ ▼ │
│ ┌─────────┐ │
│ │ VAE │ 编码到 latent 空间 (如 32×32) │
│ └────┬────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ │
│ │ Patchify │ 将 latent 分割成 patches (如 2×2) │
│ └──────┬──────┘ │
│ │ │
│ ▼ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Linear Proj │ │ DiT Blocks │ ... │ DiT Blocks │ │
│ │ + Pos Emb │ ──▶ │ (×N) │ │ (×N) │ │
│ └─────────────┘ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │
│ └──────────┬─────────┘ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Norm + Proj │ │
│ └────────┬────────┘ │
│ │ │
│ ▼ │
│ ┌─────────┐ │
│ │ Unpatch │ 将 patches 合并回 latent │
│ └────┬────┘ │
│ │ │
│ ▼ │
│ ┌─────────┐ │
│ │ VAE │ 解码回图像空间 │
│ └─────────┘ │
│ │ │
│ ▼ │
│ 输出图像 │
│ │
└─────────────────────────────────────────────────────────────────┘
1.2 核心组件
| 组件 | 描述 |
|---|---|
| VAE | 将图像编码到 latent 空间 (如 SD 的 VAE, 下采样 8x) |
| Patchify | 将 latent 分割成非重叠的 patches |
| DiT Block | 自适应 LayerNorm (AdaLN) + Self-Attention + FFN |
| 位置编码 | 傅里叶位置编码 (RoPE 或标准正弦) |
| Final Layer | 预测噪声 和方差 |
2. Patchify 层
2.1 作用
Patchify 将 2D latent 特征图转换为 patch 序列,类似于 ViT 对图像的处理。
输入:
- Latent 特征:
- 例如 SD VAE 输出: (256×256 图像下采样 8x)
输出:
- Patch 序列:
- 其中 , 是 patch size
2.2 实现
class Patchify(nn.Module):
"""
将 latent 特征图分割成 patches
"""
def __init__(self, in_channels=4, patch_size=2, hidden_dim=768):
super().__init__()
self.patch_size = patch_size
self.num_patches = (32 // patch_size) ** 2 # 取决于 VAE
# 每个 patch 映射到 hidden_dim
self.proj = nn.Conv2d(
in_channels,
hidden_dim,
kernel_size=patch_size,
stride=patch_size
)
def forward(self, x):
"""
Args:
x: (B, C, H, W) latent 特征
Returns:
patches: (B, N, D) 其中 N = H*W/patch_size²
"""
# Conv2d 实现 patchify
x = self.proj(x) # (B, D, H/p, W/p)
x = x.flatten(2).transpose(1, 2) # (B, N, D)
return x2.3 Patch Size 的影响
| Patch Size | Token 数 | 计算量 | 质量 |
|---|---|---|---|
| 8 | 16 | 最低 | 较低 |
| 4 | 64 | 中等 | 中等 |
| 2 | 256 | 较高 | 最高 |
发现:Patch size=2 的 DiT 生成质量最高,但计算量也最大。
3. 位置编码
3.1 傅里叶位置编码
DiT 使用标准的正弦位置编码(与 ViT 相同):
def get_positional_encoding(seq_len, dim, device):
"""
生成傅里叶位置编码
Args:
seq_len: 序列长度 (patch 数)
dim: 编码维度
device: 设备
Returns:
pos_emb: (1, seq_len, dim)
"""
position = torch.arange(seq_len, device=device).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, dim, 2, device=device) * (-np.log(10000) / dim)
)
pe = torch.zeros(seq_len, dim, device=device)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe.unsqueeze(0)3.2 条件位置编码
对于时间步 和类别 的条件,DiT 使用自适应方式注入位置编码的缩放/偏移。
4. DiT Block:自适应条件注入
4.1 四种 Conditioning 方案对比
DiT 论文系统比较了四种条件注入方式:
| 方法 | 描述 | 效果 |
|---|---|---|
| In-context | 将 和 作为额外 token 拼接 | 较差 |
| Cross-attention | 添加独立的 cross-attention 层 | 中等 |
| AdaLN | 自适应 LayerNorm | 较好 |
| AdaLN-Zero | AdaLN + 门控初始化为 0 | 最优 |
4.2 AdaLN (Adaptive Layer Norm)
AdaLN 根据时间步 和类别 自适应调整 LayerNorm 的参数:
class AdaLN(nn.Module):
"""
自适应 LayerNorm
根据条件 c 动态调整 scale 和 shift
"""
def __init__(self, dim, cond_dim):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
# 预测 scale, shift, gate
self.linear = nn.Linear(cond_dim, dim * 3)
def forward(self, x, cond):
"""
Args:
x: (B, N, D) 输入特征
cond: (B, cond_dim) 条件 (t, c)
Returns:
自适应归一化后的特征
"""
# 预测调制参数
scale, shift, gate = self.linear(cond).chunk(3, dim=-1)
# LayerNorm
x = self.norm(x)
# 调制
x = x * (1 + scale) + shift
x = x * (1 + gate) # 门控
return x4.3 AdaLN-Zero
AdaLN-Zero 是 DiT 的核心创新:
class AdaLNZero(nn.Module):
"""
AdaLN-Zero: 门控初始化为 0
核心洞察:每个 block 训练初期应该接近恒等映射
"""
def __init__(self, dim, cond_dim):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=False)
# 预测 scale, shift, gate (gate 初始化为 0)
self.linear = nn.Linear(cond_dim, dim * 3)
# 初始化:gate 为 0,其他为标准初始化
nn.init.zeros_(self.linear.weight[:, 2*dim:]) # gate 置 0
nn.init.zeros_(self.linear.bias[2*dim:]) # gate bias 置 0
def forward(self, x, cond):
# 预测调制参数
scale, shift, gate = self.linear(cond).chunk(3, dim=-1)
# LayerNorm
x = self.norm(x)
# 调制
x = x * (1 + scale) + shift
x = gate * x # 训练初期 gate≈0,接近恒等映射
return xAdaLN-Zero 的优势:
- 训练初期每个 block 等效为恒等映射
- 随着训练进行,门控逐渐学习到有用调制
- 显著提升深度网络的训练稳定性
4.4 完整的 DiT Block
class DiTBlock(nn.Module):
"""
DiT Block: AdaLN + Self-Attention + FFN
"""
def __init__(self, dim, num_heads, mlp_ratio=4.0, cond_dim=256):
super().__init__()
self.norm1 = AdaLNZero(dim, cond_dim)
self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.norm2 = AdaLNZero(dim, cond_dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * mlp_ratio),
nn.GELU(),
nn.Linear(dim * mlp_ratio, dim)
)
# MLP 门控也初始化为 0
nn.init.zeros_(self.mlp[0].weight)
nn.init.zeros_(self.mlp[0].bias)
def forward(self, x, cond):
# Self-Attention
x = x + self.norm1.gate * self.attn(
self.norm1(x, cond),
self.norm1(x, cond),
self.norm1(x, cond)
)[0]
# FFN
x = x + self.norm2.gate * self.mlp(self.norm2(x, cond))
return x5. 时间步与类别 Embedding
5.1 时间步 Embedding
class TimestepEmbedder(nn.Module):
"""
时间步 embedding: MLP 编码正弦位置编码
"""
def __init__(self, hidden_dim, frequency_dim=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.frequency_dim = frequency_dim
def forward(self, t):
"""
Args:
t: (B,) 时间步 (整数或连续值)
Returns:
emb: (B, hidden_dim)
"""
# 正弦位置编码
freqs = torch.exp(
-torch.log(torch.tensor(10000)) *
torch.arange(0, self.frequency_dim, 2, device=t.device) / self.frequency_dim
)
# 编码
t_emb = t[:, None] * freqs[None, :]
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
return self.mlp(t_emb)5.2 类别 Embedding (可选)
对于分类条件,使用类似的 embedding 层:
class LabelEmbedder(nn.Module):
"""类别 embedding"""
def __init__(self, num_classes, hidden_dim):
super().__init__()
self.embedding = nn.Embedding(num_classes, hidden_dim)
def forward(self, y):
return self.embedding(y)5.3 条件融合
def combine_condition(t_emb, y_emb=None):
"""融合时间和类别条件"""
if y_emb is not None:
# 拼接
return torch.cat([t_emb, y_emb], dim=-1)
else:
return t_emb6. 模型变体与缩放
6.1 DiT 模型系列
| 模型 | Depth | Hidden Dim | Heads | 参数量 | GFLOPs | FID |
|---|---|---|---|---|---|---|
| DiT-S/2 | 12 | 384 | 6 | 33M | 118 | 68.4 |
| DiT-B/2 | 12 | 768 | 12 | 130M | 1185 | 43.5 |
| DiT-L/2 | 24 | 1024 | 16 | 458M | 3291 | 9.62 |
| DiT-XL/2 | 28 | 1152 | 16 | 675M | 1181 | 2.27 |
关键发现:DiT-XL/2 使用与 DiT-B/2 相似的计算量,但质量大幅提升!
6.2 Scaling Laws
DiT 遵循与语言模型相同的 scaling laws:
FID (越低越好)
│ ★ DiT-XL/2
│ ★ DiT-L/2
│ ★ DiT-B/2
│ ★ DiT-S/2
│
└───────────────────────────────────────▶ 计算量 (GFLOPs)
100 1000 10000
实验结论:
- 固定计算预算下,更大的模型 + 更少的步骤优于小模型 + 多步骤
- FID 与 GFLOPs 呈幂律关系
- 模型大小比 token 数量更重要
7. 输出预测:噪声与方差
7.1 两种输出方式
DiT 需要同时预测:
- 噪声
- 方差 (或 )
class FinalLayer(nn.Module):
"""
最终层:预测噪声和方差
"""
def __init__(self, hidden_dim, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_dim, elementwise_affine=False)
self.linear = nn.Linear(hidden_dim, patch_size * patch_size * out_channels)
# 方差预测
self.linear_out = nn.Linear(hidden_dim, patch_size * patch_size * out_channels)
# 初始化
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
def forward(self, x):
"""
Returns:
noise: (B, N, P*D)
variance: (B, N, P*D) 或标量
"""
x = self.norm_final(x)
noise = self.linear(x)
# 可以预测 log variance 或直接预测 beta
# 取决于噪声调度
return noise, None # 简化版7.2 条件独立预测
更常见的实现是条件独立预测噪声和方差:
class DiTOutput(nn.Module):
"""DiT 输出头"""
def __init__(self, hidden_dim, patch_size, out_channels=4):
super().__init__()
self.norm = nn.LayerNorm(hidden_dim)
# 噪声预测头
self.head_noise = nn.Linear(hidden_dim, patch_size**2 * out_channels)
# 方差预测头
self.head_variance = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
x = self.norm(x)
noise = self.head_noise(x)
log_var = self.head_variance(x)
return noise, log_var8. 完整 DiT 实现
class DiT(nn.Module):
"""
完整的 DiT 模型
"""
def __init__(
self,
in_channels=4,
patch_size=2,
hidden_dim=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
num_classes=None,
cond_dim=None
):
super().__init__()
self.hidden_dim = hidden_dim
self.out_channels = in_channels
self.patch_size = patch_size
# Patchify
self.patchify = Patchify(in_channels, patch_size, hidden_dim)
self.num_patches = self.patchify.num_patches
# 位置编码
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, hidden_dim))
nn.init.normal_(self.pos_embed, std=0.02)
# 时间步和类别 embedding
self.t_embed = TimestepEmbedder(hidden_dim)
if num_classes is not None:
self.y_embed = LabelEmbedder(num_classes, hidden_dim)
cond_dim = hidden_dim * 2
else:
cond_dim = hidden_dim
# DiT Blocks
self.blocks = nn.ModuleList([
DiTBlock(hidden_dim, num_heads, mlp_ratio, cond_dim)
for _ in range(depth)
])
# Final Layer
self.final_layer = FinalLayer(hidden_dim, patch_size, in_channels)
self.initialize_weights()
def initialize_weights(self):
"""权重初始化"""
def _basic_init(module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
self.apply(_basic_init)
# AdaLN-Zero 门控初始化为 0
for block in self.blocks:
nn.init.zeros_(block.norm1.linear.weight[:, 2*self.hidden_dim:])
nn.init.zeros_(block.norm1.linear.bias[2*self.hidden_dim:])
def unpatchify(self, x):
"""将 patches 合并回特征图"""
B, N, D = x.shape
H = W = int(N ** 0.5)
p = self.patch_size
x = x.reshape(B, H, W, p, p, self.out_channels)
x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
x = x.reshape(B, self.out_channels, H*p, W*p)
return x
def forward(self, x, t, y=None):
"""
前向传播
Args:
x: (B, C, H, W) latent 特征
t: (B,) 时间步
y: (B,) 类别标签 (可选)
Returns:
noise: (B, C, H, W) 预测的噪声
"""
B = x.shape[0]
# Patchify
x = self.patchify(x) # (B, N, D)
x = x + self.pos_embed
# 时间步 embedding
t_emb = self.t_embed(t)
# 类别 embedding
if y is not None:
y_emb = self.y_embed(y)
cond = torch.cat([t_emb, y_emb], dim=-1)
else:
cond = t_emb
# DiT Blocks
for block in self.blocks:
x = block(x, cond)
# Final Layer
x = self.final_layer(x)
x = self.unpatchify(x)
return x9. 训练与采样
9.1 训练目标
def training_step(model, batch, optimizer):
"""
DiT 训练步骤
"""
x0, y = batch # 原始图像和标签
# VAE 编码
with torch.no_grad():
z = vae.encode(x0).latent_dist.sample()
z = z * vae.config.scaling_factor
# 采样时间步
t = torch.randint(0, model.num_timesteps, (B,), device=device)
# 添加噪声
noise = torch.randn_like(z)
z_t = add_noise(z, t, alpha_bar)
# 预测噪声
noise_pred = model(z_t, t, y)
# 损失
loss = F.mse_loss(noise_pred, noise)
# 反向传播
loss.backward()
optimizer.step()
optimizer.zero_grad()
return loss9.2 采样
@torch.no_grad()
def sampling(model, z_T, y=None, num_steps=50):
"""
DDIM 采样
"""
model.eval()
timesteps = torch.linspace(model.num_timesteps-1, 0, num_steps).long()
z = z_T
for i, t in enumerate(tqdm(timesteps)):
t_batch = torch.full((B,), t, device=device, dtype=torch.long)
# 预测噪声
noise_pred = model(z, t_batch, y)
# DDIM 更新 (简化版)
alpha_bar_t = alphas_cumprod[t]
alpha_bar_t_prev = alphas_cumprod[timesteps[i-1]] if i > 0 else 1
# 预测 x0
pred_x0 = (z - torch.sqrt(1-alpha_bar_t) * noise_pred) / torch.sqrt(alpha_bar_t)
# 方向指向 x0
direction = torch.sqrt(1 - alpha_bar_t_prev) * noise_pred
# 组合
z = torch.sqrt(alpha_bar_t_prev) * pred_x0 + direction
# VAE 解码
x = vae.decode(z / vae.config.scaling_factor)
return x10. 与 U-Net 的对比
10.1 架构差异
| 维度 | U-Net | DiT |
|---|---|---|
| 结构 | Encoder-Decoder + Skip | 等向性 Transformer |
| 位置信息 | 内置于卷积 | 显式位置编码 |
| 感受野 | 局部到全局 | 全局自注意力 |
| 层次结构 | 多尺度特征融合 | 统一表示 |
10.2 优缺点对比
| 方面 | U-Net | DiT |
|---|---|---|
| 收敛速度 | 快 | 较慢 |
| Scaling | 无明确规律 | 遵循 Scaling Laws |
| 计算效率 | 高分辨率下效率下降 | 可通过 patch_size 调节 |
| 长距离依赖 | 需要空洞卷积/注意力 | 自然建模 |
| 实现复杂度 | 中等 | 较高 |
10.3 融合趋势:U-DiT
2024-2025 年的研究开始融合两者的优点:
- U-DiT:在 DiT 中引入 U-Net 的层次结构和 skip connection
- U-ViT:在 ViT/DiT 中引入类似 U-Net 的下采样/上采样路径
- MM-DiT(SD3):多模态 DiT,融合文本和图像 embedding
11. 参考文献
相关链接
Footnotes
-
Peebles & Xie, “Scalable Diffusion Models with Transformers”, ICCV 2023. https://arxiv.org/abs/2212.09748 ↩