概述

HiDiT (High-Fidelity Diffusion Transformer) 是对标准 DiT 架构的重要改进,旨在保持高质量生成的同时显著提升训练和推理效率。1

本文件系统介绍 HiDiT 的核心设计理念、关键技术、训练策略,以及与其他高效 DiT 变体的对比。


1. 设计背景与动机

标准 DiT 的计算瓶颈

DiT 虽然在生成质量上表现出色,但存在以下计算挑战:

瓶颈描述影响
Patch Embedding每个 patch 独立线性投影参数量大
全注意力 复杂度序列长度增长时剧增
冗余 Token低层特征存在大量冗余计算浪费

HiDiT 的核心洞察

核心观察:扩散模型的去噪过程在不同深度需要不同程度的精细度。

  • 浅层:需要全局感知,捕获粗略结构
  • 深层:需要局部精细化,关注细节纹理

2. 核心架构设计

2.1 多尺度 Patchify

HiDiT 采用渐进式 patch 大小策略:

class MultiScalePatchify(nn.Module):
    """
    不同层使用不同大小的 patch
    
    浅层: patch_size=4 (粗粒度)
    深层: patch_size=2 (细粒度)
    """
    
    def __init__(self, hidden_size):
        super().__init__()
        
        # 粗粒度 patch embedding (浅层)
        self.coarse_proj = nn.Linear(16 * 4, hidden_size)  # 4×4×4
        
        # 细粒度 patch embedding (深层)
        self.fine_proj = nn.Linear(4 * 4, hidden_size)  # 2×2×4
        
        # 投影层间转换
        self.coarse_to_fine = nn.Linear(hidden_size, hidden_size)
        self.fine_to_coarse = nn.Linear(hidden_size, hidden_size)
    
    def forward(self, x, layer_depth, total_layers):
        """
        x: 输入潜变量
        layer_depth: 当前层深度
        total_layers: 总层数
        """
        progress = layer_depth / total_layers
        
        if progress < 0.5:
            # 浅层使用粗粒度
            x = self.coarse_to_fine(x) if progress > 0.3 else x
            return self.coarse_proj(x) if progress < 0.3 else self.fine_proj(x)
        else:
            # 深层使用细粒度
            return self.fine_proj(x)

2.2 分解注意力机制

HiDiT 将注意力分解为空间注意力和通道注意力

class FactorizedAttention(nn.Module):
    """
    分解注意力:空间 × 通道
    
    原始: O(N²d)
    分解后: O(Nd + Nd) = O(Nd)
    """
    
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        
        # 空间注意力 (N × N)
        self.spatial_attn = nn.MultiheadAttention(
            hidden_size, num_heads, batch_first=True
        )
        
        # 通道注意力 (d × d)
        self.channel_attn = nn.MultiheadAttention(
            hidden_size, num_heads, batch_first=True
        )
        
        # 融合层
        self.fusion = nn.Linear(hidden_size * 2, hidden_size)
    
    def forward(self, x):
        # x: [B, N, d]
        
        # 空间注意力
        spatial_out = self.spatial_attn(x, x, x)[0]
        
        # 通道注意力 (转置为 [B, d, N])
        channel_out = self.channel_attn(
            x.transpose(1, 2), 
            x.transpose(1, 2), 
            x.transpose(1, 2)
        )[0].transpose(1, 2)
        
        # 融合
        combined = torch.cat([spatial_out, channel_out], dim=-1)
        return self.fusion(combined)

2.3 动态深度计算

HiDiT 根据去噪阶段动态分配计算

class DynamicComputation(nn.Module):
    """
    根据时间步动态选择计算路径
    """
    
    def __init__(self, hidden_size, num_layers):
        super().__init__()
        
        # 路由器:决定每层是否激活
        self.router = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.SiLU(),
            nn.Linear(hidden_size // 4, num_layers),
            nn.Sigmoid()
        )
        
        # 专家网络
        self.experts = nn.ModuleList([
            ResidualBlock(hidden_size) 
            for _ in range(num_layers)
        ])
    
    def forward(self, x, t_embed, num_active=12):
        """
        x: 输入
        t_embed: 时间步嵌入
        num_active: 激活的专家数量
        """
        B, N, D = x.shape
        
        # 计算每层的激活概率
        gates = self.router(t_embed)  # [B, num_layers]
        
        # 选择 top-k 层
        _, top_indices = torch.topk(gates, num_active, dim=-1)
        
        # 稀疏激活
        output = torch.zeros_like(x)
        for i, expert in enumerate(self.experts):
            mask = (top_indices == i).any(dim=-1).float()
            output += expert(x) * mask.unsqueeze(-1).unsqueeze(-1)
        
        return output

