DiT: Diffusion Transformer 架构详解

概述

DiT (Diffusion Transformer) 由 William Peebles 和 Saining Xie 于 ICCV 2023 提出,是首个成功将 Transformer 架构应用于扩散模型的工作。1

DiT 的核心贡献是证明了:

  1. Transformer 可以作为扩散模型的主干网络
  2. DiT 遵循与 LLM 相同的 Scaling Laws
  3. U-Net 的归纳偏置在扩散模型中并非必须

1. DiT 整体架构

1.1 架构概览

┌─────────────────────────────────────────────────────────────────┐
│                        DiT 架构                                  │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  输入图像                                                         │
│      │                                                          │
│      ▼                                                          │
│  ┌─────────┐                                                   │
│  │   VAE   │  编码到 latent 空间 (如 32×32)                    │
│  └────┬────┘                                                   │
│       │                                                        │
│       ▼                                                        │
│  ┌─────────────┐                                               │
│  │   Patchify  │  将 latent 分割成 patches (如 2×2)            │
│  └──────┬──────┘                                               │
│         │                                                     │
│         ▼                                                     │
│  ┌─────────────┐     ┌─────────────┐     ┌─────────────┐       │
│  │ Linear Proj │     │  DiT Blocks │ ... │  DiT Blocks │       │
│  │ + Pos Emb   │ ──▶ │  (×N)      │     │  (×N)       │       │
│  └─────────────┘     └──────┬──────┘     └──────┬──────┘       │
│                            │                    │              │
│                            └──────────┬─────────┘              │
│                                       ▼                         │
│                            ┌─────────────────┐                 │
│                            │    Norm + Proj   │                 │
│                            └────────┬────────┘                  │
│                                     │                          │
│                                     ▼                          │
│  ┌─────────┐                                               │
│  │  Unpatch │  将 patches 合并回 latent                    │
│  └────┬────┘                                               │
│       │                                                    │
│       ▼                                                    │
│  ┌─────────┐                                               │
│  │   VAE   │  解码回图像空间                                 │
│  └─────────┘                                               │
│       │                                                    │
│       ▼                                                    │
│  输出图像                                                    │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

1.2 核心组件

组件描述
VAE将图像编码到 latent 空间 (如 SD 的 VAE, 下采样 8x)
Patchify将 latent 分割成非重叠的 patches
DiT Block自适应 LayerNorm (AdaLN) + Self-Attention + FFN
位置编码傅里叶位置编码 (RoPE 或标准正弦)
Final Layer预测噪声 和方差

2. Patchify 层

2.1 作用

Patchify 将 2D latent 特征图转换为 patch 序列,类似于 ViT 对图像的处理。

输入

  • Latent 特征:
  • 例如 SD VAE 输出: (256×256 图像下采样 8x)

输出

  • Patch 序列:
  • 其中 , 是 patch size

2.2 实现

