概述

DiT (Diffusion Transformer) 的训练比传统 UNet 更加复杂,对稳定性技术的要求更高。本文件系统系统性地介绍 DiT 训练中的关键稳定性技术。1


1. 训练稳定性挑战

DiT 训练的独特挑战

挑战原因影响
残差累积多层 Transformer 堆叠梯度爆炸/消失
注意力权重不稳定Softmax 可能溢出NaN/Inf
条件冲突多条件同时注入训练震荡
长序列O(N²) 注意力计算数值不稳定

症状识别

# 训练不稳定的常见症状
symptoms = {
    'loss_nan': '损失变为 NaN',
    'loss_explode': '损失突然增大 100 倍',
    'grad_nan': '梯度包含 NaN',
    'attention_nan': '注意力权重溢出',
    'output_explode': '输出值过大',
}

2. EMA (指数移动平均)

2.1 理论背景

EMA 在扩散模型训练中至关重要:

其中

2.2 DiT 中的 EMA

class EMA:
    """
    指数移动平均
    
    DiT 推荐配置:
    - 衰减率: 0.9999
    - 更新频率: 每步
    - 启动延迟: 5000 步
    """
    
    def __init__(self, model, decay=0.9999, update_after_step=5000):
        self.model = model
        self.decay = decay
        self.update_after_step = update_step
        self.shadow = {}
        self.backup = {}
        
        # 初始化 shadow 参数
        self.register()
    
    def register(self):
        """将模型参数注册到 shadow"""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()
    
    def update(self, step):
        """更新 EMA"""
        if step < self.update_after_step:
            return
        
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()
    
    def apply_shadow(self):
        """用 shadow 参数替换模型参数"""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data.clone()
                param.data = self.shadow[name]
    
    def restore(self):
        """恢复原始参数"""
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]
        self.backup = {}

2.3 EMA 的实际效果

EMA 衰减率训练稳定性生成质量推荐场景
0.9999极高稳定训练
0.999中高快速收敛
0.99实验阶段
无 EMA不推荐

3. 梯度裁剪

3.1 全局梯度裁剪

class GradientClipping:
    """
    梯度裁剪
    
    DiT 推荐配置:
    - 最大范数: 1.0
    - 归一化: 按参数数量
    """
    
    def __init__(self, max_norm=1.0, norm_type=2.0):
        self.max_norm = max_norm
        self.norm_type = norm_type
    
    def clip(self, parameters):
        """裁剪梯度"""
        total_norm = torch.nn.utils.clip_grad_norm_(
            parameters, 
            self.max_norm, 
            self.norm_type
        )
        return total_norm
 
# 使用示例
clipper = GradientClipping(max_norm=1.0)
 
for batch in dataloader:
    output = model(**batch)
    loss = compute_loss(output)
    loss.backward()
    
    # 裁剪前检查
    total_norm = clipper.clip(model.parameters())
    
    if torch.isnan(total_norm):
        print("警告: 梯度为 NaN,跳过此步")
        optimizer.zero_grad()
        continue
    
    optimizer.step()

3.2 自适应梯度裁剪

class AdaptiveGradientClipping:
    """
    自适应梯度裁剪
    
    根据训练阶段动态调整裁剪阈值
    """
    
    def __init__(self, initial_clip=1.0, final_clip=0.1, warmup_steps=10000):
        self.initial_clip = initial_clip
        self.final_clip = final_clip
        self.warmup_steps = warmup_steps
    
    def get_clip_norm(self, step):
        """获取当前步的裁剪阈值"""
        if step < self.warmup_steps:
            # 线性 warmup
            ratio = step / self.warmup_steps
            return self.initial_clip + (self.final_clip - self.initial_clip) * ratio
        else:
            return self.final_clip
    
    def clip(self, model, step):
        """裁剪并返回梯度范数"""
        clip_norm = self.get_clip_norm(step)
        total_norm = torch.nn.utils.clip_grad_norm_(
            model.parameters(), 
            clip_norm
        )
        return total_norm

3.3 逐层梯度裁剪

def clip_gradients_by_layer(model, max_norms):
    """
    逐层设置不同的梯度裁剪阈值
    
    早期层通常需要更严格的裁剪
    """
    for name, param in model.named_parameters():
        if param.grad is None:
            continue
        
        # 解析层号
        if 'blocks' in name:
            layer_id = int(name.split('.')[1])
            max_norm = max_norms[min(layer_id, len(max_norms)-1)]
        else:
            max_norm = max_norms[-1]  # 默认
        
        param.grad.data = torch.clamp(
            param.grad.data, 
            -max_norm, 
            max_norm
        )

4. 混合精度训练

4.1 FP16/BF16 训练

# PyTorch 混合精度配置
scaler = torch.cuda.amp.GradScaler(
    init_scale=65536,  # 初始缩放因子
    growth_factor=2.0,  # 增长因子
    backoff_factor=0.5,  # 下降因子
    growth_interval=2000  # 增长间隔
)
 
