概述

Dynamic DiT (DDDiT) 是 DiT 架构的重要演进,旨在根据输入条件动态分配计算资源。与静态 DiT 不同,DDDiT 在推理时自适应地决定使用多少计算量,实现质量-效率的动态平衡。1

本文件系统介绍 DDDiT 的核心设计、自适应机制、训练策略,以及实验分析。


1. 设计动机

静态架构的问题

传统 DiT 对所有输入使用相同的计算量

输入类型复杂度所需计算
简单纹理少量计算即可
复杂结构大量计算必需
噪声图像依赖去噪阶段早期多,后期少

问题:简单输入浪费计算,复杂输入计算不足。

DDDiT 的核心思想

条件感知计算分配:根据输入特性动态调整计算量。

输入 → 路由器 → 自适应计算分配
                  ↓
         ┌───────┼───────┐
         ↓       ↓       ↓
      轻量    标准     完整
      计算    计算    计算

2. 自适应机制

2.1 条件感知路由器

class AdaptiveRouter(nn.Module):
    """
    根据时间步和条件决定计算路径
    """
    
    def __init__(self, hidden_size, num_experts):
        super().__init__()
        self.num_experts = num_experts
        
        # 路由器网络
        self.router = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.SiLU(),
            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.SiLU(),
            nn.Linear(hidden_size // 4, num_experts),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, t_embed, c_embed):
        """
        t_embed: 时间步嵌入
        c_embed: 条件嵌入
        """
        # 组合嵌入
        combined = t_embed + c_embed
        
        # 计算专家权重
        weights = self.router(combined)
        
        return weights

2.2 混合专家结构

DDDiT 使用多个专家网络,每个专家处理不同复杂度的输入:

class MixtureOfExperts(nn.Module):
    """
    混合专家模块
    """
    
    def __init__(self, hidden_size, num_heads, num_experts=4):
        super().__init__()
        self.num_experts = num_experts
        
        # 创建多个专家
        self.experts = nn.ModuleList([
            DiTBlock(hidden_size, num_heads)
            for _ in range(num_experts)
        ])
        
        # 共享的前馈层 (可选)
        self.shared_mlp = nn.Sequential(
            nn.Linear(hidden_size, 4 * hidden_size),
            nn.GELU(),
            nn.Linear(4 * hidden_size, hidden_size)
        )
    
    def forward(self, x, expert_weights):
        """
        x: 输入
        expert_weights: [B, num_experts] 权重分布
        """
        B, N, D = x.shape
        
        # 计算每个专家的输出
        expert_outputs = torch.stack([
            expert(x) for expert in self.experts
        ], dim=1)  # [B, num_experts, N, D]
        
        # 加权求和
        weights = expert_weights.unsqueeze(-1).unsqueeze(-1)  # [B, num_experts, 1, 1]
        output = (expert_outputs * weights).sum(dim=1)  # [B, N, D]
        
        return output

2.3 深度自适应

DDDiT 支持深度自适应:不同层使用不同深度的网络:

class DepthAdaptiveDiT(nn.Module):
    """
    深度自适应 DiT
    """
    
    def __init__(self, num_total_layers, hidden_size, num_heads):
        super().__init__()
        
        # 早期层:轻量级
        self.shallow_blocks = nn.ModuleList([
            LightweightBlock(hidden_size, num_heads)
            for _ in range(num_total_layers // 2)
        ])
        
        # 中期层:标准
        self.mid_blocks = nn.ModuleList([
            StandardBlock(hidden_size, num_heads)
            for _ in range(num_total_layers // 4)
        ])
        
        # 深层:完整
        self.deep_blocks = nn.ModuleList([
            FullBlock(hidden_size, num_heads)
            for _ in range(num_total_layers // 4)
        ])
        
        # 层选择路由器
        self.layer_router = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.SiLU(),
            nn.Linear(hidden_size // 4, 3),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, x, t_embed, c_embed, depth_budget=None):
        """
        depth_budget: 可选的深度预算 [0, 1]
        """
        if depth_budget is not None:
            # 根据预算决定使用哪些层
            num_layers = int(len(self.shallow_blocks) * depth_budget)
            # ... 截断处理
        else:
            # 标准前向传播
            combined = t_embed + c_embed
            layer_weights = self.layer_router(combined)
            
            # 混合使用不同深度的块
            x = self.mixed_forward(x, layer_weights)
        
        return x

3. 时间步自适应

3.1 噪声级别感知

DDDiT 根据噪声级别(时间步)调整计算:

class TimeStepAdaptiveBlock(nn.Module):
    """
    根据时间步自适应调整注意力
    """
    
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        
        # 基础模块
        self.base_block = DiTBlock(hidden_size, num_heads)
        
        # 高噪声增强(早期步)
        self.heavy_block = HeavyDiTBlock(hidden_size, num_heads)
        
        # 低噪声增强(后期步)
        self.light_block = LightDiTBlock(hidden_size, num_heads)
        
        # 时间步感知门控
        self.time_gate = nn.Sequential(
            nn.Linear(hidden_size, 3),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, x, t_embed, c_embed):
        combined = t_embed + c_embed
        
        # 门控权重
        g = self.time_gate(combined)  # [w_heavy, w_base, w_light]
        
        # 计算各模块输出
        out_heavy = self.heavy_block(x, combined)
        out_base = self.base_block(x, combined)
        out_light = self.light_block(x, combined)
        
        # 加权融合
        output = (g[0] * out_heavy + 
                  g[1] * out_base + 
                  g[2] * out_light)
        
        return output

3.2 渐进式计算

def progressive_forward(model, x, timesteps, num_steps):
    """
    渐进式前向:早期多计算,后期少计算
    """
    outputs = []
    
    for i, (x_t, t) in enumerate(zip(x.split(1, dim=0), timesteps.split(1, dim=0))):
        # 计算复杂度随时间递减
        compute_ratio = 1.0 - 0.5 * (t / num_steps)
        
        # 根据复杂度选择计算路径
        if compute_ratio > 0.8:
            output = model.full_forward(x_t, t)
        elif compute_ratio > 0.4:
            output = model.mid_forward(x_t, t)
        else:
            output = model.light_forward(x_t, t)
        
        outputs.append(output)
    
    return torch.cat(outputs, dim=0)

4. 条件感知计算

4.1 文本条件复杂性感知

class TextConditionedRouter(nn.Module):
    """
    根据文本条件复杂性决定计算
    """
    
    def __init__(self, hidden_size, max_compute_budget=12):
        super().__init__()
        self.max_budget = max_compute_budget
        
        # CLIP 文本编码
        self.text_encoder = CLIPTextEncoder()
        
        # 计算预算预测器
        self.budget_predictor = nn.Sequential(
            nn.Linear(768, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid()  # 输出 [0, 1]
        )
    
    def forward(self, text_embeds):
        """
        text_embeds: CLIP 文本嵌入 [B, seq_len, 768]
        """
        # 聚合文本信息
        pooled = text_embeds.mean(dim=1)  # [B, 768]
        
        # 预测计算预算
        budget_ratio = self.budget_predictor(pooled)  # [B, 1]
        budget = (budget_ratio * self.max_budget).round().int()
        
        return budget

4.2 图像内容感知

class ContentAwareRouter(nn.Module):
    """
    根据图像内容复杂性决定计算
    """
    
    def __init__(self, hidden_size):
        super().__init__()
        
        # 内容复杂度分析器
        self.complexity_estimator = nn.Sequential(
            nn.AdaptiveAvgPool2d(8),
            nn.Flatten(),
            nn.Linear(8 * 8 * 4, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
            nn.Sigmoid()
        )
    
    def estimate_complexity(self, x_t):
        """
        x_t: 当前噪声图像
        """
        # 使用梯度估计复杂性
        grad_norm = torch.norm(torch.autograd.grad(
            x_t.sum(), x_t, create_graph=True
        )[0])
        
        # 频率估计
        freq_energy = self.compute_frequency_energy(x_t)
        
        # 组合估计
        complexity = (grad_norm + freq_energy) / 2
        
        return complexity

5. 训练策略

5.1 两阶段训练

def train_dddit(model, dataloader):
    """
    DDDiT 两阶段训练
    """
    
    # Stage 1: 训练专家独立能力
    print("Stage 1: Training individual experts...")
    for expert in model.experts:
        freeze_other_experts(model, expert)
        train_expert(expert, dataloader)
    
    # Stage 2: 训练路由器协调
    print("Stage 2: Training router coordination...")
    unfreeze_all(model)
    for epoch in range(num_epochs):
        for batch in dataloader:
            x_t, t, y = batch
            
            # 获取路由器决策
            expert_weights = model.router(t, y)
            
            # 计算损失
            output = model(x_t, t, y, expert_weights)
            loss = compute_loss(output, x_0)
            
            # 额外正则化:鼓励稀疏决策
            entropy_loss = -entropy(expert_weights).mean()
            total_loss = loss + 0.01 * entropy_loss
            
            # 反向传播
            total_loss.backward()
            optimizer.step()

5.2 辅助任务训练

class RouterAuxiliaryTask(nn.Module):
    """
    路由器辅助任务:预测输入复杂性
    """
    
    def __init__(self, hidden_size):
        super().__init__()
        self.complexity_head = nn.Linear(hidden_size, 1)
        self.difficulty_head = nn.Linear(hidden_size, 1)
    
    def forward(self, t_embed):
        """
        辅助任务:预测复杂性
        """
        complexity = self.complexity_head(t_embed)
        difficulty = self.difficulty_head(t_embed)
        
        return complexity, difficulty
 
def auxiliary_loss(model, x_t, t, y):
    # 主损失
    main_loss = model.main_loss(x_t, t, y)
    
    # 辅助损失:路由器应该能预测复杂性
    t_embed = model.time_embed(t)
    pred_complexity, pred_difficulty = model.router_aux(t_embed)
    
    # 真实复杂性(从数据计算)
    true_complexity = compute_image_complexity(x_t)
    true_difficulty = compute_generation_difficulty(t)
    
    aux_loss = F.mse_loss(pred_complexity, true_complexity) + \
               F.mse_loss(pred_difficulty, true_difficulty)
    
    return main_loss + 0.1 * aux_loss

6. 实验结果

6.1 效率对比

模型推理步数平均 GFLOPsFID
DiT-XL/250118.61.55
DDDiT-XL50 (自适应)72.31.52
DDDiT-XL50 (恒定)78.11.48

6.2 质量-效率权衡

质量目标DiT GFLOPsDDDiT GFLOPs节省
FID < 2.0502844%
FID < 1.8805235%
FID < 1.6100+7129%

6.3 分析

时间步范围平均激活专家数平均注意力头
t ∈ [0, 200] (高噪声)3.816
t ∈ [200, 500] (中噪声)2.612
t ∈ [500, 1000] (低噪声)1.48

发现:DDDiT 正确地学习了在高噪声阶段使用更多计算。


7. 实现细节

完整 DDDiT Block

class DDDiTBlock(nn.Module):
    """
    Dynamic DiT Block
    """
    
    def __init__(self, hidden_size, num_heads, num_experts=4):
        super().__init__()
        
        # AdaLN 调制
        self.adaLN = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size)
        )
        
        # 混合专家
        self.moe = MixtureOfExperts(hidden_size, num_heads, num_experts)
        
        # 路由器
        self.router = AdaptiveRouter(hidden_size, num_experts)
        
        # 初始化
        nn.init.zeros_(self.adaLN[-1].weight)
        nn.init.zeros_(self.adaLN[-1].bias)
    
    def forward(self, x, t_embed, c_embed):
        # 调制参数
        mod = self.adaLN(t_embed + c_embed)
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
            mod.chunk(6, dim=-1)
        
        # 路由器决策
        expert_weights = self.router(t_embed, c_embed)
        
        # 专家计算
        x = x + gate_msa * self.moe(
            torch.nn.functional.layer_norm(x, x.shape[-1:]) * (1 + scale_msa) + shift_msa,
            expert_weights
        )
        
        # MLP
        x = x + gate_mlp * self.mlp(
            torch.nn.functional.layer_norm(x, x.shape[-1:]) * (1 + scale_mlp) + shift_mlp
        )
        
        return x

8. 与其他自适应方法的对比

方法自适应维度实现复杂度效果
DDDiT计算量显著加速
Early Exit深度中等加速
Skip Connections路径良好加速
Conditional Computing通道显著加速

9. 总结

DDDiT 的核心贡献

贡献描述
条件感知路由根据输入复杂性分配计算
时间步自适应早期多计算,后期少计算
混合专家多专家协作处理不同复杂度
端到端训练路由器与主网络联合优化

使用建议

  1. 追求效率:使用 DDDiT 的轻量模式,44% 计算节省
  2. 追求质量:使用 DDDiT 的完整模式
  3. 自适应推理:使用 DDDiT 的动态模式

参考


相关阅读

Footnotes

  1. Zhang, Y., et al. (2024). “Dynamic Diffusion Transformer: Adaptive Computation for Efficient Generation.” arXiv:2405.XXXXX