class Patchify(nn.Module):
    """
    将 latent 特征图分割成 patches
    """
    def __init__(self, in_channels=4, patch_size=2, hidden_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (32 // patch_size) ** 2  # 取决于 VAE
        
        # 每个 patch 映射到 hidden_dim
        self.proj = nn.Conv2d(
            in_channels,
            hidden_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
    
    def forward(self, x):
        """
        Args:
            x: (B, C, H, W) latent 特征
        
        Returns:
            patches: (B, N, D) 其中 N = H*W/patch_size²
        """
        # Conv2d 实现 patchify
        x = self.proj(x)  # (B, D, H/p, W/p)
        x = x.flatten(2).transpose(1, 2)  # (B, N, D)
        return x

2.3 Patch Size 的影响

Patch SizeToken 数计算量质量
816最低较低
464中等中等
2256较高最高

发现:Patch size=2 的 DiT 生成质量最高,但计算量也最大。


3. 位置编码

3.1 傅里叶位置编码

DiT 使用标准的正弦位置编码(与 ViT 相同):

def get_positional_encoding(seq_len, dim, device):
    """
    生成傅里叶位置编码
    
    Args:
        seq_len: 序列长度 (patch 数)
        dim: 编码维度
        device: 设备
    
    Returns:
        pos_emb: (1, seq_len, dim)
    """
    position = torch.arange(seq_len, device=device).unsqueeze(1)
    div_term = torch.exp(
        torch.arange(0, dim, 2, device=device) * (-np.log(10000) / dim)
    )
    
    pe = torch.zeros(seq_len, dim, device=device)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    
    return pe.unsqueeze(0)

3.2 条件位置编码

对于时间步 和类别 的条件,DiT 使用自适应方式注入位置编码的缩放/偏移。


4. DiT Block:自适应条件注入

4.1 四种 Conditioning 方案对比

DiT 论文系统比较了四种条件注入方式:

方法描述效果
In-context 作为额外 token 拼接较差
Cross-attention添加独立的 cross-attention 层中等
AdaLN自适应 LayerNorm较好
AdaLN-ZeroAdaLN + 门控初始化为 0最优

4.2 AdaLN (Adaptive Layer Norm)

AdaLN 根据时间步 和类别 自适应调整 LayerNorm 的参数:

class AdaLN(nn.Module):
    """
    自适应 LayerNorm
    
    根据条件 c 动态调整 scale 和 shift
    """
    def __init__(self, dim, cond_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim, elementwise_affine=False)
        # 预测 scale, shift, gate
        self.linear = nn.Linear(cond_dim, dim * 3)
    
    def forward(self, x, cond):
        """
        Args:
            x: (B, N, D) 输入特征
            cond: (B, cond_dim) 条件 (t, c)
        
        Returns:
            自适应归一化后的特征
        """
        # 预测调制参数
        scale, shift, gate = self.linear(cond).chunk(3, dim=-1)
        
        # LayerNorm
        x = self.norm(x)
        
        # 调制
        x = x * (1 + scale) + shift
        x = x * (1 + gate)  # 门控
        
        return x

4.3 AdaLN-Zero

AdaLN-Zero 是 DiT 的核心创新:

class AdaLNZero(nn.Module):
    """
    AdaLN-Zero: 门控初始化为 0
    
    核心洞察:每个 block 训练初期应该接近恒等映射
    """
    def __init__(self, dim, cond_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim, elementwise_affine=False)
        # 预测 scale, shift, gate (gate 初始化为 0)
        self.linear = nn.Linear(cond_dim, dim * 3)
        
        # 初始化:gate 为 0,其他为标准初始化
        nn.init.zeros_(self.linear.weight[:, 2*dim:])  # gate 置 0
        nn.init.zeros_(self.linear.bias[2*dim:])         # gate bias 置 0
    
    def forward(self, x, cond):
        # 预测调制参数
        scale, shift, gate = self.linear(cond).chunk(3, dim=-1)
        
        # LayerNorm
        x = self.norm(x)
        
        # 调制
        x = x * (1 + scale) + shift
        x = gate * x  # 训练初期 gate≈0,接近恒等映射
        
        return x

AdaLN-Zero 的优势

  1. 训练初期每个 block 等效为恒等映射
  2. 随着训练进行,门控逐渐学习到有用调制
  3. 显著提升深度网络的训练稳定性

4.4 完整的 DiT Block

class DiTBlock(nn.Module):
    """
    DiT Block: AdaLN + Self-Attention + FFN
    """
    def __init__(self, dim, num_heads, mlp_ratio=4.0, cond_dim=256):
        super().__init__()
        self.norm1 = AdaLNZero(dim, cond_dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm2 = AdaLNZero(dim, cond_dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * mlp_ratio),
            nn.GELU(),
            nn.Linear(dim * mlp_ratio, dim)
        )
        # MLP 门控也初始化为 0
        nn.init.zeros_(self.mlp[0].weight)
        nn.init.zeros_(self.mlp[0].bias)
    
    def forward(self, x, cond):
        # Self-Attention
        x = x + self.norm1.gate * self.attn(
            self.norm1(x, cond),
            self.norm1(x, cond),
            self.norm1(x, cond)
        )[0]
        
        # FFN
        x = x + self.norm2.gate * self.mlp(self.norm2(x, cond))
        
        return x

5. 时间步与类别 Embedding

5.1 时间步 Embedding

class TimestepEmbedder(nn.Module):
    """
    时间步 embedding: MLP 编码正弦位置编码
    """
    def __init__(self, hidden_dim, frequency_dim=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.frequency_dim = frequency_dim
    
    def forward(self, t):
        """
        Args:
            t: (B,) 时间步 (整数或连续值)
        
        Returns:
            emb: (B, hidden_dim)
        """
        # 正弦位置编码
        freqs = torch.exp(
            -torch.log(torch.tensor(10000)) * 
            torch.arange(0, self.frequency_dim, 2, device=t.device) / self.frequency_dim
        )
        
        # 编码
        t_emb = t[:, None] * freqs[None, :]
        t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=-1)
        
        return self.mlp(t_emb)

5.2 类别 Embedding (可选)

对于分类条件,使用类似的 embedding 层:

class LabelEmbedder(nn.Module):
    """类别 embedding"""
    def __init__(self, num_classes, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(num_classes, hidden_dim)
    
    def forward(self, y):
        return self.embedding(y)

5.3 条件融合

def combine_condition(t_emb, y_emb=None):
    """融合时间和类别条件"""
    if y_emb is not None:
        # 拼接
        return torch.cat([t_emb, y_emb], dim=-1)
    else:
        return t_emb

6. 模型变体与缩放

6.1 DiT 模型系列

模型DepthHidden DimHeads参数量GFLOPsFID
DiT-S/212384633M11868.4
DiT-B/21276812130M118543.5
DiT-L/224102416458M32919.62
DiT-XL/228115216675M11812.27

关键发现:DiT-XL/2 使用与 DiT-B/2 相似的计算量,但质量大幅提升!

6.2 Scaling Laws

DiT 遵循与语言模型相同的 scaling laws:

FID (越低越好)
│                              ★ DiT-XL/2
│                         ★ DiT-L/2
│                    ★ DiT-B/2
│               ★ DiT-S/2
│
└───────────────────────────────────────▶ 计算量 (GFLOPs)
         100              1000        10000

实验结论

  1. 固定计算预算下,更大的模型 + 更少的步骤优于小模型 + 多步骤
  2. FID 与 GFLOPs 呈幂律关系
  3. 模型大小比 token 数量更重要

7. 输出预测:噪声与方差

7.1 两种输出方式

DiT 需要同时预测:

  1. 噪声
  2. 方差 (或
class FinalLayer(nn.Module):
    """
    最终层:预测噪声和方差
    """
    def __init__(self, hidden_dim, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_dim, elementwise_affine=False)
        self.linear = nn.Linear(hidden_dim, patch_size * patch_size * out_channels)
        
        # 方差预测
        self.linear_out = nn.Linear(hidden_dim, patch_size * patch_size * out_channels)
        
        # 初始化
        nn.init.zeros_(self.linear.weight)
        nn.init.zeros_(self.linear.bias)
    
    def forward(self, x):
        """
        Returns:
            noise: (B, N, P*D)
            variance: (B, N, P*D) 或标量
        """
        x = self.norm_final(x)
        noise = self.linear(x)
        
        # 可以预测 log variance 或直接预测 beta
        # 取决于噪声调度
        return noise, None  # 简化版

7.2 条件独立预测

更常见的实现是条件独立预测噪声和方差:

class DiTOutput(nn.Module):
    """DiT 输出头"""
    def __init__(self, hidden_dim, patch_size, out_channels=4):
        super().__init__()
        self.norm = nn.LayerNorm(hidden_dim)
        
        # 噪声预测头
        self.head_noise = nn.Linear(hidden_dim, patch_size**2 * out_channels)
        
        # 方差预测头
        self.head_variance = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, x):
        x = self.norm(x)
        noise = self.head_noise(x)
        log_var = self.head_variance(x)
        return noise, log_var

8. 完整 DiT 实现

class DiT(nn.Module):
    """
    完整的 DiT 模型
    """
    def __init__(
        self,
        in_channels=4,
        patch_size=2,
        hidden_dim=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4.0,
        num_classes=None,
        cond_dim=None
    ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.out_channels = in_channels
        self.patch_size = patch_size
        
        # Patchify
        self.patchify = Patchify(in_channels, patch_size, hidden_dim)
        self.num_patches = self.patchify.num_patches
        
        # 位置编码
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, hidden_dim))
        nn.init.normal_(self.pos_embed, std=0.02)
        
        # 时间步和类别 embedding
        self.t_embed = TimestepEmbedder(hidden_dim)
        if num_classes is not None:
            self.y_embed = LabelEmbedder(num_classes, hidden_dim)
            cond_dim = hidden_dim * 2
        else:
            cond_dim = hidden_dim
        
        # DiT Blocks
        self.blocks = nn.ModuleList([
            DiTBlock(hidden_dim, num_heads, mlp_ratio, cond_dim)
            for _ in range(depth)
        ])
        
        # Final Layer
        self.final_layer = FinalLayer(hidden_dim, patch_size, in_channels)
        
        self.initialize_weights()
    
    def initialize_weights(self):
        """权重初始化"""
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
        
        self.apply(_basic_init)
        
        # AdaLN-Zero 门控初始化为 0
        for block in self.blocks:
            nn.init.zeros_(block.norm1.linear.weight[:, 2*self.hidden_dim:])
            nn.init.zeros_(block.norm1.linear.bias[2*self.hidden_dim:])
    
    def unpatchify(self, x):
        """将 patches 合并回特征图"""
        B, N, D = x.shape
        H = W = int(N ** 0.5)
        p = self.patch_size
        
        x = x.reshape(B, H, W, p, p, self.out_channels)
        x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
        x = x.reshape(B, self.out_channels, H*p, W*p)
        return x
    
    def forward(self, x, t, y=None):
        """
        前向传播
        
        Args:
            x: (B, C, H, W) latent 特征
            t: (B,) 时间步
            y: (B,) 类别标签 (可选)
        
        Returns:
            noise: (B, C, H, W) 预测的噪声
        """
        B = x.shape[0]
        
        # Patchify
        x = self.patchify(x)  # (B, N, D)
        x = x + self.pos_embed
        
        # 时间步 embedding
        t_emb = self.t_embed(t)
        
        # 类别 embedding
        if y is not None:
            y_emb = self.y_embed(y)
            cond = torch.cat([t_emb, y_emb], dim=-1)
        else:
            cond = t_emb
        
        # DiT Blocks
        for block in self.blocks:
            x = block(x, cond)
        
        # Final Layer
        x = self.final_layer(x)
        x = self.unpatchify(x)
        
        return x

9. 训练与采样

9.1 训练目标

def training_step(model, batch, optimizer):
    """
    DiT 训练步骤
    """
    x0, y = batch  # 原始图像和标签
    
    # VAE 编码
    with torch.no_grad():
        z = vae.encode(x0).latent_dist.sample()
        z = z * vae.config.scaling_factor
    
    # 采样时间步
    t = torch.randint(0, model.num_timesteps, (B,), device=device)
    
    # 添加噪声
    noise = torch.randn_like(z)
    z_t = add_noise(z, t, alpha_bar)
    
    # 预测噪声
    noise_pred = model(z_t, t, y)
    
    # 损失
    loss = F.mse_loss(noise_pred, noise)
    
    # 反向传播
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    return loss

9.2 采样

@torch.no_grad()
def sampling(model, z_T, y=None, num_steps=50):
    """
    DDIM 采样
    """
    model.eval()
    
    timesteps = torch.linspace(model.num_timesteps-1, 0, num_steps).long()
    
    z = z_T
    for i, t in enumerate(tqdm(timesteps)):
        t_batch = torch.full((B,), t, device=device, dtype=torch.long)
        
        # 预测噪声
        noise_pred = model(z, t_batch, y)
        
        # DDIM 更新 (简化版)
        alpha_bar_t = alphas_cumprod[t]
        alpha_bar_t_prev = alphas_cumprod[timesteps[i-1]] if i > 0 else 1
        
        # 预测 x0
        pred_x0 = (z - torch.sqrt(1-alpha_bar_t) * noise_pred) / torch.sqrt(alpha_bar_t)
        
        # 方向指向 x0
        direction = torch.sqrt(1 - alpha_bar_t_prev) * noise_pred
        
        # 组合
        z = torch.sqrt(alpha_bar_t_prev) * pred_x0 + direction
    
    # VAE 解码
    x = vae.decode(z / vae.config.scaling_factor)
    
    return x

10. 与 U-Net 的对比

10.1 架构差异

维度U-NetDiT
结构Encoder-Decoder + Skip等向性 Transformer
位置信息内置于卷积显式位置编码
感受野局部到全局全局自注意力
层次结构多尺度特征融合统一表示

10.2 优缺点对比

方面U-NetDiT
收敛速度较慢
Scaling无明确规律遵循 Scaling Laws
计算效率高分辨率下效率下降可通过 patch_size 调节
长距离依赖需要空洞卷积/注意力自然建模
实现复杂度中等较高

10.3 融合趋势:U-DiT

2024-2025 年的研究开始融合两者的优点:

  • U-DiT:在 DiT 中引入 U-Net 的层次结构和 skip connection
  • U-ViT:在 ViT/DiT 中引入类似 U-Net 的下采样/上采样路径
  • MM-DiT(SD3):多模态 DiT,融合文本和图像 embedding

11. 参考文献


相关链接

Footnotes

  1. Peebles & Xie, “Scalable Diffusion Models with Transformers”, ICCV 2023. https://arxiv.org/abs/2212.09748