概述

传统 Transformer 对所有 token 施加相同的计算量。然而不同 token 的”难度”差异巨大:常见 token 易于预测,罕见/复杂 token 需要更多计算。动态深度计算(Dynamic-Depth Compute)通过为每个 token 自适应分配不同的层数,实现计算的最优分配。

Mixture-of-Depths (MoD) (Raposo et al., Google DeepMind, 2024) 1 是这一方向的代表性工作,灵感来自 MoE 但通过深度而非宽度的动态分配实现自适应计算。本文深入讲解 MoD 的核心机制、训练策略,并扩展到 Skip Transformer、PonderV2、MoD-Lite 等变体。


1. 动机与背景

1.1 传统 Transformer 的均匀计算

标准 Transformer 的每一层对所有 token 施加相同计算:

def transformer_forward(x):
    for layer in self.layers:
        x = layer(x)  # 所有 token 都过这一层
    return x

问题

  1. 计算浪费:简单 token(如 “the”、“of”)不需要 80 层推理
  2. 效率瓶颈:难以在不增加参数的情况下降低延迟
  3. 资源分配不均:所有 token 占用相同的推理时间和内存

1.2 计算分配的两种维度

维度方法代表工作
宽度(Width)专家混合MoE (Switch Transformer, GShard)
深度(Depth)层数动态调整MoD, Skip Transformer, PonderV2

1.3 自适应计算时间(ACT)

自适应计算时间(Adaptive Computation Time, ACT)2 是早期工作,训练 RNN 在每个时间步动态决定迭代次数:

ACT 主要针对 RNN,难以直接迁移到 Transformer。

1.4 MoD 的核心思想

MoD 的关键洞察:将”路由”思想从宽度(MoE)扩展到深度(MoD)

  • MoE:每个 token 选择部分专家(参数)
  • MoD:每个 token 选择部分层(计算步骤)

2. Mixture-of-Depths (MoD)

2.1 基本架构

MoD 在 Transformer 的每一层前添加一个路由器,决定该 token 是否进入该层:

def mod_layer(x, router, transformer_block, residual):
    """
    x: (B, N, D)
    router: 路由网络,输出每个 token 的权重
    transformer_block: 标准 Transformer 层
    residual: 残差连接
    """
    # 1. 路由器为每个 token 分配权重
    weights = router(x)  # (B, N, 1)
    
    # 2. 选择 top-k 个 token 进入该层
    capacity = int(k * N)  # k 是每层容量比例
    topk_weights, topk_indices = torch.topk(weights, capacity, dim=1)
    
    # 3. 仅对选中的 token 应用 Transformer 层
    x_selected = transformer_block(x[topk_indices])
    
    # 4. 未选中的 token 通过残差
    x_new = torch.zeros_like(x)
    x_new[topk_indices] = x_selected * topk_weights
    x_new[~topk_indices] = x[~topk_indices]
    
    return x_new

2.2 路由器设计

路由器是一个简单的线性层:

其中:

  • :第 个 token 的表示
  • :路由器为该 token 分配的”重要性分数”
  • :sigmoid 函数

2.3 容量约束(Capacity Constraint)

为保证计算效率,每层的 token 通过率是固定的:

其中 是容量比例(如 )。这意味着:

  • 每层最多处理 个 token
  • 未被选中的 token 通过残差连接跳过该层

关键:所有层处理的 token 总数大致相同,保证了 FLOPs 的确定性。

2.4 训练策略

策略 1:辅助损失(Auxiliary Loss)

为鼓励路由器选择有意义的 token,添加辅助损失:

其中 是第 层实际处理的 token 数。

目标:使每层的实际使用率接近目标容量。

策略 2:专家选择 vs Token 选择

Token 选择(Token-Choice Routing):

  • 每个 token 决定是否进入该层
  • 实现简单但可能不均衡

专家选择(Expert-Choice Routing):

  • 每层选择 top-k tokens(类似 MoE)
  • 更均衡但需排序

MoD 默认使用专家选择风格:

def expert_choice_routing(weights, capacity):
    """
    weights: (B, N, 1)
    capacity: 每层应处理的 token 数
    """
    B, N, _ = weights.shape
    
    # 按权重排序所有 token
    flat_weights = weights.squeeze(-1).view(-1)  # (B*N,)
    sorted_indices = torch.argsort(flat_weights, descending=True)
    
    # 选择 top-capacity 个 token
    selected = sorted_indices[:capacity]
    
    # 还原为 (B, N) 形状的 mask
    mask = torch.zeros_like(flat_weights, dtype=torch.bool)
    mask[selected] = True
    mask = mask.view(B, N)
    
    return mask