model = model.cuda()
model = model.bfloat16()  # BF16 比 FP16 更稳定
 
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
 
for batch in dataloader:
    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        output = model(**batch)
        loss = compute_loss(output)
    
    # 缩放损失并反向传播
    scaler.scale(loss).backward()
    
    # 梯度裁剪
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    
    # 参数更新
    scaler.step(optimizer)
    scaler.update()

4.2 精度选择

精度格式数值范围稳定性性能推荐场景
FP32全范围最高最慢基准
BF16全范围DiT 推荐
FP16有限范围最快谨慎使用

4.3 精度转换工具

class PrecisionManager:
    """
    管理模型精度
    """
    
    def __init__(self, model, precision='bf16-mixed'):
        self.model = model
        self.precision = precision
    
    def to_train_precision(self):
        """转换为训练精度"""
        if self.precision == 'bf16-mixed':
            self.model = self.model.to(torch.bfloat16)
        elif self.precision == 'fp16-mixed':
            self.model = self.model.to(torch.float16)
        return self.model
    
    def to_inference_precision(self):
        """转换为推理精度"""
        # 推理时使用 FP32 更稳定
        self.model = self.model.to(torch.float32)
        return self.model

5. 学习率调度

5.1 DiT 推荐调度

def dit_learning_rate_schedule(
    step, 
    base_lr=1e-4,
    warmup_steps=5000,
    total_steps=400000,
    min_lr=1e-6
):
    """
    DiT 推荐学习率调度
    
    包含 warmup 和余弦衰减
    """
    # Warmup 阶段
    if step < warmup_steps:
        return base_lr * (step / warmup_steps)
    
    # 余弦衰减
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
    
    return min_lr + (base_lr - min_lr) * cosine_decay

5.2 层级学习率

def set_layerwise_lr(model, base_lr=1e-4, decay_rate=0.9):
    """
    设置层级学习率
    
    早期层使用较小学习率,稳定训练
    """
    optimizer_grouped_parameters = []
    
    num_layers = len(model.blocks)
    
    for i, (name, param) in enumerate(model.named_parameters()):
        if not param.requires_grad:
            continue
        
        # 计算层级衰减
        depth = i / num_layers
        lr = base_lr * (decay_rate ** (1 - depth))
        
        optimizer_grouped_parameters.append({
            'params': [param],
            'lr': lr,
            'weight_decay': 0.0,  # DiT 通常不用 weight decay
        })
    
    return optimizer_grouped_parameters

5.3 调度可视化

学习率
     ↑
1e-4 ┤───────────────────╱
     │               ╱──
     │           ╱──
     │       ╱──
     │   ╱──
1e-6 ┤╱─────────────────────
     └───────────────────────→ 步数
     0   5K    50K   400K

6. 注意力稳定性

6.1 Softmax 数值稳定化

def stable_attention(Q, K, V, scale=None):
    """
    数值稳定的注意力计算
    """
    d_k = Q.shape[-1]
    scale = scale or d_k ** -0.5
    
    # 数值稳定的 softmax
    scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
    
    # 减去最大值避免溢出
    scores = scores - scores.max(dim=-1, keepdim=True)[0]
    
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)
    
    return output, attn_weights

6.2 QK-Norm

class QKNormAttention(nn.Module):
    """
    Query-Key 归一化
    
    防止注意力权重爆炸
    """
    
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        
        # 可学习的缩放参数
        self.scale = nn.Parameter(torch.ones(num_heads))
    
    def forward(self, Q, K, V):
        # QK 归一化
        Q = Q / (Q.norm(dim=-1, keepdim=True) + 1e-8)
        K = K / (K.norm(dim=-1, keepdim=True) + 1e-8)
        
        # 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale.view(1, -1, 1, 1)
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        
        return output

6.3 注意力权重监控

class AttentionMonitor:
    """
    监控注意力权重稳定性
    """
    
    def __init__(self, model):
        self.model = model
        self.attention_stats = []
    
    def register_hooks(self):
        """注册注意力监控 hooks"""
        def hook_fn(module, input, output):
            attn_weights = output[1] if isinstance(output, tuple) else output
            stats = {
                'mean': attn_weights.mean().item(),
                'std': attn_weights.std().item(),
                'max': attn_weights.max().item(),
                'min': attn_weights.min().item(),
            }
            self.attention_stats.append(stats)
        
        for block in self.model.blocks:
            block.attn.register_forward_hook(hook_fn)
    
    def check_stability(self):
        """检查注意力稳定性"""
        if not self.attention_stats:
            return True
        
        latest = self.attention_stats[-1]
        
        # 检查是否异常
        if latest['max'] > 1.0 + 1e-5:
            print("警告: 注意力权重可能不稳定")
            return False
        
        if torch.isnan(torch.tensor(latest['mean'])):
            print("警告: 注意力权重为 NaN")
            return False
        
        return True

