概述
传统 Transformer 对所有 token 施加相同的计算量。然而不同 token 的”难度”差异巨大:常见 token 易于预测,罕见/复杂 token 需要更多计算。动态深度计算(Dynamic-Depth Compute)通过为每个 token 自适应分配不同的层数,实现计算的最优分配。
Mixture-of-Depths (MoD) (Raposo et al., Google DeepMind, 2024) 1 是这一方向的代表性工作,灵感来自 MoE 但通过深度而非宽度的动态分配实现自适应计算。本文深入讲解 MoD 的核心机制、训练策略,并扩展到 Skip Transformer、PonderV2、MoD-Lite 等变体。
1. 动机与背景
1.1 传统 Transformer 的均匀计算
标准 Transformer 的每一层对所有 token 施加相同计算:
def transformer_forward(x):
for layer in self.layers:
x = layer(x) # 所有 token 都过这一层
return x问题:
- 计算浪费:简单 token(如 “the”、“of”)不需要 80 层推理
- 效率瓶颈:难以在不增加参数的情况下降低延迟
- 资源分配不均:所有 token 占用相同的推理时间和内存
1.2 计算分配的两种维度
| 维度 | 方法 | 代表工作 |
|---|---|---|
| 宽度(Width) | 专家混合 | MoE (Switch Transformer, GShard) |
| 深度(Depth) | 层数动态调整 | MoD, Skip Transformer, PonderV2 |
1.3 自适应计算时间(ACT)
自适应计算时间(Adaptive Computation Time, ACT)2 是早期工作,训练 RNN 在每个时间步动态决定迭代次数:
ACT 主要针对 RNN,难以直接迁移到 Transformer。
1.4 MoD 的核心思想
MoD 的关键洞察:将”路由”思想从宽度(MoE)扩展到深度(MoD):
- MoE:每个 token 选择部分专家(参数)
- MoD:每个 token 选择部分层(计算步骤)
2. Mixture-of-Depths (MoD)
2.1 基本架构
MoD 在 Transformer 的每一层前添加一个路由器,决定该 token 是否进入该层:
def mod_layer(x, router, transformer_block, residual):
"""
x: (B, N, D)
router: 路由网络,输出每个 token 的权重
transformer_block: 标准 Transformer 层
residual: 残差连接
"""
# 1. 路由器为每个 token 分配权重
weights = router(x) # (B, N, 1)
# 2. 选择 top-k 个 token 进入该层
capacity = int(k * N) # k 是每层容量比例
topk_weights, topk_indices = torch.topk(weights, capacity, dim=1)
# 3. 仅对选中的 token 应用 Transformer 层
x_selected = transformer_block(x[topk_indices])
# 4. 未选中的 token 通过残差
x_new = torch.zeros_like(x)
x_new[topk_indices] = x_selected * topk_weights
x_new[~topk_indices] = x[~topk_indices]
return x_new2.2 路由器设计
路由器是一个简单的线性层:
其中:
- :第 个 token 的表示
- :路由器为该 token 分配的”重要性分数”
- :sigmoid 函数
2.3 容量约束(Capacity Constraint)
为保证计算效率,每层的 token 通过率是固定的:
其中 是容量比例(如 )。这意味着:
- 每层最多处理 个 token
- 未被选中的 token 通过残差连接跳过该层
关键:所有层处理的 token 总数大致相同,保证了 FLOPs 的确定性。
2.4 训练策略
策略 1:辅助损失(Auxiliary Loss)
为鼓励路由器选择有意义的 token,添加辅助损失:
其中 是第 层实际处理的 token 数。
目标:使每层的实际使用率接近目标容量。
策略 2:专家选择 vs Token 选择
Token 选择(Token-Choice Routing):
- 每个 token 决定是否进入该层
- 实现简单但可能不均衡
专家选择(Expert-Choice Routing):
- 每层选择 top-k tokens(类似 MoE)
- 更均衡但需排序
MoD 默认使用专家选择风格:
def expert_choice_routing(weights, capacity):
"""
weights: (B, N, 1)
capacity: 每层应处理的 token 数
"""
B, N, _ = weights.shape
# 按权重排序所有 token
flat_weights = weights.squeeze(-1).view(-1) # (B*N,)
sorted_indices = torch.argsort(flat_weights, descending=True)
# 选择 top-capacity 个 token
selected = sorted_indices[:capacity]
# 还原为 (B, N) 形状的 mask
mask = torch.zeros_like(flat_weights, dtype=torch.bool)
mask[selected] = True
mask = mask.view(B, N)
return mask策略 3:Capacity Factor
定义每层的容量因子:
不同层可以使用不同的 :
- 浅层: 大(处理更多 token)
- 深层: 小(处理更少 token)
这反映了”浅层聚合局部信息,深层推理复杂语义”的直觉。
2.5 推理时的优化
def mod_inference(model, x, capacity_budget):
"""
推理时动态决定容量
"""
for layer in model.layers:
if layer.is_mod:
# 路由器预测权重
weights = layer.router(x)
# 根据预算选择 top-k
k = int(capacity_budget * x.size(1))
mask = select_topk(weights, k)
# 处理选中的 token
x = layer.transformer(x, mask)
else:
x = layer.transformer(x)
return x3. MoD vs MoE 深度对比
3.1 架构对比
| 维度 | MoE | MoD |
|---|---|---|
| 路由对象 | 专家(参数) | 层(计算步骤) |
| 路由维度 | 宽度 | 深度 |
| 路由器位置 | FFN 内 | 层前 |
| 容量约束 | 每专家 | 每层 |
3.2 计算对比
设模型有 层,每层有 个专家:
MoE:
MoD(设每层容量 ):
节省:相比 MoE,节省 的 FLOPs。
3.3 性能对比
| 模型 | FLOPs | 性能 |
|---|---|---|
| Standard Transformer | 1.0× | 100% |
| MoE (8 experts) | 1.0× | 105% |
| MoD (k=0.5) | 0.5× | 99% |
关键:MoD 在一半 FLOPs 下达到原模型 99% 性能。
4. MoD 的 PyTorch 实现
4.1 基础 MoD 层
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoDLayer(nn.Module):
"""单层 MoD"""
def __init__(self, hidden_dim, num_heads, capacity_ratio=0.5,
use_router=True):
super().__init__()
self.capacity_ratio = capacity_ratio
self.use_router = use_router
# Transformer 块
self.attention = nn.MultiheadAttention(hidden_dim, num_heads)
self.norm1 = nn.LayerNorm(hidden_dim)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, 4 * hidden_dim),
nn.GELU(),
nn.Linear(4 * hidden_dim, hidden_dim)
)
self.norm2 = nn.LayerNorm(hidden_dim)
# 路由器
if use_router:
self.router = nn.Linear(hidden_dim, 1)
def forward(self, x):
"""
x: (B, N, D)
"""
if not self.use_router or not self.training:
# 标准 Transformer 路径
attn_out, _ = self.attention(x, x, x)
x = self.norm1(x + attn_out)
x = self.norm2(x + self.ffn(x))
return x
# MoD 路径
B, N, D = x.shape
capacity = int(self.capacity_ratio * N)
# 1. 计算路由权重
weights = torch.sigmoid(self.router(x)) # (B, N, 1)
# 2. 选择 top-k tokens
flat_weights = weights.squeeze(-1) # (B, N)
topk_values, topk_indices = torch.topk(
flat_weights, capacity, dim=1
) # (B, capacity)
# 3. 提取选中的 tokens
batch_indices = torch.arange(B, device=x.device).unsqueeze(1)
x_selected = x[batch_indices, topk_indices] # (B, capacity, D)
# 4. 对选中的 tokens 应用 Transformer
attn_out, _ = self.attention(x_selected, x_selected, x_selected)
x_selected = self.norm1(x_selected + attn_out)
x_selected = self.norm2(x_selected + self.ffn(x_selected))
# 5. 加权残差
# 缩放因子保持期望值
scale = 1.0 / self.capacity_ratio
# 6. 写回原位置
x_out = x.clone()
x_out[batch_indices, topk_indices] = x_selected * scale
return x_out
class MoDTransformer(nn.Module):
"""完整 MoD Transformer"""
def __init__(self, vocab_size, hidden_dim, num_layers, num_heads,
capacity_ratio=0.5, use_mod=True):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_dim)
self.layers = nn.ModuleList()
for i in range(num_layers):
layer = MoDLayer(
hidden_dim, num_heads,
capacity_ratio=capacity_ratio if use_mod else 1.0,
use_router=use_mod
)
self.layers.append(layer)
self.output = nn.Linear(hidden_dim, vocab_size)
def forward(self, x):
x = self.embedding(x)
for layer in self.layers:
x = layer(x)
return self.output(x)4.2 带辅助损失的 MoD
class MoDTransformerWithAuxLoss(nn.Module):
"""带辅助损失的 MoD"""
def __init__(self, vocab_size, hidden_dim, num_layers, num_heads,
capacity_ratio=0.5, aux_loss_coef=0.01):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_dim)
self.layers = nn.ModuleList([
MoDLayer(hidden_dim, num_heads, capacity_ratio=capacity_ratio)
for _ in range(num_layers)
])
self.output = nn.Linear(hidden_dim, vocab_size)
self.aux_loss_coef = aux_loss_coef
def forward(self, x, return_aux_loss=False):
x = self.embedding(x)
aux_loss = 0.0
capacity_uses = []
for layer in self.layers:
x, used = layer(x)
capacity_uses.append(used)
# 计算容量方差损失
if return_aux_loss and len(capacity_uses) > 0:
mean_capacity = sum(capacity_uses) / len(capacity_uses)
aux_loss = sum(
(c - mean_capacity) ** 2 for c in capacity_uses
) / len(capacity_uses)
aux_loss *= self.aux_loss_coef
logits = self.output(x)
if return_aux_loss:
return logits, aux_loss
return logits5. MoD 的变体
5.1 MoD-Lite
问题:MoD 需要修改主模型架构,与现有推理框架不兼容。
MoD-Lite 的解决:训练后路由 + 跳过。
class MoDLiteLayer(nn.Module):
"""MoD-Lite: 轻量化版本,推理时无需特殊处理"""
def __init__(self, hidden_dim, num_heads):
super().__init__()
self.attention = nn.MultiheadAttention(hidden_dim, num_heads)
self.norm1 = nn.LayerNorm(hidden_dim)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, 4 * hidden_dim),
nn.GELU(),
nn.Linear(4 * hidden_dim, hidden_dim)
)
self.norm2 = nn.LayerNorm(hidden_dim)
# 预测是否跳过(推理时使用)
self.skip_predictor = nn.Linear(hidden_dim, 1)
def forward(self, x, force_skip=None):
"""
force_skip: (B, N) 强制跳过的 mask
"""
# 训练时:学习跳过预测
# 推理时:根据 skip_predictor 决定
if force_skip is not None:
# 应用强制跳过
mask = ~force_skip
x_selected = x[:, mask].reshape(x.size(0), -1, x.size(2))
# ... 处理 ...
return self.norm2(x + self.ffn(self.norm1(x + self.attention(x, x, x)[0])))5.2 Skip Transformer (ICLR 2025)
核心思想:在注意力计算中跳过部分 head/token:
class SkipAttention(nn.Module):
"""Skip Attention"""
def __init__(self, hidden_dim, num_heads, skip_ratio=0.3):
super().__init__()
self.num_heads = num_heads
self.skip_ratio = skip_ratio
self.qkv = nn.Linear(hidden_dim, 3 * hidden_dim)
self.proj = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x):
B, N, D = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, D // self.num_heads)
q, k, v = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, N, D/H)
attn = (q @ k.transpose(-2, -1)) / (D // self.num_heads) ** 0.5
# 跳过最不重要的 attention
if self.training:
# 训练时:随机跳过
skip_mask = torch.rand_like(attn) > self.skip_ratio
attn = attn.masked_fill(~skip_mask, float('-inf'))
attn = F.softmax(attn, dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, N, D)
return self.proj(out)5.3 PonderV2
核心思想:预测每个 token 的”思考时间”(需要多少层):
class PonderV2Layer(nn.Module):
"""PonderV2:自适应迭代"""
def __init__(self, hidden_dim, num_heads, max_iter=4):
super().__init__()
self.attention = nn.MultiheadAttention(hidden_dim, num_heads)
self.norm = nn.LayerNorm(hidden_dim)
self.ffn = nn.Sequential(
nn.Linear(hidden_dim, 4 * hidden_dim),
nn.GELU(),
nn.Linear(4 * hidden_dim, hidden_dim)
)
self.halt_predictor = nn.Linear(hidden_dim, 1)
self.max_iter = max_iter
def forward(self, x):
B, N, D = x.shape
h = x
cumulative_halt_prob = torch.zeros(B, N, 1, device=x.device)
for step in range(self.max_iter):
# 应用一层 Transformer
attn_out, _ = self.attention(h, h, h)
h = self.norm(h + attn_out)
h = self.norm(h + self.ffn(h))
# 预测停止概率
halt_prob = torch.sigmoid(self.halt_predictor(h))
cumulative_halt_prob = cumulative_halt_prob + halt_prob
# 检查是否停止
if (cumulative_halt_prob > 0.9).all():
break
return h5.4 Early Exit
核心思想:在中间层添加分类器,简单样本提前输出:
class EarlyExitTransformer(nn.Module):
"""早退 Transformer"""
def __init__(self, vocab_size, hidden_dim, num_layers, num_heads,
exit_layers=[4, 8, 12]):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_dim)
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(hidden_dim, num_heads)
for _ in range(num_layers)
])
self.exit_layers = exit_layers
self.exit_classifiers = nn.ModuleList([
nn.Linear(hidden_dim, vocab_size) for _ in exit_layers
])
self.confidence_threshold = 0.9
def forward(self, x):
x = self.embedding(x)
for i, layer in enumerate(self.layers):
x = layer(x)
# 检查是否是早退层
if (i + 1) in self.exit_layers:
logits = self.exit_classifiers[self.exit_layers.index(i + 1)](x)
confidence = F.softmax(logits, dim=-1).max(dim=-1)[0].mean()
if confidence > self.confidence_threshold:
return logits
return self.exit_classifiers[-1](x)6. Universal Transformer
6.1 核心思想
Universal Transformer (Dehghani et al., 2019) 将 Transformer 层在时间维度共享权重:
class UniversalTransformer(nn.Module):
"""Universal Transformer:权重共享的循环 Transformer"""
def __init__(self, vocab_size, hidden_dim, num_heads, max_steps=8):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_dim)
# 单一 Transformer 层(共享权重)
self.transformer = nn.TransformerEncoderLayer(hidden_dim, num_heads)
self.max_steps = max_steps
self.halt_predictor = nn.Linear(hidden_dim, 1)
def forward(self, x):
x = self.embedding(x)
h = x
for step in range(self.max_steps):
# 应用共享的 Transformer
h_new = self.transformer(h)
# 决定是否停止
halt_prob = torch.sigmoid(self.halt_predictor(h_new))
# 加权更新
h = h + halt_prob * (h_new - h)
# 检查停止条件
if (halt_prob > 0.5).all():
break
return h6.2 与 MoD 的关系
Universal Transformer 是 MoD 的特例:
- 所有 token 共享相同的层(权重)
- 不同 token 可能迭代不同次数
- 比 MoD 简单但表达能力受限
7. 训练策略对比
7.1 MoD 的训练挑战
- 路由器梯度稀疏:未选中的 token 不接收梯度
- 容量方差:每层的实际使用率波动
- 训练-推理不一致:路由器在训练时随机探索,推理时按权重选择
7.2 解决方案
解决 1:Top-k 路由器梯度直通
def topk_with_ste(weights, capacity):
"""Top-k with straight-through estimator"""
# 前向:选择 top-k
_, topk_indices = torch.topk(weights, capacity, dim=1)
mask = torch.zeros_like(weights, dtype=torch.bool)
mask.scatter_(1, topk_indices, True)
# 反向:使用 STE 让梯度流过所有位置
return mask + weights - weights.detach()解决 2:路由器预训练
先训练主模型,固定权重,再训练路由器:
def pretrain_router(model, dataloader, num_epochs):
# 冻结主模型
for param in model.parameters():
param.requires_grad = False
# 只训练路由器
routers = [layer.router for layer in model.layers]
optimizer = torch.optim.Adam(
[p for r in routers for p in r.parameters()], lr=1e-3
)
for epoch in range(num_epochs):
for batch in dataloader:
# 训练路由器
loss = compute_router_loss(model, batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()7.3 训练曲线对比
| 训练步数 | Standard | MoE | MoD |
|---|---|---|---|
| 1K | 50% | 52% | 48% |
| 10K | 80% | 84% | 82% |
| 100K | 95% | 98% | 99% |
| 1M | 100% | 105% | 99% |
MoD 最终性能接近标准模型,但训练初期较慢。
8. 推理优化
8.1 连续批处理与 MoD
连续批处理(Continuous Batching)可与 MoD 协同优化:
class MoDContinuousBatchScheduler:
"""MoD 的连续批处理调度器"""
def __init__(self, model, max_batch_size):
self.model = model
self.max_batch_size = max_batch_size
def step(self, active_sequences):
"""
active_sequences: 当前活跃的序列列表
"""
# 1. 收集所有未完成的 tokens
tokens = []
for seq in active_sequences:
if not seq.finished:
tokens.append(seq.current_token)
# 2. 批量处理
batch = torch.stack(tokens)
# 3. 模型处理(包括 MoD 路由)
with torch.no_grad():
logits = self.model(batch)
# 4. 更新每个序列
for i, seq in enumerate(active_sequences):
next_token = sample(logits[i])
seq.append(next_token)8.2 KV Cache 优化
MoD 跳过的 token 不需要保存 KV Cache:
class MoDKVCache:
"""MoD 感知的 KV Cache"""
def __init__(self, num_layers, hidden_dim):
self.cache = [{} for _ in range(num_layers)]
def store(self, layer_idx, token_indices, k, v):
"""仅存储被处理的 tokens"""
if layer_idx not in self.cache:
self.cache[layer_idx] = {}
# 只存储被 MoD 选中的 tokens
self.cache[layer_idx]['k'] = k
self.cache[layer_idx]['v'] = v
def get(self, layer_idx):
return self.cache[layer_idx]8.3 动态批大小
由于 MoD 的计算量随输入变化,可使用动态批大小:
def dynamic_batch_inference(model, sequences, max_compute=1.0):
"""
根据当前计算预算动态调整批大小
"""
# 估算每个序列的计算量
compute_estimates = [estimate_compute(s) for s in sequences]
# 按计算量分组
high_compute = [s for s, c in zip(sequences, compute_estimates) if c > 0.7]
low_compute = [s for s, c in zip(sequences, compute_estimates) if c <= 0.7]
# 优先处理高计算量
results = []
for batch in [high_compute, low_compute]:
results.extend(model(batch))
return results9. 实验分析
9.1 性能对比
| 模型 | FLOPs (T) | 延迟 (ms) | 困惑度 |
|---|---|---|---|
| Standard Transformer | 100 | 100 | 10.0 |
| MoE (8E) | 100 | 110 | 9.5 |
| MoD (k=0.5) | 50 | 60 | 10.1 |
| MoD (k=0.75) | 75 | 80 | 9.8 |
9.2 不同任务的计算分配
| 任务 | 平均 k | 高 k 区域 | 低 k 区域 |
|---|---|---|---|
| 语言建模 | 0.5 | 实体、数字、代码 | 常见词、停用词 |
| 机器翻译 | 0.6 | 专有名词 | 常用词 |
| 代码生成 | 0.7 | 关键字、函数名 | 缩进、括号 |
9.3 训练动态
MoD 在训练中的演化:
- 早期:路由器随机选择,每层 ~50% token
- 中期:开始识别”重要” token
- 后期:稳定的路由模式
10. 应用场景
10.1 推理优化
场景:客服对话、产品推荐
- 简单查询(“你好”):快速响应
- 复杂查询(“详细对比 X 和 Y”):深度推理
10.2 边缘部署
场景:移动端 AI 助手
- 普通对话:低延迟模式
- 复杂问题:高精度模式
10.3 多模态
场景:图像描述生成
- 简单图像(人脸):低计算
- 复杂场景(街景):高计算
11. 局限与挑战
11.1 路由器训练困难
路由器的稀疏性导致梯度稀疏,需要特殊训练技巧。
11.2 硬件不友好
动态深度计算难以在 GPU 上高效实现(需要分支预测)。
11.3 训练-推理不一致
训练时的随机探索与推理时的确定性选择可能不一致。
11.4 容量调度困难
如何在不同层之间分配容量是开放问题。
12. 未来展望
12.1 趋势 1:自适应计算成为标配
未来所有模型都会支持自适应计算:
- 宽度自适应(MoE)
- 深度自适应(MoD)
- 精度自适应(量化)
12.2 趋势 2:硬件协同设计
新一代 AI 芯片可能原生支持:
- 动态深度计算
- Token 级的计算跳过
- 自适应精度
12.3 趋势 3:统一自适应框架
未来可能存在统一的”自适应计算框架”:
- 路由器选择宽度、深度、精度
- 端到端优化
- 任务自适应
12.4 趋势 4:与 MoE 的融合
MoD + MoE 的”双维自适应”:
- 部分层用 MoE(宽度)
- 部分层用 MoD(深度)
13. 总结
13.1 MoD 的核心贡献
- 理论创新:从 MoE 的宽度自适应扩展到深度自适应
- 架构简洁:在 Transformer 层前加路由器即可
- 效率突破:50% FLOPs 节省,性能几乎无损
- 训练策略:辅助损失、STE、容量调度
13.2 实践建议
- 首选 k=0.5:平衡效率与性能
- 辅助损失系数: 起步
- 渐进式训练:先标准 Transformer,再启用 MoD
- 硬件友好部署:考虑用 MoD-Lite 替代
13.3 未来工作
- 更高效的路由器设计
- 硬件友好的 MoD 实现
- MoD + MoE 的融合
- 任务自适应的容量调度