策略 3:Capacity Factor

定义每层的容量因子:

不同层可以使用不同的

  • 浅层: 大(处理更多 token)
  • 深层: 小(处理更少 token)

这反映了”浅层聚合局部信息,深层推理复杂语义”的直觉。

2.5 推理时的优化

def mod_inference(model, x, capacity_budget):
    """
    推理时动态决定容量
    """
    for layer in model.layers:
        if layer.is_mod:
            # 路由器预测权重
            weights = layer.router(x)
            
            # 根据预算选择 top-k
            k = int(capacity_budget * x.size(1))
            mask = select_topk(weights, k)
            
            # 处理选中的 token
            x = layer.transformer(x, mask)
        else:
            x = layer.transformer(x)
    
    return x

3. MoD vs MoE 深度对比

3.1 架构对比

维度MoEMoD
路由对象专家(参数)层(计算步骤)
路由维度宽度深度
路由器位置FFN 内层前
容量约束每专家 每层

3.2 计算对比

设模型有 层,每层有 个专家:

MoE

MoD(设每层容量 ):

节省:相比 MoE,节省 的 FLOPs。

3.3 性能对比

模型FLOPs性能
Standard Transformer1.0×100%
MoE (8 experts)1.0×105%
MoD (k=0.5)0.5×99%

关键:MoD 在一半 FLOPs 下达到原模型 99% 性能。


4. MoD 的 PyTorch 实现

4.1 基础 MoD 层

import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
class MoDLayer(nn.Module):
    """单层 MoD"""
    def __init__(self, hidden_dim, num_heads, capacity_ratio=0.5,
                 use_router=True):
        super().__init__()
        self.capacity_ratio = capacity_ratio
        self.use_router = use_router
        
        # Transformer 块
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, 4 * hidden_dim),
            nn.GELU(),
            nn.Linear(4 * hidden_dim, hidden_dim)
        )
        self.norm2 = nn.LayerNorm(hidden_dim)
        
        # 路由器
        if use_router:
            self.router = nn.Linear(hidden_dim, 1)
    
    def forward(self, x):
        """
        x: (B, N, D)
        """
        if not self.use_router or not self.training:
            # 标准 Transformer 路径
            attn_out, _ = self.attention(x, x, x)
            x = self.norm1(x + attn_out)
            x = self.norm2(x + self.ffn(x))
            return x
        
        # MoD 路径
        B, N, D = x.shape
        capacity = int(self.capacity_ratio * N)
        
        # 1. 计算路由权重
        weights = torch.sigmoid(self.router(x))  # (B, N, 1)
        
        # 2. 选择 top-k tokens
        flat_weights = weights.squeeze(-1)  # (B, N)
        topk_values, topk_indices = torch.topk(
            flat_weights, capacity, dim=1
        )  # (B, capacity)
        
        # 3. 提取选中的 tokens
        batch_indices = torch.arange(B, device=x.device).unsqueeze(1)
        x_selected = x[batch_indices, topk_indices]  # (B, capacity, D)
        
        # 4. 对选中的 tokens 应用 Transformer
        attn_out, _ = self.attention(x_selected, x_selected, x_selected)
        x_selected = self.norm1(x_selected + attn_out)
        x_selected = self.norm2(x_selected + self.ffn(x_selected))
        
        # 5. 加权残差
        # 缩放因子保持期望值
        scale = 1.0 / self.capacity_ratio
        
        # 6. 写回原位置
        x_out = x.clone()
        x_out[batch_indices, topk_indices] = x_selected * scale
        
        return x_out
 
 
class MoDTransformer(nn.Module):
    """完整 MoD Transformer"""
    def __init__(self, vocab_size, hidden_dim, num_layers, num_heads,
                 capacity_ratio=0.5, use_mod=True):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.layers = nn.ModuleList()
        
        for i in range(num_layers):
            layer = MoDLayer(
                hidden_dim, num_heads,
                capacity_ratio=capacity_ratio if use_mod else 1.0,
                use_router=use_mod
            )
            self.layers.append(layer)
        
        self.output = nn.Linear(hidden_dim, vocab_size)
    
    def forward(self, x):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x)
        return self.output(x)

4.2 带辅助损失的 MoD