7. 初始化策略

7.1 DiT 初始化

def dit_init_weights(module):
    """
    DiT 推荐的权重初始化
    """
    if isinstance(module, nn.Linear):
        # 线性层: Xavier 初始化
        nn.init.xavier_uniform_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    
    elif isinstance(module, nn.LayerNorm):
        # LayerNorm: 初始化为零偏移
        nn.init.zeros_(module.weight)
        nn.init.zeros_(module.bias)
    
    elif isinstance(module, nn.Embedding):
        # Embedding: 小范围初始化
        nn.init.normal_(module.weight, mean=0, std=0.02)
 
# 应用初始化
model.apply(dit_init_weights)
 
# AdaLN-Zero 特殊初始化
for block in model.blocks:
    nn.init.zeros_(block.adaLN[-1].weight)
    nn.init.zeros_(block.adaLN[-1].bias)

7.2 残差连接初始化

class ResidualBlockWithScale(nn.Module):
    """
    带缩放的残差连接
    """
    
    def __init__(self, module, scale=0.1):
        super().__init__()
        self.module = module
        self.scale = nn.Parameter(torch.tensor(scale))
    
    def forward(self, x):
        return x + self.scale * self.module(x)

8. 训练监控

8.1 关键指标

class TrainingMonitor:
    """
    DiT 训练监控
    """
    
    def __init__(self):
        self.metrics = {
            'loss': [],
            'grad_norm': [],
            'attention_stats': [],
            'lr': [],
        }
    
    def log_batch(self, step, loss, grad_norm, lr):
        self.metrics['loss'].append((step, loss.item()))
        self.metrics['grad_norm'].append((step, grad_norm))
        self.metrics['lr'].append((step, lr))
    
    def check_stability(self, step):
        """检查训练稳定性"""
        # 检查损失
        if len(self.metrics['loss']) > 0:
            recent_losses = [l for _, l in self.metrics['loss'][-100:]]
            if any(np.isnan(l) for l in recent_losses):
                return False, "Loss is NaN"
            
            # 检查损失爆炸
            if len(recent_losses) > 10:
                if recent_losses[-1] > 10 * np.mean(recent_losses[:-10]):
                    return False, "Loss exploded"
        
        # 检查梯度
        if len(self.metrics['grad_norm']) > 0:
            recent_grads = [g for _, g in self.metrics['grad_norm'][-100:]]
            if recent_grads[-1] > 100:
                return False, "Gradient exploded"
        
        return True, "Stable"
    
    def generate_report(self):
        """生成训练报告"""
        report = {
            'avg_loss': np.mean([l for _, l in self.metrics['loss'][-1000:]]),
            'avg_grad_norm': np.mean([g for _, g in self.metrics['grad_norm'][-1000:]]),
            'current_lr': self.metrics['lr'][-1][1] if self.metrics['lr'] else None,
        }
        return report

8.2 早停机制

class EarlyStopping:
    """
    早停机制
    """
    
    def __init__(self, patience=10, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
    
    def check(self, loss):
        """检查是否应该停止"""
        if loss < self.best_loss - self.min_delta:
            self.best_loss = loss
            self.counter = 0
            return False  # 继续训练
        
        self.counter += 1
        if self.counter >= self.patience:
            return True  # 停止训练
        
        return False

9. 完整训练配置

# DiT 推荐训练配置
TRAINING_CONFIG = {
    # 模型
    'model': {
        'hidden_size': 1024,
        'num_heads': 16,
        'num_layers': 28,
        'patch_size': 2,
    },
    
    # 优化器
    'optimizer': {
        'lr': 1e-4,
        'weight_decay': 0.0,
        'beta1': 0.9,
        'beta2': 0.999,
    },
    
    # 学习率调度
    'scheduler': {
        'warmup_steps': 5000,
        'total_steps': 400000,
        'min_lr': 1e-6,
    },
    
    # 稳定性
    'stability': {
        'ema_decay': 0.9999,
        'grad_clip_norm': 1.0,
        'precision': 'bf16-mixed',
    },
    
    # 数据
    'data': {
        'batch_size': 256,
        'resolution': 256,
        'num_workers': 8,
    },
}

10. 总结

关键稳定性技术

技术作用推荐配置
EMA平滑训练轨迹衰减率 0.9999
梯度裁剪防止梯度爆炸范数 1.0
混合精度加速 + 稳定BF16
学习率调度稳定收敛Warmup + Cosine
QK-Norm注意力稳定推荐使用

检查清单

  • 使用 EMA(衰减率 0.9999)
  • 启用梯度裁剪
  • 使用 BF16 混合精度
  • 添加 warmup 阶段
  • 监控关键指标
  • 实现早停机制

参考


相关阅读

Footnotes

  1. Peebles, W., & Xie, S. (2023). “Scalable Diffusion Models with Transformers.” ICCV 2023. arXiv:2212.09748