DiT 变体与 2025 年新进展

概述

DiT (Diffusion Transformer) 自 2023 年提出以来,在图像生成领域取得了巨大成功。2024-2025 年,研究者们在 DiT 基础上提出了多种改进变体,涵盖效率优化、架构创新、应用扩展等多个方向。

本文梳理了近年来最重要的 DiT 变体及其核心创新。


1. U-DiT: U-Net 遇见 DiT

1.1 核心思想

U-DiT 由 Tian 等人提出(CVPR 2025),核心观察是:U-Net 的层次结构和 skip connection 在 DiT 中同样有效1

1.2 架构设计

┌──────────────────────────────────────────────────────────────┐
│                        U-DiT 架构                            │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│  输入 (H×W) ──┬── Tokenize ──▶ Patch Embed ──▶ Level 1     │
│               │                                              │
│               │                                              │
│               ▼                                              │
│  ┌─────────────────────────────────────────────────────┐    │
│  │              DiT Block × N₁                         │    │
│  └─────────────────────────────────────────────────────┘    │
│                          │                                   │
│                          │ Downsample                        │
│                          ▼                                   │
│  ┌─────────────────────────────────────────────────────┐    │
│  │              DiT Block × N₂                         │    │
│  └─────────────────────────────────────────────────────┘    │
│                          │                                   │
│                          │ Skip Connection                   │
│                          ▼                                   │
│  ┌─────────────────────────────────────────────────────┐    │
│  │              DiT Block × N₁                         │    │
│  └─────────────────────────────────────────────────────┘    │
│                          │                                   │
│                          │ Upsample + Skip                   │
│                          ▼                                   │
│  ┌─────────────────────────────────────────────────────┐    │
│  │              DiT Block × N₁                         │    │
│  └─────────────────────────────────────────────────────┘    │
│                          │                                   │
│                          ▼                                   │
│                    Patch Detokenize ──▶ 输出                  │
│                                                              │
└──────────────────────────────────────────────────────────────┘

1.3 关键创新

创新点描述效果
层次化 Token不同层使用不同的 patch size计算量减少 6 倍
QKV 下采样Attention 中对 K, V 进行下采样加速注意力计算
Skip Connection引入类似 U-Net 的跳连保留多尺度信息

1.4 实验结果

模型GFLOPsFID质量对比
DiT-XL/211212.27基准
U-DiT-L/2~190~2.1质量相当,计算量减少

1.5 实现细节

class UDiTBlock(nn.Module):
    """
    U-DiT Block: 带层次化结构的 DiT Block
    """
    def __init__(self, dim, num_heads, downsample_kv=True):
        super().__init__()
        self.norm1 = AdaLNZero(dim)
        self.attn = Attention(dim, num_heads)
        self.norm2 = AdaLNZero(dim)
        self.ffn = MLP(dim)
        
        # QKV 下采样(可选)
        self.downsample_kv = downsample_kv
        if downsample_kv:
            self.kv_downsample = nn.Conv2d(dim, dim, kernel_size=3, stride=2, padding=1)
    
    def forward(self, x, cond, kv_cache=None):
        # 自注意力
        x = x + self.norm1.gate * self.attn(self.norm1(x, cond), kv_cache)
        
        # FFN
        x = x + self.norm2.gate * self.ffn(self.norm2(x, cond))
        
        return x

2. DyDiT++: 动态 DiT

2.1 核心思想

DyDiT++ 由阿里巴巴达摩院提出(ICLR 2025),核心观察是:不同时间步和空间区域需要的计算量不同2

2.2 两个维度的动态性

Timestep-wise Dynamic Width (TDW)

时间步 t=0.9 (早期): 需要大模型宽度 ────────────████████████ 100%
时间步 t=0.5 (中期): 需要中等模型宽度 ────────────████████ 60%
时间步 t=0.1 (晚期): 小模型宽度足够 ──────────────████ 30%

关键洞察:早期 timestep 需要更多表达能力(生成全局结构),晚期 timestep 只需精细调整(局部纹理)。

Spatial-wise Dynamic Token (SDT)

图像区域:
┌─────────────┐
│  主体 (高重要性)  │ ─── 需要完整计算 ████████████
│               │
│  背景 (低重要性)  │ ─── 可跳过计算   ░░░░░░░░░░░
└─────────────┘

判断标准:背景区域 loss 贡献低,可以跳过其计算。

2.3 架构设计

