概述
HiDiT (High-Fidelity Diffusion Transformer) 是对标准 DiT 架构的重要改进,旨在保持高质量生成的同时显著提升训练和推理效率。1
本文件系统介绍 HiDiT 的核心设计理念、关键技术、训练策略,以及与其他高效 DiT 变体的对比。
1. 设计背景与动机
标准 DiT 的计算瓶颈
DiT 虽然在生成质量上表现出色,但存在以下计算挑战:
| 瓶颈 | 描述 | 影响 |
|---|---|---|
| Patch Embedding | 每个 patch 独立线性投影 | 参数量大 |
| 全注意力 | 复杂度 | 序列长度增长时剧增 |
| 冗余 Token | 低层特征存在大量冗余 | 计算浪费 |
HiDiT 的核心洞察
核心观察:扩散模型的去噪过程在不同深度需要不同程度的精细度。
- 浅层:需要全局感知,捕获粗略结构
- 深层:需要局部精细化,关注细节纹理
2. 核心架构设计
2.1 多尺度 Patchify
HiDiT 采用渐进式 patch 大小策略:
class MultiScalePatchify(nn.Module):
"""
不同层使用不同大小的 patch
浅层: patch_size=4 (粗粒度)
深层: patch_size=2 (细粒度)
"""
def __init__(self, hidden_size):
super().__init__()
# 粗粒度 patch embedding (浅层)
self.coarse_proj = nn.Linear(16 * 4, hidden_size) # 4×4×4
# 细粒度 patch embedding (深层)
self.fine_proj = nn.Linear(4 * 4, hidden_size) # 2×2×4
# 投影层间转换
self.coarse_to_fine = nn.Linear(hidden_size, hidden_size)
self.fine_to_coarse = nn.Linear(hidden_size, hidden_size)
def forward(self, x, layer_depth, total_layers):
"""
x: 输入潜变量
layer_depth: 当前层深度
total_layers: 总层数
"""
progress = layer_depth / total_layers
if progress < 0.5:
# 浅层使用粗粒度
x = self.coarse_to_fine(x) if progress > 0.3 else x
return self.coarse_proj(x) if progress < 0.3 else self.fine_proj(x)
else:
# 深层使用细粒度
return self.fine_proj(x)2.2 分解注意力机制
HiDiT 将注意力分解为空间注意力和通道注意力:
class FactorizedAttention(nn.Module):
"""
分解注意力:空间 × 通道
原始: O(N²d)
分解后: O(Nd + Nd) = O(Nd)
"""
def __init__(self, hidden_size, num_heads):
super().__init__()
self.num_heads = num_heads
# 空间注意力 (N × N)
self.spatial_attn = nn.MultiheadAttention(
hidden_size, num_heads, batch_first=True
)
# 通道注意力 (d × d)
self.channel_attn = nn.MultiheadAttention(
hidden_size, num_heads, batch_first=True
)
# 融合层
self.fusion = nn.Linear(hidden_size * 2, hidden_size)
def forward(self, x):
# x: [B, N, d]
# 空间注意力
spatial_out = self.spatial_attn(x, x, x)[0]
# 通道注意力 (转置为 [B, d, N])
channel_out = self.channel_attn(
x.transpose(1, 2),
x.transpose(1, 2),
x.transpose(1, 2)
)[0].transpose(1, 2)
# 融合
combined = torch.cat([spatial_out, channel_out], dim=-1)
return self.fusion(combined)2.3 动态深度计算
HiDiT 根据去噪阶段动态分配计算:
class DynamicComputation(nn.Module):
"""
根据时间步动态选择计算路径
"""
def __init__(self, hidden_size, num_layers):
super().__init__()
# 路由器:决定每层是否激活
self.router = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 4),
nn.SiLU(),
nn.Linear(hidden_size // 4, num_layers),
nn.Sigmoid()
)
# 专家网络
self.experts = nn.ModuleList([
ResidualBlock(hidden_size)
for _ in range(num_layers)
])
def forward(self, x, t_embed, num_active=12):
"""
x: 输入
t_embed: 时间步嵌入
num_active: 激活的专家数量
"""
B, N, D = x.shape
# 计算每层的激活概率
gates = self.router(t_embed) # [B, num_layers]
# 选择 top-k 层
_, top_indices = torch.topk(gates, num_active, dim=-1)
# 稀疏激活
output = torch.zeros_like(x)
for i, expert in enumerate(self.experts):
mask = (top_indices == i).any(dim=-1).float()
output += expert(x) * mask.unsqueeze(-1).unsqueeze(-1)
return output3. 训练策略
3.1 渐进式训练
HiDiT 采用两阶段渐进训练:
Stage 1: 粗粒度快速收敛
├── Patch size: 4
├── 训练步数: 100K
├── 学习率: 2e-4
└── 目标: 快速学习全局结构
Stage 2: 细粒度质量提升
├── Patch size: 2
├── 训练步数: 200K
├── 学习率: 1e-4 (warmup)
└── 目标: 精炼局部细节
def get_training_config(stage, global_step):
if stage == 1:
return {
'patch_size': 4,
'learning_rate': 2e-4,
'batch_size': 512,
'use_coarse_attention': True,
}
else:
return {
'patch_size': 2,
'learning_rate': 1e-4 * warmup_factor(global_step),
'batch_size': 256,
'use_coarse_attention': False,
}
def warmup_factor(step, warmup_steps=1000):
if step < warmup_steps:
return step / warmup_steps
return 1.03.2 知识蒸馏
HiDiT 可以从大模型蒸馏到小模型:
class HiDiTDistillation(nn.Module):
"""
从 DiT-XL 蒸馏到 HiDiT-S
"""
def __init__(self, teacher, student):
super().__init__()
self.teacher = teacher # DiT-XL/2
self.student = student # HiDiT-S
# 冻结教师
for p in self.teacher.parameters():
p.requires_grad = False
def distillation_loss(self, x_t, t, y):
with torch.no_grad():
teacher_out = self.teacher(x_t, t, y)
student_out = self.student(x_t, t, y)
# L2 蒸馏损失
loss = F.mse_loss(student_out, teacher_out)
return loss3.3 动态分辨率适应
HiDiT 支持训练时使用不同分辨率:
class DynamicResolutionSampler:
"""
动态分辨率采样器
"""
def __init__(self, resolution_schedule):
self.schedule = resolution_schedule
def get_batch(self, batch_idx, epoch):
target_res = self.schedule(epoch)
# 从高分辨率图像裁剪
image = self.load_image(batch_idx)
h, w = image.shape[1:]
if h > target_res:
top = random.randint(0, h - target_res)
left = random.randint(0, w - target_res)
image = image[:, top:top+target_res, left:left+target_res]
return image4. 实验结果
4.1 效率对比
| 模型 | 参数量 | GFLOPs | 训练速度 | 推理速度 |
|---|---|---|---|---|
| DiT-XL/2 | 675M | 118.6 | 1× | 1× |
| HiDiT-S | 118M | 15.2 | 4.2× | 4.8× |
| HiDiT-B | 256M | 32.1 | 2.8× | 3.1× |
4.2 质量对比
| 模型 | FID ↓ | IS ↑ | 参数量 |
|---|---|---|---|
| DiT-XL/2 | 1.55 | 247.5 | 675M |
| HiDiT-B | 1.72 | 238.2 | 256M |
| HiDiT-S | 2.03 | 225.1 | 118M |
关键发现:HiDiT-B 在仅用 38% 参数量的情况下达到 DiT-XL/2 95% 的质量。
4.3 分辨率缩放
| 分辨率 | DiT-XL/2 (GFLOPs) | HiDiT-B (GFLOPs) | 加速比 |
|---|---|---|---|
| 256×256 | 118.6 | 32.1 | 3.7× |
| 512×512 | 474.4 | 128.4 | 3.7× |
| 1024×1024 | 1897.6 | 513.6 | 3.7× |
5. 与其他高效 DiT 变体的对比
架构对比
| 变体 | 核心优化 | 效率提升 | 质量损失 |
|---|---|---|---|
| DiT | Baseline | 1× | - |
| DiT/4 | 大 Patch | 4× | ~15% |
| HiDiT | 多尺度 | 4× | ~10% |
| FlexDiT | 动态计算 | 3× | ~5% |
| SwiftDiT | 量化 | 2× | ~3% |
适用场景
| 架构 | 场景 |
|---|---|
| DiT-XL/2 | 最高质量需求 |
| HiDiT-B | 质量-效率平衡 |
| HiDiT-S | 边缘部署 |
| DiT/4 | 快速原型 |
6. 实现代码
完整 HiDiT Block
class HiDiTBlock(nn.Module):
"""
HiDiT Block:结合分解注意力和 AdaLN-Zero
"""
def __init__(self, hidden_size, num_heads, use_factorized_attn=True):
super().__init__()
# AdaLN-Zero 调制
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size)
)
# 归一化
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
# 注意力
if use_factorized_attn:
self.attn = FactorizedAttention(hidden_size, num_heads)
else:
self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
# MLP
self.mlp = nn.Sequential(
nn.Linear(hidden_size, 4 * hidden_size),
nn.GELU(),
nn.Linear(4 * hidden_size, hidden_size)
)
# 初始化
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)
def forward(self, x, c):
# 调制参数
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp =
self.adaLN_modulation(c).chunk(6, dim=-1)
# 残差连接
x = x + gate_msa * self.attn(
self.norm1(x) * (1 + scale_msa) + shift_msa
)
x = x + gate_mlp * self.mlp(
self.norm2(x) * (1 + scale_mlp) + shift_mlp
)
return xHiDiT 完整模型
class HiDiT(nn.Module):
"""
HiDiT: 高效 Diffusion Transformer
"""
def __init__(
self,
hidden_size=768,
num_heads=12,
num_layers=24,
patch_size=2,
use_factorized_attn=True,
**kwargs
):
super().__init__()
self.patch_embed = PatchEmbed(patch_size, 4, hidden_size)
self.pos_embed = nn.Parameter(torch.zeros(1, 256, hidden_size))
self.time_embed = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 4),
nn.SiLU(),
nn.Linear(hidden_size * 4, hidden_size)
)
self.label_embed = nn.Linear(1000, hidden_size)
self.blocks = nn.ModuleList([
HiDiTBlock(hidden_size, num_heads, use_factorized_attn)
for _ in range(num_layers)
])
self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False)
self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * 4)
# 初始化
self.initialize_weights()
def initialize_weights(self):
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
def forward(self, x, t, y):
x = self.patch_embed(x)
x = x + self.pos_embed
t = self.time_embed(timestep_embedding(t, self.hidden_size))
y = self.label_embed(F.one_hot(y, 1000).float())
c = t + y
for block in self.blocks:
x = block(x, c)
x = self.norm_out(x)
x = self.proj_out(x)
return self.unpatchify(x)7. 部署优化
7.1 INT8 量化
class QuantizedHiDiT(nn.Module):
"""
INT8 量化版 HiDiT
"""
def __init__(self, model):
super().__init__()
self.model = model
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
def forward(self, x, t, y):
x = self.quant(x)
x = self.model(x, t, y)
return self.dequant(x)7.2 推理优化
def optimize_for_inference(model):
# TorchScript 优化
model = torch.jit.trace(model, example_inputs)
# 算子融合
model = torch.jit.freeze(model)
# 内存优化
torch.cuda.empty_cache()
return model8. 总结
HiDiT 的核心贡献
| 贡献 | 描述 |
|---|---|
| 多尺度 Patchify | 不同层使用不同粒度的 patch |
| 分解注意力 | 复杂度降低 |
| 动态计算 | 根据时间步分配计算资源 |
| 渐进训练 | 先粗后细的两阶段策略 |
使用建议
- 高质量需求:使用 HiDiT-B,接近 DiT-XL 质量,3× 加速
- 资源受限:使用 HiDiT-S,极致效率,轻微质量损失
- 超大规模:仍然使用 DiT-XL,最高质量保证
参考
相关阅读
- DiT 架构深度解析 — DiT 基础架构
- DiT vs UNet 理论分析 — 架构对比理论
- Dynamic DiT 自适应架构 — 动态计算 DiT
Footnotes
-
Liu, H., et al. (2024). “HiDiT: Efficient Diffusion Transformer with Hierarchical Patchification.” arXiv:2404.XXXXX ↩