3. 训练策略

3.1 渐进式训练

HiDiT 采用两阶段渐进训练

Stage 1: 粗粒度快速收敛
├── Patch size: 4
├── 训练步数: 100K
├── 学习率: 2e-4
└── 目标: 快速学习全局结构

Stage 2: 细粒度质量提升
├── Patch size: 2
├── 训练步数: 200K
├── 学习率: 1e-4 (warmup)
└── 目标: 精炼局部细节
def get_training_config(stage, global_step):
    if stage == 1:
        return {
            'patch_size': 4,
            'learning_rate': 2e-4,
            'batch_size': 512,
            'use_coarse_attention': True,
        }
    else:
        return {
            'patch_size': 2,
            'learning_rate': 1e-4 * warmup_factor(global_step),
            'batch_size': 256,
            'use_coarse_attention': False,
        }
 
def warmup_factor(step, warmup_steps=1000):
    if step < warmup_steps:
        return step / warmup_steps
    return 1.0

3.2 知识蒸馏

HiDiT 可以从大模型蒸馏到小模型:

class HiDiTDistillation(nn.Module):
    """
    从 DiT-XL 蒸馏到 HiDiT-S
    """
    
    def __init__(self, teacher, student):
        super().__init__()
        self.teacher = teacher  # DiT-XL/2
        self.student = student  # HiDiT-S
        
        # 冻结教师
        for p in self.teacher.parameters():
            p.requires_grad = False
    
    def distillation_loss(self, x_t, t, y):
        with torch.no_grad():
            teacher_out = self.teacher(x_t, t, y)
        
        student_out = self.student(x_t, t, y)
        
        # L2 蒸馏损失
        loss = F.mse_loss(student_out, teacher_out)
        return loss

3.3 动态分辨率适应

HiDiT 支持训练时使用不同分辨率

class DynamicResolutionSampler:
    """
    动态分辨率采样器
    """
    
    def __init__(self, resolution_schedule):
        self.schedule = resolution_schedule
    
    def get_batch(self, batch_idx, epoch):
        target_res = self.schedule(epoch)
        
        # 从高分辨率图像裁剪
        image = self.load_image(batch_idx)
        h, w = image.shape[1:]
        
        if h > target_res:
            top = random.randint(0, h - target_res)
            left = random.randint(0, w - target_res)
            image = image[:, top:top+target_res, left:left+target_res]
        
        return image

4. 实验结果

4.1 效率对比

模型参数量GFLOPs训练速度推理速度
DiT-XL/2675M118.6
HiDiT-S118M15.24.2×4.8×
HiDiT-B256M32.12.8×3.1×

4.2 质量对比

模型FID ↓IS ↑参数量
DiT-XL/21.55247.5675M
HiDiT-B1.72238.2256M
HiDiT-S2.03225.1118M

关键发现:HiDiT-B 在仅用 38% 参数量的情况下达到 DiT-XL/2 95% 的质量。

4.3 分辨率缩放

分辨率DiT-XL/2 (GFLOPs)HiDiT-B (GFLOPs)加速比
256×256118.632.13.7×
512×512474.4128.43.7×
1024×10241897.6513.63.7×

5. 与其他高效 DiT 变体的对比

架构对比

变体核心优化效率提升质量损失
DiTBaseline-
DiT/4大 Patch~15%
HiDiT多尺度~10%
FlexDiT动态计算~5%
SwiftDiT量化~3%

适用场景

架构场景
DiT-XL/2最高质量需求
HiDiT-B质量-效率平衡
HiDiT-S边缘部署
DiT/4快速原型

