DiT 变体与 2025 年新进展
概述
DiT (Diffusion Transformer) 自 2023 年提出以来,在图像生成领域取得了巨大成功。2024-2025 年,研究者们在 DiT 基础上提出了多种改进变体,涵盖效率优化、架构创新、应用扩展等多个方向。
本文梳理了近年来最重要的 DiT 变体及其核心创新。
1. U-DiT: U-Net 遇见 DiT
1.1 核心思想
U-DiT 由 Tian 等人提出(CVPR 2025),核心观察是:U-Net 的层次结构和 skip connection 在 DiT 中同样有效。1
1.2 架构设计
┌──────────────────────────────────────────────────────────────┐
│ U-DiT 架构 │
├──────────────────────────────────────────────────────────────┤
│ │
│ 输入 (H×W) ──┬── Tokenize ──▶ Patch Embed ──▶ Level 1 │
│ │ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ DiT Block × N₁ │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ │ Downsample │
│ ▼ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ DiT Block × N₂ │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ │ Skip Connection │
│ ▼ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ DiT Block × N₁ │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ │ Upsample + Skip │
│ ▼ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ DiT Block × N₁ │ │
│ └─────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ Patch Detokenize ──▶ 输出 │
│ │
└──────────────────────────────────────────────────────────────┘
1.3 关键创新
| 创新点 | 描述 | 效果 |
|---|---|---|
| 层次化 Token | 不同层使用不同的 patch size | 计算量减少 6 倍 |
| QKV 下采样 | Attention 中对 K, V 进行下采样 | 加速注意力计算 |
| Skip Connection | 引入类似 U-Net 的跳连 | 保留多尺度信息 |
1.4 实验结果
| 模型 | GFLOPs | FID | 质量对比 |
|---|---|---|---|
| DiT-XL/2 | 1121 | 2.27 | 基准 |
| U-DiT-L/2 | ~190 | ~2.1 | 质量相当,计算量减少 6× |
1.5 实现细节
class UDiTBlock(nn.Module):
"""
U-DiT Block: 带层次化结构的 DiT Block
"""
def __init__(self, dim, num_heads, downsample_kv=True):
super().__init__()
self.norm1 = AdaLNZero(dim)
self.attn = Attention(dim, num_heads)
self.norm2 = AdaLNZero(dim)
self.ffn = MLP(dim)
# QKV 下采样(可选)
self.downsample_kv = downsample_kv
if downsample_kv:
self.kv_downsample = nn.Conv2d(dim, dim, kernel_size=3, stride=2, padding=1)
def forward(self, x, cond, kv_cache=None):
# 自注意力
x = x + self.norm1.gate * self.attn(self.norm1(x, cond), kv_cache)
# FFN
x = x + self.norm2.gate * self.ffn(self.norm2(x, cond))
return x2. DyDiT++: 动态 DiT
2.1 核心思想
DyDiT++ 由阿里巴巴达摩院提出(ICLR 2025),核心观察是:不同时间步和空间区域需要的计算量不同。2
2.2 两个维度的动态性
Timestep-wise Dynamic Width (TDW)
时间步 t=0.9 (早期): 需要大模型宽度 ────────────████████████ 100%
时间步 t=0.5 (中期): 需要中等模型宽度 ────────────████████ 60%
时间步 t=0.1 (晚期): 小模型宽度足够 ──────────────████ 30%
关键洞察:早期 timestep 需要更多表达能力(生成全局结构),晚期 timestep 只需精细调整(局部纹理)。
Spatial-wise Dynamic Token (SDT)
图像区域:
┌─────────────┐
│ 主体 (高重要性) │ ─── 需要完整计算 ████████████
│ │
│ 背景 (低重要性) │ ─── 可跳过计算 ░░░░░░░░░░░
└─────────────┘
判断标准:背景区域 loss 贡献低,可以跳过其计算。
2.3 架构设计
class DyDiTPPBlock(nn.Module):
"""
DyDiT++ Block: 支持动态宽度和动态 Token
"""
def __init__(self, dim, num_heads, max_width_ratio=1.0):
super().__init__()
self.max_width = int(dim * max_width_ratio)
self.current_width = dim
# 宽度控制器
self.width_predictor = nn.Sequential(
nn.Linear(dim, dim // 4),
nn.SiLU(),
nn.Linear(dim // 4, 1),
nn.Sigmoid()
)
# 标准组件
self.norm = AdaLNZero(dim)
self.attn = SparseAttention(dim, num_heads)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
def forward(self, x, t, importance_mask=None):
# 1. 动态宽度调整
width_ratio = self.width_predictor(x.mean(dim=1, keepdim=True))
current_dim = int(self.max_width * width_ratio)
# 2. 空间动态 Token(跳过低重要性区域)
if importance_mask is not None:
# 稀疏注意力:只计算重要区域的 attention
x = self.sparse_attention(x, importance_mask)
# 3. 标准前向
x = self.norm(x, t) * self.norm.gate + x
x = x + self.ffn(x)
return x
def sparse_attention(self, x, mask, threshold=0.3):
"""
稀疏注意力:跳过低重要性 token
"""
# 计算重要性分数
scores = torch.sigmoid(self.importance_net(x))
# 掩码:只保留高于阈值的 token
active_mask = scores > threshold
# 对活跃 token 进行注意力计算
active_indices = active_mask.nonzero(as_tuple=True)
if len(active_indices[0]) > 0:
x_active = x[active_indices]
# 注意力计算
attn_out = self.attn(x_active)
# 写回
x = torch.scatter(x, 0, active_indices[0].unsqueeze(-1).expand_as(attn_out), attn_out)
return x2.4 TD-LoRA: 参数高效微调
DyDiT++ 还提出了 TD-LoRA,在时间维度上分解 LoRA 矩阵:
class TDLoRA(nn.Module):
"""
Time-Decomposed LoRA: 按时间步分解的适配器
"""
def __init__(self, dim, rank=16, num_timesteps=1000):
super().__init__()
self.dim = dim
self.rank = rank
# 时间步 embedding
self.t_emb = nn.Embedding(num_timesteps, rank)
# 基础 LoRA 矩阵
self.A = nn.Linear(dim, rank, bias=False)
self.B = nn.Linear(rank, dim, bias=False)
# 时间依赖调制
self.time_scale = nn.Linear(rank, rank)
self.time_shift = nn.Linear(rank, rank)
def forward(self, x, t):
# 获取时间步的 rank 向量
t_vec = self.t_emb(t) # (B, rank)
# 时间依赖的 A, B
A_t = self.A.weight * (1 + self.time_scale(t_vec))
B_t = self.B.weight + self.time_shift(t_vec)
# LoRA 输出
return (x @ A_t.T) @ B_t.T2.5 实验结果
| 模型 | GFLOPs 减少 | 硬件加速 | FID |
|---|---|---|---|
| DiT-XL/2 | 0% | 1.0× | 2.27 |
| DyDiT++-XL | 51% | 1.73× | 2.07 |
| DyDiT++-L | 48% | 1.65× | 2.18 |
3. D²iT: 动态压缩 DiT
3.1 核心思想
D²iT(CVPR 2025)关注的问题是:不同图像区域应该使用不同的压缩率。3
3.2 问题分析
| 压缩率 | 优点 | 缺点 |
|---|---|---|
| 大压缩 | 计算量小、全局一致性 | 局部细节丢失 |
| 小压缩 | 保留局部细节 | 计算量大 |
核心洞察:主体区域需要小压缩(细节重要),背景区域可以使用大压缩(细节不重要)。
3.3 自适应压缩策略
class D2iT(nn.Module):
"""
Dynamic Diffusion Transformer: 自适应压缩
"""
def __init__(self, dim, num_classes=1000):
super().__init__()
self.dim = dim
# 重要性预测器
self.importance_predictor = nn.Sequential(
nn.Linear(dim, dim // 2),
nn.GELU(),
nn.Linear(dim // 2, 1),
nn.Sigmoid()
)
# 多尺度压缩器
self.compressors = nn.ModuleList([
PatchifyCompressor(patch_size=p) for p in [1, 2, 4]
])
# DiT blocks
self.blocks = nn.ModuleList([DiTBlock(dim) for _ in range(28)])
def forward(self, x, t, y=None):
B, C, H, W = x.shape
# 1. 编码到 latent
x = self.encoder(x)
# 2. 预测每个区域的重要性
importance = self.importance_predictor(x) # (B, N, 1)
# 3. 自适应压缩
compressed_tokens = []
for p_size in [1, 2, 4]:
tokens_p = self.compressors[p_size](x)
weight = (importance > 1/(p_size**2)).float() # 高重要性用小 patch
compressed_tokens.append((tokens_p, weight, p_size))
# 4. 融合多尺度 token
x = self.fuse_tokens(compressed_tokens)
# 5. 标准 DiT 处理
for block in self.blocks:
x = block(x, t, y)
return x
def fuse_tokens(self, multi_scale_tokens):
"""
融合多尺度的 token 表示
"""
# 简单的加权平均
all_tokens = []
all_weights = []
for tokens, weights, p_size in multi_scale_tokens:
# 上采样到统一分辨率
tokens_up = F.interpolate(
tokens.permute(0, 2, 1).reshape(B, -1, H//2, W//2),
scale_factor=p_size/2,
mode='bilinear'
)
all_tokens.append(tokens_up)
all_weights.append(weights)
# 加权平均
fused = sum(w * t for w, t in zip(all_weights, all_tokens))
weights_sum = sum(all_weights)
return fused / (weights_sum + 1e-6)3.4 实验结果
| 方法 | Patch Size | FID | 计算量 |
|---|---|---|---|
| 固定 p=1 | 1 | 1.95 | 100% |
| 固定 p=2 | 2 | 2.15 | 25% |
| 固定 p=4 | 4 | 2.42 | 6.25% |
| D²iT | 自适应 | 2.05 | ~15% |
4. MM-DiT: 多模态 DiT (Stable Diffusion 3)
4.1 核心思想
MM-DiT(Multi-Modal DiT)由 Stability AI 提出,用于 Stable Diffusion 3,核心创新是分别处理文本和图像 embedding,然后融合。4
4.2 架构设计
┌────────────────────────────────────────────────────────────────┐
│ MM-DiT 架构 │
├────────────────────────────────────────────────────────────────┤
│ │
│ 文本输入 │
│ │ │
│ ▼ │
│ ┌───────────┐ │
│ │ Text Enc │ T5-XXL / CLIP Text │
│ └─────┬─────┘ │
│ │ │
│ ▼ │
│ ┌───────────┐ │
│ │ Text Proj │ 线性投影到 hidden_dim │
│ └─────┬─────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 文本 DiT Block × N │ │
│ │ (独立权重) │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │ │
│ │ 残差相加 │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────┐ │
│ │ 图像 DiT Block × N │ │
│ │ (独立权重) │ │
│ └─────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ MM-DiT Block │
│ ┌─────────────────┐ │
│ │ 共享 LayerNorm │ │
│ │ + 交叉残差 │ │
│ └─────────────────┘ │
│ │ │
│ ▼ │
│ 输出层 │
│ │
└────────────────────────────────────────────────────────────────┘
4.3 关键创新
| 创新点 | 描述 |
|---|---|
| 独立文本/图像路径 | 避免跨模态干扰 |
| 共享 AdaLN | LayerNorm 参数共享,降低参数量 |
| Q-Former 集成 | 可选使用 FLAN-T5 的 frozen encoder |
4.4 MM-DiT Block 实现
class MMDiTBlock(nn.Module):
"""
Multi-Modal DiT Block
"""
def __init__(self, dim, num_heads, text_dim):
super().__init__()
# 文本路径(独立权重)
self.text_norm = nn.LayerNorm(dim)
self.text_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.text_ffn = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
# 图像路径(独立权重)
self.image_norm = nn.LayerNorm(dim)
self.image_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.image_ffn = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
# 共享的 AdaLN(关键创新)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(dim, 6 * dim)
)
def forward(self, x_img, x_text, t):
"""
Args:
x_img: (B, N_img, D) 图像 token
x_text: (B, N_text, D) 文本 token
t: (B, D) 时间步 embedding
"""
# 共享的条件
c = self.adaLN_modulation(t)
c_img, c_text, c_img_gate, c_text_gate = c.chunk(4, dim=-1)
# 图像路径
x_img = x_img + c_img_gate * self.image_attn(
self.image_norm(x_img) * (1 + c_img),
self.image_norm(x_img),
self.image_norm(x_img)
)[0]
x_img = x_img + c_img_gate * self.image_ffn(
self.image_norm(x_img) * (1 + c_img)
)
# 文本路径
x_text = x_text + c_text_gate * self.text_attn(
self.text_norm(x_text) * (1 + c_text),
self.text_norm(x_text),
self.text_norm(x_text)
)[0]
x_text = x_text + c_text_gate * self.text_ffn(
self.text_norm(x_text) * (1 + c_text)
)
# 残差交叉(图像从文本获取上下文)
x_img = x_img + 0.1 * x_text.mean(dim=1, keepdim=True)
return x_img, x_text4.5 SD3 的改进
| 组件 | SD 1.x / SDXL | SD3 (MM-DiT) |
|---|---|---|
| 文本编码器 | CLIP Text | CLIP + T5-XXL |
| 架构 | U-Net | MM-DiT |
| 位置编码 | Sinusoidal | RoPE |
| 采样器 | DDPM/DDIM | Euler |
| 参数 | ~1B | ~2B |
5. 其他重要变体
5.1 ScaleDiT: 超高分辨率生成
class ScaleDiT(nn.Module):
"""
ScaleDiT: 分层局部注意力,支持 4K+ 分辨率
"""
def __init__(self, dim):
super().__init__()
# 全局注意力(低分辨率)
self.global_attn = GlobalAttention(dim)
# 局部注意力(高分辨率分块)
self.local_attn = LocalAttention(dim, window_size=32)
def forward(self, x, resolution):
if resolution > 1024:
# 分块处理 + 全局引导
return self.hierarchical_generate(x)
else:
return self.blocks(x)5.2 FlexiDiT: 弹性计算
FlexiDiT 的核心思想是单一模型支持可变计算预算:
class FlexiDiT(nn.Module):
"""
FlexiDiT: 弹性计算预算
"""
def forward(self, x, t, budget=None):
if budget is None:
# 完整计算
return self.full_forward(x, t)
else:
# 根据预算动态跳过
return self.adaptive_forward(x, t, budget)5.3 PixelDiT: 端到端像素建模
PixelDiT 尝试消除 VAE,直接在像素空间建模:
| 组件 | Latent DiT (SD) | PixelDiT |
|---|---|---|
| 输入空间 | Latent (压缩 8x) | 像素 |
| 序列长度 | 32×32 = 1024 | 256×256 = 65536 |
| 计算量 | 低 | 高 |
| 重建误差 | VAE 引入 | 无 |
6. 性能对比汇总
| 模型 | 任务 | FID ↓ | 参数量 | GFLOPs ↓ |
|---|---|---|---|---|
| DiT-XL/2 | ImageNet 256² | 2.27 | 675M | 1181 |
| U-DiT-L/2 | ImageNet 256² | ~2.1 | ~400M | ~190 |
| DyDiT++-XL | ImageNet 256² | 2.07 | 675M | ~580 |
| D²iT | ImageNet 256² | 2.05 | ~700M | ~180 |
| SD3 (MM-DiT) | T2I 1024² | SOTA | ~2B | - |
| MAR-H | ImageNet 256² | ~1.6 | - | - |
7. 总结与展望
7.1 主要研究方向
| 方向 | 代表工作 | 核心贡献 |
|---|---|---|
| 效率优化 | U-DiT, DyDiT++, D²iT | 减少计算量 |
| 多模态融合 | MM-DiT | 统一文本图像建模 |
| 超高分辨率 | ScaleDiT, PixelDiT | 突破分辨率限制 |
| 弹性计算 | FlexiDiT | 单一模型多预算 |
7.2 未来趋势
- 动态计算成为主流:根据内容和时间步自适应分配计算资源
- 多模态统一:文本、图像、视频在统一框架下建模
- 端到端建模:减少对 VAE 的依赖
- 硬件协同设计:针对 NPU/GPU 特性优化架构
7.3 实践建议
| 场景 | 推荐架构 |
|---|---|
| 快速实验 | DiT-B/2 |
| 高质量生成 | DiT-XL/2 |
| 资源受限 | U-DiT-L/2 |
| 多模态生成 | MM-DiT (SD3) |
| 少步采样 | 结合 Consistency Models |