class DyDiTPPBlock(nn.Module):
    """
    DyDiT++ Block: 支持动态宽度和动态 Token
    """
    def __init__(self, dim, num_heads, max_width_ratio=1.0):
        super().__init__()
        self.max_width = int(dim * max_width_ratio)
        self.current_width = dim
        
        # 宽度控制器
        self.width_predictor = nn.Sequential(
            nn.Linear(dim, dim // 4),
            nn.SiLU(),
            nn.Linear(dim // 4, 1),
            nn.Sigmoid()
        )
        
        # 标准组件
        self.norm = AdaLNZero(dim)
        self.attn = SparseAttention(dim, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )
    
    def forward(self, x, t, importance_mask=None):
        # 1. 动态宽度调整
        width_ratio = self.width_predictor(x.mean(dim=1, keepdim=True))
        current_dim = int(self.max_width * width_ratio)
        
        # 2. 空间动态 Token(跳过低重要性区域)
        if importance_mask is not None:
            # 稀疏注意力:只计算重要区域的 attention
            x = self.sparse_attention(x, importance_mask)
        
        # 3. 标准前向
        x = self.norm(x, t) * self.norm.gate + x
        x = x + self.ffn(x)
        
        return x
    
    def sparse_attention(self, x, mask, threshold=0.3):
        """
        稀疏注意力:跳过低重要性 token
        """
        # 计算重要性分数
        scores = torch.sigmoid(self.importance_net(x))
        
        # 掩码:只保留高于阈值的 token
        active_mask = scores > threshold
        
        # 对活跃 token 进行注意力计算
        active_indices = active_mask.nonzero(as_tuple=True)
        if len(active_indices[0]) > 0:
            x_active = x[active_indices]
            # 注意力计算
            attn_out = self.attn(x_active)
            # 写回
            x = torch.scatter(x, 0, active_indices[0].unsqueeze(-1).expand_as(attn_out), attn_out)
        
        return x

2.4 TD-LoRA: 参数高效微调

DyDiT++ 还提出了 TD-LoRA,在时间维度上分解 LoRA 矩阵:

class TDLoRA(nn.Module):
    """
    Time-Decomposed LoRA: 按时间步分解的适配器
    """
    def __init__(self, dim, rank=16, num_timesteps=1000):
        super().__init__()
        self.dim = dim
        self.rank = rank
        
        # 时间步 embedding
        self.t_emb = nn.Embedding(num_timesteps, rank)
        
        # 基础 LoRA 矩阵
        self.A = nn.Linear(dim, rank, bias=False)
        self.B = nn.Linear(rank, dim, bias=False)
        
        # 时间依赖调制
        self.time_scale = nn.Linear(rank, rank)
        self.time_shift = nn.Linear(rank, rank)
    
    def forward(self, x, t):
        # 获取时间步的 rank 向量
        t_vec = self.t_emb(t)  # (B, rank)
        
        # 时间依赖的 A, B
        A_t = self.A.weight * (1 + self.time_scale(t_vec))
        B_t = self.B.weight + self.time_shift(t_vec)
        
        # LoRA 输出
        return (x @ A_t.T) @ B_t.T

2.5 实验结果

模型GFLOPs 减少硬件加速FID
DiT-XL/20%1.0×2.27
DyDiT++-XL51%1.73×2.07
DyDiT++-L48%1.65×2.18

3. D²iT: 动态压缩 DiT

3.1 核心思想

D²iT(CVPR 2025)关注的问题是:不同图像区域应该使用不同的压缩率3

3.2 问题分析

压缩率优点缺点
大压缩计算量小、全局一致性局部细节丢失
小压缩保留局部细节计算量大

核心洞察:主体区域需要小压缩(细节重要),背景区域可以使用大压缩(细节不重要)。

3.3 自适应压缩策略

class D2iT(nn.Module):
    """
    Dynamic Diffusion Transformer: 自适应压缩
    """
    def __init__(self, dim, num_classes=1000):
        super().__init__()
        self.dim = dim
        
        # 重要性预测器
        self.importance_predictor = nn.Sequential(
            nn.Linear(dim, dim // 2),
            nn.GELU(),
            nn.Linear(dim // 2, 1),
            nn.Sigmoid()
        )
        
        # 多尺度压缩器
        self.compressors = nn.ModuleList([
            PatchifyCompressor(patch_size=p) for p in [1, 2, 4]
        ])
        
        # DiT blocks
        self.blocks = nn.ModuleList([DiTBlock(dim) for _ in range(28)])
    
    def forward(self, x, t, y=None):
        B, C, H, W = x.shape
        
        # 1. 编码到 latent
        x = self.encoder(x)
        
        # 2. 预测每个区域的重要性
        importance = self.importance_predictor(x)  # (B, N, 1)
        
        # 3. 自适应压缩
        compressed_tokens = []
        for p_size in [1, 2, 4]:
            tokens_p = self.compressors[p_size](x)
            weight = (importance > 1/(p_size**2)).float()  # 高重要性用小 patch
            compressed_tokens.append((tokens_p, weight, p_size))
        
        # 4. 融合多尺度 token
        x = self.fuse_tokens(compressed_tokens)
        
        # 5. 标准 DiT 处理
        for block in self.blocks:
            x = block(x, t, y)
        
        return x
    
    def fuse_tokens(self, multi_scale_tokens):
        """
        融合多尺度的 token 表示
        """
        # 简单的加权平均
        all_tokens = []
        all_weights = []
        
        for tokens, weights, p_size in multi_scale_tokens:
            # 上采样到统一分辨率
            tokens_up = F.interpolate(
                tokens.permute(0, 2, 1).reshape(B, -1, H//2, W//2),
                scale_factor=p_size/2,
                mode='bilinear'
            )
            all_tokens.append(tokens_up)
            all_weights.append(weights)
        
        # 加权平均
        fused = sum(w * t for w, t in zip(all_weights, all_tokens))
        weights_sum = sum(all_weights)
        
        return fused / (weights_sum + 1e-6)

3.4 实验结果

方法Patch SizeFID计算量
固定 p=111.95100%
固定 p=222.1525%
固定 p=442.426.25%
D²iT自适应2.05~15%

4. MM-DiT: 多模态 DiT (Stable Diffusion 3)

4.1 核心思想

MM-DiT(Multi-Modal DiT)由 Stability AI 提出,用于 Stable Diffusion 3,核心创新是分别处理文本和图像 embedding,然后融合4

4.2 架构设计

┌────────────────────────────────────────────────────────────────┐
│                      MM-DiT 架构                               │
├────────────────────────────────────────────────────────────────┤
│                                                                 │
│  文本输入                                                          │
│      │                                                           │
│      ▼                                                           │
│  ┌───────────┐                                                   │
│  │ Text Enc  │ T5-XXL / CLIP Text                               │
│  └─────┬─────┘                                                   │
│        │                                                         │
│        ▼                                                         │
│  ┌───────────┐                                                   │
│  │  Text Proj │ 线性投影到 hidden_dim                            │
│  └─────┬─────┘                                                   │
│        │                                                         │
│        ▼                                                         │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │            文本 DiT Block × N                             │   │
│  │  (独立权重)                                               │   │
│  └─────────────────────────────────────────────────────────┘   │
│                           │                                     │
│                           │ 残差相加                             │
│                           ▼                                     │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │            图像 DiT Block × N                             │   │
│  │  (独立权重)                                               │   │
│  └─────────────────────────────────────────────────────────┘   │
│                           │                                     │
│                           ▼                                     │
│                   MM-DiT Block                                   │
│                 ┌─────────────────┐                            │
│                 │   共享 LayerNorm │                            │
│                 │   + 交叉残差      │                            │
│                 └─────────────────┘                            │
│                           │                                     │
│                           ▼                                     │
│                      输出层                                      │
│                                                                 │
└────────────────────────────────────────────────────────────────┘

4.3 关键创新

创新点描述
独立文本/图像路径避免跨模态干扰
共享 AdaLNLayerNorm 参数共享,降低参数量
Q-Former 集成可选使用 FLAN-T5 的 frozen encoder

4.4 MM-DiT Block 实现

class MMDiTBlock(nn.Module):
    """
    Multi-Modal DiT Block
    """
    def __init__(self, dim, num_heads, text_dim):
        super().__init__()
        
        # 文本路径(独立权重)
        self.text_norm = nn.LayerNorm(dim)
        self.text_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.text_ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )
        
        # 图像路径(独立权重)
        self.image_norm = nn.LayerNorm(dim)
        self.image_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.image_ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )
        
        # 共享的 AdaLN(关键创新)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(dim, 6 * dim)
        )
    
    def forward(self, x_img, x_text, t):
        """
        Args:
            x_img: (B, N_img, D) 图像 token
            x_text: (B, N_text, D) 文本 token
            t: (B, D) 时间步 embedding
        """
        # 共享的条件
        c = self.adaLN_modulation(t)
        c_img, c_text, c_img_gate, c_text_gate = c.chunk(4, dim=-1)
        
        # 图像路径
        x_img = x_img + c_img_gate * self.image_attn(
            self.image_norm(x_img) * (1 + c_img),
            self.image_norm(x_img),
            self.image_norm(x_img)
        )[0]
        x_img = x_img + c_img_gate * self.image_ffn(
            self.image_norm(x_img) * (1 + c_img)
        )
        
        # 文本路径
        x_text = x_text + c_text_gate * self.text_attn(
            self.text_norm(x_text) * (1 + c_text),
            self.text_norm(x_text),
            self.text_norm(x_text)
        )[0]
        x_text = x_text + c_text_gate * self.text_ffn(
            self.text_norm(x_text) * (1 + c_text)
        )
        
        # 残差交叉(图像从文本获取上下文)
        x_img = x_img + 0.1 * x_text.mean(dim=1, keepdim=True)
        
        return x_img, x_text

4.5 SD3 的改进

组件SD 1.x / SDXLSD3 (MM-DiT)
文本编码器CLIP TextCLIP + T5-XXL
架构U-NetMM-DiT
位置编码SinusoidalRoPE
采样器DDPM/DDIMEuler
参数~1B~2B

5. 其他重要变体

5.1 ScaleDiT: 超高分辨率生成

class ScaleDiT(nn.Module):
    """
    ScaleDiT: 分层局部注意力,支持 4K+ 分辨率
    """
    def __init__(self, dim):
        super().__init__()
        # 全局注意力(低分辨率)
        self.global_attn = GlobalAttention(dim)
        
        # 局部注意力(高分辨率分块)
        self.local_attn = LocalAttention(dim, window_size=32)
    
    def forward(self, x, resolution):
        if resolution > 1024:
            # 分块处理 + 全局引导
            return self.hierarchical_generate(x)
        else:
            return self.blocks(x)

5.2 FlexiDiT: 弹性计算

FlexiDiT 的核心思想是单一模型支持可变计算预算

class FlexiDiT(nn.Module):
    """
    FlexiDiT: 弹性计算预算
    """
    def forward(self, x, t, budget=None):
        if budget is None:
            # 完整计算
            return self.full_forward(x, t)
        else:
            # 根据预算动态跳过
            return self.adaptive_forward(x, t, budget)

5.3 PixelDiT: 端到端像素建模

PixelDiT 尝试消除 VAE,直接在像素空间建模:

组件Latent DiT (SD)PixelDiT
输入空间Latent (压缩 8x)像素
序列长度32×32 = 1024256×256 = 65536
计算量
重建误差VAE 引入

6. 性能对比汇总

模型任务FID ↓参数量GFLOPs ↓
DiT-XL/2ImageNet 256²2.27675M1181
U-DiT-L/2ImageNet 256²~2.1~400M~190
DyDiT++-XLImageNet 256²2.07675M~580
D²iTImageNet 256²2.05~700M~180
SD3 (MM-DiT)T2I 1024²SOTA~2B-
MAR-HImageNet 256²~1.6--

7. 总结与展望

7.1 主要研究方向

方向代表工作核心贡献
效率优化U-DiT, DyDiT++, D²iT减少计算量
多模态融合MM-DiT统一文本图像建模
超高分辨率ScaleDiT, PixelDiT突破分辨率限制
弹性计算FlexiDiT单一模型多预算

7.2 未来趋势

  1. 动态计算成为主流:根据内容和时间步自适应分配计算资源
  2. 多模态统一:文本、图像、视频在统一框架下建模
  3. 端到端建模:减少对 VAE 的依赖
  4. 硬件协同设计:针对 NPU/GPU 特性优化架构

7.3 实践建议

场景推荐架构
快速实验DiT-B/2
高质量生成DiT-XL/2
资源受限U-DiT-L/2
多模态生成MM-DiT (SD3)
少步采样结合 Consistency Models

参考文献


相关链接

Footnotes

  1. Tian et al., “U-DiT: U-Net Meets DiT”, CVPR 2025.

  2. Alibaba DAMO, “DyDiT++: Dynamic Diffusion Transformer”, ICLR 2025.

  3. CVPR 2025, “D²iT: Dynamic Diffusion Transformer with Adaptive Compression”.

  4. Stability AI, “Stable Diffusion 3”, 2024.