class MoDTransformerWithAuxLoss(nn.Module):
    """带辅助损失的 MoD"""
    def __init__(self, vocab_size, hidden_dim, num_layers, num_heads,
                 capacity_ratio=0.5, aux_loss_coef=0.01):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.layers = nn.ModuleList([
            MoDLayer(hidden_dim, num_heads, capacity_ratio=capacity_ratio)
            for _ in range(num_layers)
        ])
        self.output = nn.Linear(hidden_dim, vocab_size)
        self.aux_loss_coef = aux_loss_coef
    
    def forward(self, x, return_aux_loss=False):
        x = self.embedding(x)
        
        aux_loss = 0.0
        capacity_uses = []
        
        for layer in self.layers:
            x, used = layer(x)
            capacity_uses.append(used)
        
        # 计算容量方差损失
        if return_aux_loss and len(capacity_uses) > 0:
            mean_capacity = sum(capacity_uses) / len(capacity_uses)
            aux_loss = sum(
                (c - mean_capacity) ** 2 for c in capacity_uses
            ) / len(capacity_uses)
            aux_loss *= self.aux_loss_coef
        
        logits = self.output(x)
        
        if return_aux_loss:
            return logits, aux_loss
        return logits

5. MoD 的变体

5.1 MoD-Lite

问题:MoD 需要修改主模型架构,与现有推理框架不兼容。

MoD-Lite 的解决:训练后路由 + 跳过。

class MoDLiteLayer(nn.Module):
    """MoD-Lite: 轻量化版本,推理时无需特殊处理"""
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, 4 * hidden_dim),
            nn.GELU(),
            nn.Linear(4 * hidden_dim, hidden_dim)
        )
        self.norm2 = nn.LayerNorm(hidden_dim)
        
        # 预测是否跳过(推理时使用)
        self.skip_predictor = nn.Linear(hidden_dim, 1)
    
    def forward(self, x, force_skip=None):
        """
        force_skip: (B, N) 强制跳过的 mask
        """
        # 训练时:学习跳过预测
        # 推理时:根据 skip_predictor 决定
        
        if force_skip is not None:
            # 应用强制跳过
            mask = ~force_skip
            x_selected = x[:, mask].reshape(x.size(0), -1, x.size(2))
            # ... 处理 ...
        
        return self.norm2(x + self.ffn(self.norm1(x + self.attention(x, x, x)[0])))

5.2 Skip Transformer (ICLR 2025)

核心思想:在注意力计算中跳过部分 head/token:

class SkipAttention(nn.Module):
    """Skip Attention"""
    def __init__(self, hidden_dim, num_heads, skip_ratio=0.3):
        super().__init__()
        self.num_heads = num_heads
        self.skip_ratio = skip_ratio
        self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim)
        self.proj = nn.Linear(hidden_dim, hidden_dim)
    
    def forward(self, x):
        B, N, D = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, D // self.num_heads)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)  # (3, B, H, N, D/H)
        
        attn = (q @ k.transpose(-2, -1)) / (D // self.num_heads) ** 0.5
        
        # 跳过最不重要的 attention
        if self.training:
            # 训练时:随机跳过
            skip_mask = torch.rand_like(attn) > self.skip_ratio
            attn = attn.masked_fill(~skip_mask, float('-inf'))
        
        attn = F.softmax(attn, dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, N, D)
        return self.proj(out)

5.3 PonderV2

核心思想:预测每个 token 的”思考时间”(需要多少层):

class PonderV2Layer(nn.Module):
    """PonderV2:自适应迭代"""
    def __init__(self, hidden_dim, num_heads, max_iter=4):
        super().__init__()
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads)
        self.norm = nn.LayerNorm(hidden_dim)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, 4 * hidden_dim),
            nn.GELU(),
            nn.Linear(4 * hidden_dim, hidden_dim)
        )
        self.halt_predictor = nn.Linear(hidden_dim, 1)
        self.max_iter = max_iter
    
    def forward(self, x):
        B, N, D = x.shape
        h = x
        cumulative_halt_prob = torch.zeros(B, N, 1, device=x.device)
        
        for step in range(self.max_iter):
            # 应用一层 Transformer
            attn_out, _ = self.attention(h, h, h)
            h = self.norm(h + attn_out)
            h = self.norm(h + self.ffn(h))
            
            # 预测停止概率
            halt_prob = torch.sigmoid(self.halt_predictor(h))
            cumulative_halt_prob = cumulative_halt_prob + halt_prob
            
            # 检查是否停止
            if (cumulative_halt_prob > 0.9).all():
                break
        
        return h

5.4 Early Exit

核心思想:在中间层添加分类器,简单样本提前输出:

class EarlyExitTransformer(nn.Module):
    """早退 Transformer"""
    def __init__(self, vocab_size, hidden_dim, num_layers, num_heads,
                 exit_layers=[4, 8, 12]):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(hidden_dim, num_heads)
            for _ in range(num_layers)
        ])
        self.exit_layers = exit_layers
        self.exit_classifiers = nn.ModuleList([
            nn.Linear(hidden_dim, vocab_size) for _ in exit_layers
        ])
        self.confidence_threshold = 0.9
    
    def forward(self, x):
        x = self.embedding(x)
        
        for i, layer in enumerate(self.layers):
            x = layer(x)
            
            # 检查是否是早退层
            if (i + 1) in self.exit_layers:
                logits = self.exit_classifiers[self.exit_layers.index(i + 1)](x)
                confidence = F.softmax(logits, dim=-1).max(dim=-1)[0].mean()
                
                if confidence > self.confidence_threshold:
                    return logits
        
        return self.exit_classifiers[-1](x)

6. Universal Transformer

6.1 核心思想

Universal Transformer (Dehghani et al., 2019) 将 Transformer 层在时间维度共享权重:

class UniversalTransformer(nn.Module):
    """Universal Transformer:权重共享的循环 Transformer"""
    def __init__(self, vocab_size, hidden_dim, num_heads, max_steps=8):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        # 单一 Transformer 层(共享权重)
        self.transformer = nn.TransformerEncoderLayer(hidden_dim, num_heads)
        self.max_steps = max_steps
        self.halt_predictor = nn.Linear(hidden_dim, 1)
    
    def forward(self, x):
        x = self.embedding(x)
        h = x
        
        for step in range(self.max_steps):
            # 应用共享的 Transformer
            h_new = self.transformer(h)
            
            # 决定是否停止
            halt_prob = torch.sigmoid(self.halt_predictor(h_new))
            
            # 加权更新
            h = h + halt_prob * (h_new - h)
            
            # 检查停止条件
            if (halt_prob > 0.5).all():
                break
        
        return h

6.2 与 MoD 的关系

Universal Transformer 是 MoD 的特例:

  • 所有 token 共享相同的层(权重)
  • 不同 token 可能迭代不同次数
  • 比 MoD 简单但表达能力受限

7. 训练策略对比

7.1 MoD 的训练挑战

  1. 路由器梯度稀疏:未选中的 token 不接收梯度
  2. 容量方差:每层的实际使用率波动
  3. 训练-推理不一致:路由器在训练时随机探索,推理时按权重选择

7.2 解决方案

解决 1:Top-k 路由器梯度直通

def topk_with_ste(weights, capacity):
    """Top-k with straight-through estimator"""
    # 前向:选择 top-k
    _, topk_indices = torch.topk(weights, capacity, dim=1)
    mask = torch.zeros_like(weights, dtype=torch.bool)
    mask.scatter_(1, topk_indices, True)
    
    # 反向:使用 STE 让梯度流过所有位置
    return mask + weights - weights.detach()

解决 2:路由器预训练

先训练主模型,固定权重,再训练路由器:

def pretrain_router(model, dataloader, num_epochs):
    # 冻结主模型
    for param in model.parameters():
        param.requires_grad = False
    
    # 只训练路由器
    routers = [layer.router for layer in model.layers]
    optimizer = torch.optim.Adam(
        [p for r in routers for p in r.parameters()], lr=1e-3
    )
    
    for epoch in range(num_epochs):
        for batch in dataloader:
            # 训练路由器
            loss = compute_router_loss(model, batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

7.3 训练曲线对比

训练步数StandardMoEMoD
1K50%52%48%
10K80%84%82%
100K95%98%99%
1M100%105%99%

MoD 最终性能接近标准模型,但训练初期较慢


8. 推理优化

8.1 连续批处理与 MoD

连续批处理(Continuous Batching)可与 MoD 协同优化:

class MoDContinuousBatchScheduler:
    """MoD 的连续批处理调度器"""
    def __init__(self, model, max_batch_size):
        self.model = model
        self.max_batch_size = max_batch_size
    
    def step(self, active_sequences):
        """
        active_sequences: 当前活跃的序列列表
        """
        # 1. 收集所有未完成的 tokens
        tokens = []
        for seq in active_sequences:
            if not seq.finished:
                tokens.append(seq.current_token)
        
        # 2. 批量处理
        batch = torch.stack(tokens)
        
        # 3. 模型处理(包括 MoD 路由)
        with torch.no_grad():
            logits = self.model(batch)
        
        # 4. 更新每个序列
        for i, seq in enumerate(active_sequences):
            next_token = sample(logits[i])
            seq.append(next_token)

8.2 KV Cache 优化

MoD 跳过的 token 不需要保存 KV Cache:

class MoDKVCache:
    """MoD 感知的 KV Cache"""
    def __init__(self, num_layers, hidden_dim):
        self.cache = [{} for _ in range(num_layers)]
    
    def store(self, layer_idx, token_indices, k, v):
        """仅存储被处理的 tokens"""
        if layer_idx not in self.cache:
            self.cache[layer_idx] = {}
        
        # 只存储被 MoD 选中的 tokens
        self.cache[layer_idx]['k'] = k
        self.cache[layer_idx]['v'] = v
    
    def get(self, layer_idx):
        return self.cache[layer_idx]

8.3 动态批大小

由于 MoD 的计算量随输入变化,可使用动态批大小:

def dynamic_batch_inference(model, sequences, max_compute=1.0):
    """
    根据当前计算预算动态调整批大小
    """
    # 估算每个序列的计算量
    compute_estimates = [estimate_compute(s) for s in sequences]
    
    # 按计算量分组
    high_compute = [s for s, c in zip(sequences, compute_estimates) if c > 0.7]
    low_compute = [s for s, c in zip(sequences, compute_estimates) if c <= 0.7]
    
    # 优先处理高计算量
    results = []
    for batch in [high_compute, low_compute]:
        results.extend(model(batch))
    
    return results

9. 实验分析

9.1 性能对比

模型FLOPs (T)延迟 (ms)困惑度
Standard Transformer10010010.0
MoE (8E)1001109.5
MoD (k=0.5)506010.1
MoD (k=0.75)75809.8

9.2 不同任务的计算分配

任务平均 k高 k 区域低 k 区域
语言建模0.5实体、数字、代码常见词、停用词
机器翻译0.6专有名词常用词
代码生成0.7关键字、函数名缩进、括号

9.3 训练动态

MoD 在训练中的演化:

  • 早期:路由器随机选择,每层 ~50% token
  • 中期:开始识别”重要” token
  • 后期:稳定的路由模式

10. 应用场景

10.1 推理优化

场景:客服对话、产品推荐

  • 简单查询(“你好”):快速响应
  • 复杂查询(“详细对比 X 和 Y”):深度推理

10.2 边缘部署

场景:移动端 AI 助手

  • 普通对话:低延迟模式
  • 复杂问题:高精度模式

10.3 多模态

场景:图像描述生成

  • 简单图像(人脸):低计算
  • 复杂场景(街景):高计算

11. 局限与挑战

11.1 路由器训练困难

路由器的稀疏性导致梯度稀疏,需要特殊训练技巧。

11.2 硬件不友好

动态深度计算难以在 GPU 上高效实现(需要分支预测)。

11.3 训练-推理不一致

训练时的随机探索与推理时的确定性选择可能不一致。

11.4 容量调度困难

如何在不同层之间分配容量是开放问题。


12. 未来展望

12.1 趋势 1:自适应计算成为标配

未来所有模型都会支持自适应计算:

  • 宽度自适应(MoE)
  • 深度自适应(MoD)
  • 精度自适应(量化)

12.2 趋势 2:硬件协同设计

新一代 AI 芯片可能原生支持:

  • 动态深度计算
  • Token 级的计算跳过
  • 自适应精度

12.3 趋势 3:统一自适应框架

未来可能存在统一的”自适应计算框架”:

  • 路由器选择宽度、深度、精度
  • 端到端优化
  • 任务自适应

12.4 趋势 4:与 MoE 的融合

MoD + MoE 的”双维自适应”:

  • 部分层用 MoE(宽度)
  • 部分层用 MoD(深度)

13. 总结

13.1 MoD 的核心贡献

  1. 理论创新:从 MoE 的宽度自适应扩展到深度自适应
  2. 架构简洁:在 Transformer 层前加路由器即可
  3. 效率突破:50% FLOPs 节省,性能几乎无损
  4. 训练策略:辅助损失、STE、容量调度

13.2 实践建议

  1. 首选 k=0.5:平衡效率与性能
  2. 辅助损失系数 起步
  3. 渐进式训练:先标准 Transformer,再启用 MoD
  4. 硬件友好部署:考虑用 MoD-Lite 替代

13.3 未来工作

  • 更高效的路由器设计
  • 硬件友好的 MoD 实现
  • MoD + MoE 的融合
  • 任务自适应的容量调度

参考

Footnotes

  1. Raposo et al., “Mixture-of-Depths: Dynamically allocating compute in transformer-based language models”, arXiv 2404.02258, 2024

  2. Graves, “Adaptive Computation Time for Recurrent Neural Networks”, arXiv 1603.08983, 2016