概述
DiT (Diffusion Transformer) 的训练比传统 UNet 更加复杂,对稳定性技术的要求更高。本文件系统系统性地介绍 DiT 训练中的关键稳定性技术。1
1. 训练稳定性挑战
DiT 训练的独特挑战
| 挑战 | 原因 | 影响 |
|---|---|---|
| 残差累积 | 多层 Transformer 堆叠 | 梯度爆炸/消失 |
| 注意力权重不稳定 | Softmax 可能溢出 | NaN/Inf |
| 条件冲突 | 多条件同时注入 | 训练震荡 |
| 长序列 | O(N²) 注意力计算 | 数值不稳定 |
症状识别
# 训练不稳定的常见症状
symptoms = {
'loss_nan': '损失变为 NaN',
'loss_explode': '损失突然增大 100 倍',
'grad_nan': '梯度包含 NaN',
'attention_nan': '注意力权重溢出',
'output_explode': '输出值过大',
}2. EMA (指数移动平均)
2.1 理论背景
EMA 在扩散模型训练中至关重要:
其中 。
2.2 DiT 中的 EMA
class EMA:
"""
指数移动平均
DiT 推荐配置:
- 衰减率: 0.9999
- 更新频率: 每步
- 启动延迟: 5000 步
"""
def __init__(self, model, decay=0.9999, update_after_step=5000):
self.model = model
self.decay = decay
self.update_after_step = update_step
self.shadow = {}
self.backup = {}
# 初始化 shadow 参数
self.register()
def register(self):
"""将模型参数注册到 shadow"""
for name, param in self.model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self, step):
"""更新 EMA"""
if step < self.update_after_step:
return
for name, param in self.model.named_parameters():
if param.requires_grad:
assert name in self.shadow
new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
self.shadow[name] = new_average.clone()
def apply_shadow(self):
"""用 shadow 参数替换模型参数"""
for name, param in self.model.named_parameters():
if param.requires_grad:
self.backup[name] = param.data.clone()
param.data = self.shadow[name]
def restore(self):
"""恢复原始参数"""
for name, param in self.model.named_parameters():
if param.requires_grad:
param.data = self.backup[name]
self.backup = {}2.3 EMA 的实际效果
| EMA 衰减率 | 训练稳定性 | 生成质量 | 推荐场景 |
|---|---|---|---|
| 0.9999 | 极高 | 高 | 稳定训练 |
| 0.999 | 高 | 中高 | 快速收敛 |
| 0.99 | 中 | 中 | 实验阶段 |
| 无 EMA | 低 | 低 | 不推荐 |
3. 梯度裁剪
3.1 全局梯度裁剪
class GradientClipping:
"""
梯度裁剪
DiT 推荐配置:
- 最大范数: 1.0
- 归一化: 按参数数量
"""
def __init__(self, max_norm=1.0, norm_type=2.0):
self.max_norm = max_norm
self.norm_type = norm_type
def clip(self, parameters):
"""裁剪梯度"""
total_norm = torch.nn.utils.clip_grad_norm_(
parameters,
self.max_norm,
self.norm_type
)
return total_norm
# 使用示例
clipper = GradientClipping(max_norm=1.0)
for batch in dataloader:
output = model(**batch)
loss = compute_loss(output)
loss.backward()
# 裁剪前检查
total_norm = clipper.clip(model.parameters())
if torch.isnan(total_norm):
print("警告: 梯度为 NaN,跳过此步")
optimizer.zero_grad()
continue
optimizer.step()3.2 自适应梯度裁剪
class AdaptiveGradientClipping:
"""
自适应梯度裁剪
根据训练阶段动态调整裁剪阈值
"""
def __init__(self, initial_clip=1.0, final_clip=0.1, warmup_steps=10000):
self.initial_clip = initial_clip
self.final_clip = final_clip
self.warmup_steps = warmup_steps
def get_clip_norm(self, step):
"""获取当前步的裁剪阈值"""
if step < self.warmup_steps:
# 线性 warmup
ratio = step / self.warmup_steps
return self.initial_clip + (self.final_clip - self.initial_clip) * ratio
else:
return self.final_clip
def clip(self, model, step):
"""裁剪并返回梯度范数"""
clip_norm = self.get_clip_norm(step)
total_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
clip_norm
)
return total_norm3.3 逐层梯度裁剪
def clip_gradients_by_layer(model, max_norms):
"""
逐层设置不同的梯度裁剪阈值
早期层通常需要更严格的裁剪
"""
for name, param in model.named_parameters():
if param.grad is None:
continue
# 解析层号
if 'blocks' in name:
layer_id = int(name.split('.')[1])
max_norm = max_norms[min(layer_id, len(max_norms)-1)]
else:
max_norm = max_norms[-1] # 默认
param.grad.data = torch.clamp(
param.grad.data,
-max_norm,
max_norm
)4. 混合精度训练
4.1 FP16/BF16 训练
# PyTorch 混合精度配置
scaler = torch.cuda.amp.GradScaler(
init_scale=65536, # 初始缩放因子
growth_factor=2.0, # 增长因子
backoff_factor=0.5, # 下降因子
growth_interval=2000 # 增长间隔
)
model = model.cuda()
model = model.bfloat16() # BF16 比 FP16 更稳定
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for batch in dataloader:
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
output = model(**batch)
loss = compute_loss(output)
# 缩放损失并反向传播
scaler.scale(loss).backward()
# 梯度裁剪
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# 参数更新
scaler.step(optimizer)
scaler.update()4.2 精度选择
| 精度格式 | 数值范围 | 稳定性 | 性能 | 推荐场景 |
|---|---|---|---|---|
| FP32 | 全范围 | 最高 | 最慢 | 基准 |
| BF16 | 全范围 | 高 | 快 | DiT 推荐 |
| FP16 | 有限范围 | 中 | 最快 | 谨慎使用 |
4.3 精度转换工具
class PrecisionManager:
"""
管理模型精度
"""
def __init__(self, model, precision='bf16-mixed'):
self.model = model
self.precision = precision
def to_train_precision(self):
"""转换为训练精度"""
if self.precision == 'bf16-mixed':
self.model = self.model.to(torch.bfloat16)
elif self.precision == 'fp16-mixed':
self.model = self.model.to(torch.float16)
return self.model
def to_inference_precision(self):
"""转换为推理精度"""
# 推理时使用 FP32 更稳定
self.model = self.model.to(torch.float32)
return self.model5. 学习率调度
5.1 DiT 推荐调度
def dit_learning_rate_schedule(
step,
base_lr=1e-4,
warmup_steps=5000,
total_steps=400000,
min_lr=1e-6
):
"""
DiT 推荐学习率调度
包含 warmup 和余弦衰减
"""
# Warmup 阶段
if step < warmup_steps:
return base_lr * (step / warmup_steps)
# 余弦衰减
progress = (step - warmup_steps) / (total_steps - warmup_steps)
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
return min_lr + (base_lr - min_lr) * cosine_decay5.2 层级学习率
def set_layerwise_lr(model, base_lr=1e-4, decay_rate=0.9):
"""
设置层级学习率
早期层使用较小学习率,稳定训练
"""
optimizer_grouped_parameters = []
num_layers = len(model.blocks)
for i, (name, param) in enumerate(model.named_parameters()):
if not param.requires_grad:
continue
# 计算层级衰减
depth = i / num_layers
lr = base_lr * (decay_rate ** (1 - depth))
optimizer_grouped_parameters.append({
'params': [param],
'lr': lr,
'weight_decay': 0.0, # DiT 通常不用 weight decay
})
return optimizer_grouped_parameters5.3 调度可视化
学习率
↑
1e-4 ┤───────────────────╱
│ ╱──
│ ╱──
│ ╱──
│ ╱──
1e-6 ┤╱─────────────────────
└───────────────────────→ 步数
0 5K 50K 400K
6. 注意力稳定性
6.1 Softmax 数值稳定化
def stable_attention(Q, K, V, scale=None):
"""
数值稳定的注意力计算
"""
d_k = Q.shape[-1]
scale = scale or d_k ** -0.5
# 数值稳定的 softmax
scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
# 减去最大值避免溢出
scores = scores - scores.max(dim=-1, keepdim=True)[0]
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
return output, attn_weights6.2 QK-Norm
class QKNormAttention(nn.Module):
"""
Query-Key 归一化
防止注意力权重爆炸
"""
def __init__(self, hidden_size, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
# 可学习的缩放参数
self.scale = nn.Parameter(torch.ones(num_heads))
def forward(self, Q, K, V):
# QK 归一化
Q = Q / (Q.norm(dim=-1, keepdim=True) + 1e-8)
K = K / (K.norm(dim=-1, keepdim=True) + 1e-8)
# 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale.view(1, -1, 1, 1)
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
return output6.3 注意力权重监控
class AttentionMonitor:
"""
监控注意力权重稳定性
"""
def __init__(self, model):
self.model = model
self.attention_stats = []
def register_hooks(self):
"""注册注意力监控 hooks"""
def hook_fn(module, input, output):
attn_weights = output[1] if isinstance(output, tuple) else output
stats = {
'mean': attn_weights.mean().item(),
'std': attn_weights.std().item(),
'max': attn_weights.max().item(),
'min': attn_weights.min().item(),
}
self.attention_stats.append(stats)
for block in self.model.blocks:
block.attn.register_forward_hook(hook_fn)
def check_stability(self):
"""检查注意力稳定性"""
if not self.attention_stats:
return True
latest = self.attention_stats[-1]
# 检查是否异常
if latest['max'] > 1.0 + 1e-5:
print("警告: 注意力权重可能不稳定")
return False
if torch.isnan(torch.tensor(latest['mean'])):
print("警告: 注意力权重为 NaN")
return False
return True7. 初始化策略
7.1 DiT 初始化
def dit_init_weights(module):
"""
DiT 推荐的权重初始化
"""
if isinstance(module, nn.Linear):
# 线性层: Xavier 初始化
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
# LayerNorm: 初始化为零偏移
nn.init.zeros_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
# Embedding: 小范围初始化
nn.init.normal_(module.weight, mean=0, std=0.02)
# 应用初始化
model.apply(dit_init_weights)
# AdaLN-Zero 特殊初始化
for block in model.blocks:
nn.init.zeros_(block.adaLN[-1].weight)
nn.init.zeros_(block.adaLN[-1].bias)7.2 残差连接初始化
class ResidualBlockWithScale(nn.Module):
"""
带缩放的残差连接
"""
def __init__(self, module, scale=0.1):
super().__init__()
self.module = module
self.scale = nn.Parameter(torch.tensor(scale))
def forward(self, x):
return x + self.scale * self.module(x)8. 训练监控
8.1 关键指标
class TrainingMonitor:
"""
DiT 训练监控
"""
def __init__(self):
self.metrics = {
'loss': [],
'grad_norm': [],
'attention_stats': [],
'lr': [],
}
def log_batch(self, step, loss, grad_norm, lr):
self.metrics['loss'].append((step, loss.item()))
self.metrics['grad_norm'].append((step, grad_norm))
self.metrics['lr'].append((step, lr))
def check_stability(self, step):
"""检查训练稳定性"""
# 检查损失
if len(self.metrics['loss']) > 0:
recent_losses = [l for _, l in self.metrics['loss'][-100:]]
if any(np.isnan(l) for l in recent_losses):
return False, "Loss is NaN"
# 检查损失爆炸
if len(recent_losses) > 10:
if recent_losses[-1] > 10 * np.mean(recent_losses[:-10]):
return False, "Loss exploded"
# 检查梯度
if len(self.metrics['grad_norm']) > 0:
recent_grads = [g for _, g in self.metrics['grad_norm'][-100:]]
if recent_grads[-1] > 100:
return False, "Gradient exploded"
return True, "Stable"
def generate_report(self):
"""生成训练报告"""
report = {
'avg_loss': np.mean([l for _, l in self.metrics['loss'][-1000:]]),
'avg_grad_norm': np.mean([g for _, g in self.metrics['grad_norm'][-1000:]]),
'current_lr': self.metrics['lr'][-1][1] if self.metrics['lr'] else None,
}
return report8.2 早停机制
class EarlyStopping:
"""
早停机制
"""
def __init__(self, patience=10, min_delta=0.001):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = float('inf')
def check(self, loss):
"""检查是否应该停止"""
if loss < self.best_loss - self.min_delta:
self.best_loss = loss
self.counter = 0
return False # 继续训练
self.counter += 1
if self.counter >= self.patience:
return True # 停止训练
return False9. 完整训练配置
# DiT 推荐训练配置
TRAINING_CONFIG = {
# 模型
'model': {
'hidden_size': 1024,
'num_heads': 16,
'num_layers': 28,
'patch_size': 2,
},
# 优化器
'optimizer': {
'lr': 1e-4,
'weight_decay': 0.0,
'beta1': 0.9,
'beta2': 0.999,
},
# 学习率调度
'scheduler': {
'warmup_steps': 5000,
'total_steps': 400000,
'min_lr': 1e-6,
},
# 稳定性
'stability': {
'ema_decay': 0.9999,
'grad_clip_norm': 1.0,
'precision': 'bf16-mixed',
},
# 数据
'data': {
'batch_size': 256,
'resolution': 256,
'num_workers': 8,
},
}10. 总结
关键稳定性技术
| 技术 | 作用 | 推荐配置 |
|---|---|---|
| EMA | 平滑训练轨迹 | 衰减率 0.9999 |
| 梯度裁剪 | 防止梯度爆炸 | 范数 1.0 |
| 混合精度 | 加速 + 稳定 | BF16 |
| 学习率调度 | 稳定收敛 | Warmup + Cosine |
| QK-Norm | 注意力稳定 | 推荐使用 |
检查清单
- 使用 EMA(衰减率 0.9999)
- 启用梯度裁剪
- 使用 BF16 混合精度
- 添加 warmup 阶段
- 监控关键指标
- 实现早停机制
参考
相关阅读
- DiT 架构深度解析 — DiT 架构细节
- DiT vs UNet 理论分析 — 架构对比
- 扩散模型缩放定律 — 缩放分析
Footnotes
-
Peebles, W., & Xie, S. (2023). “Scalable Diffusion Models with Transformers.” ICCV 2023. arXiv:2212.09748 ↩