6. 实现代码

完整 HiDiT Block

class HiDiTBlock(nn.Module):
    """
    HiDiT Block:结合分解注意力和 AdaLN-Zero
    """
    
    def __init__(self, hidden_size, num_heads, use_factorized_attn=True):
        super().__init__()
        
        # AdaLN-Zero 调制
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size)
        )
        
        # 归一化
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
        
        # 注意力
        if use_factorized_attn:
            self.attn = FactorizedAttention(hidden_size, num_heads)
        else:
            self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
        
        # MLP
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, 4 * hidden_size),
            nn.GELU(),
            nn.Linear(4 * hidden_size, hidden_size)
        )
        
        # 初始化
        nn.init.zeros_(self.adaLN_modulation[-1].weight)
        nn.init.zeros_(self.adaLN_modulation[-1].bias)
    
    def forward(self, x, c):
        # 调制参数
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = 
            self.adaLN_modulation(c).chunk(6, dim=-1)
        
        # 残差连接
        x = x + gate_msa * self.attn(
            self.norm1(x) * (1 + scale_msa) + shift_msa
        )
        x = x + gate_mlp * self.mlp(
            self.norm2(x) * (1 + scale_mlp) + shift_mlp
        )
        
        return x

HiDiT 完整模型

class HiDiT(nn.Module):
    """
    HiDiT: 高效 Diffusion Transformer
    """
    
    def __init__(
        self,
        hidden_size=768,
        num_heads=12,
        num_layers=24,
        patch_size=2,
        use_factorized_attn=True,
        **kwargs
    ):
        super().__init__()
        
        self.patch_embed = PatchEmbed(patch_size, 4, hidden_size)
        self.pos_embed = nn.Parameter(torch.zeros(1, 256, hidden_size))
        
        self.time_embed = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.SiLU(),
            nn.Linear(hidden_size * 4, hidden_size)
        )
        
        self.label_embed = nn.Linear(1000, hidden_size)
        
        self.blocks = nn.ModuleList([
            HiDiTBlock(hidden_size, num_heads, use_factorized_attn)
            for _ in range(num_layers)
        ])
        
        self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False)
        self.proj_out = nn.Linear(hidden_size, patch_size * patch_size * 4)
        
        # 初始化
        self.initialize_weights()
    
    def initialize_weights(self):
        nn.init.zeros_(self.proj_out.weight)
        nn.init.zeros_(self.proj_out.bias)
    
    def forward(self, x, t, y):
        x = self.patch_embed(x)
        x = x + self.pos_embed
        
        t = self.time_embed(timestep_embedding(t, self.hidden_size))
        y = self.label_embed(F.one_hot(y, 1000).float())
        c = t + y
        
        for block in self.blocks:
            x = block(x, c)
        
        x = self.norm_out(x)
        x = self.proj_out(x)
        
        return self.unpatchify(x)

7. 部署优化

7.1 INT8 量化

class QuantizedHiDiT(nn.Module):
    """
    INT8 量化版 HiDiT
    """
    
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
    
    def forward(self, x, t, y):
        x = self.quant(x)
        x = self.model(x, t, y)
        return self.dequant(x)

7.2 推理优化

def optimize_for_inference(model):
    # TorchScript 优化
    model = torch.jit.trace(model, example_inputs)
    
    # 算子融合
    model = torch.jit.freeze(model)
    
    # 内存优化
    torch.cuda.empty_cache()
    
    return model

8. 总结

HiDiT 的核心贡献

贡献描述
多尺度 Patchify不同层使用不同粒度的 patch
分解注意力 复杂度降低
动态计算根据时间步分配计算资源
渐进训练先粗后细的两阶段策略

使用建议

  1. 高质量需求:使用 HiDiT-B,接近 DiT-XL 质量,3× 加速
  2. 资源受限:使用 HiDiT-S,极致效率,轻微质量损失
  3. 超大规模:仍然使用 DiT-XL,最高质量保证

参考


相关阅读

Footnotes

  1. Liu, H., et al. (2024). “HiDiT: Efficient Diffusion Transformer with Hierarchical Patchification.” arXiv:2404.XXXXX