title: MoR混合递归Transformer
date: 2026-05-07
description: 通过参数共享与自适应计算统一实现高效Transformer的方法
tags:
- transformer
- recursive-computation
- mixture-of-experts
draft: false
permalink:
MoR混合递归Transformer
概述
MoR(Mixture-of-Recursions)是一种统一框架,同时实现参数共享和自适应计算两种效率维度。1
传统Transformer面临参数量和计算量的双重挑战,MoR通过递归Transformer架构和轻量级路由器,在135M到1.7B参数规模上建立新的帕累托前沿。
核心思想
双轴效率
Transformer效率优化通常关注两个独立方向:
- 参数效率:通过权重共享减少参数量
- 计算效率:通过自适应计算减少FLOPs
现有方法只能优化其中一个维度:
- MoE优化参数效率,但计算量不变
- 动态深度/跳过方法优化计算效率,但参数量不变
MoR的统一框架
MoR的核心洞察:递归是同时实现双轴效率的自然机制
┌─────────────────────────────────────┐
│ 共享Transformer层栈 │
│ (参数效率:O(L)层 vs O(L×D)参数) │
└─────────────┬───────────────────────┘
│
▼
┌─────────────────────────────────────┐
│ 轻量级递归路由器 │
│ (计算效率:动态分配递归深度) │
└─────────────┬───────────────────────┘
│
┌───────────┼───────────┐
▼ ▼ ▼
Token A Token B Token C
(D=3) (D=2) (D=1)
递归3次 递归2次 递归1次
方法详解
递归Transformer架构
MoR使用共享层栈进行递归计算:
其中 是递归深度, 是输入。
自适应递归深度分配
关键创新:轻量级路由器为每个token动态分配递归深度:
class RecursionRouter(nn.Module):
def __init__(self, d_model, max_depth, depth_hidden_dim=64):
super().__init__()
self.max_depth = max_depth
# 极简路由器
self.router = nn.Sequential(
nn.Linear(d_model, depth_hidden_dim),
nn.SiLU(),
nn.Linear(depth_hidden_dim, max_depth),
)
def forward(self, x):
# x: [B, N, D]
logits = self.router(x) # [B, N, max_depth]
# 采样递归深度(训练时随机,推理时贪婪)
if self.training:
# Gumbel-Softmax采样
gumbels = -torch.empty_like(logits).exponential_().log()
logits = logits + gumbels
depths = F.softmax(logits, dim=-1)
else:
# 贪婪选择
depths = F.one_hot(logits.argmax(dim=-1), self.max_depth).float()
return depths # [B, N, max_depth]稀疏注意力机制
递归深度分配后,MoR对活跃token应用注意力:
class MoRLayer(nn.Module):
def __init__(self, d_model, n_heads, max_depth):
super().__init__()
self.shared_layers = nn.ModuleList([
TransformerEncoderLayer(d_model, n_heads)
for _ in range(3) # 基础层数
])
self.router = RecursionRouter(d_model, max_depth)
self.max_depth = max_depth
def forward(self, x, attention_mask=None):
B, N, D = x.shape
# 获取每个token的递归深度
depth_probs = self.router(x) # [B, N, max_depth]
# 逐层递归处理
h = x
for k in range(self.max_depth):
# 选择该层需要处理的token
active_mask = depth_probs[:, :, k] > 0.5 # [B, N]
if active_mask.sum() == 0:
continue
# 应用共享Transformer层
layer = self.shared_layers[k % len(self.shared_layers)]
h_active = h.clone()
h_active[~active_mask] = 0 # 屏蔽非活跃token
h_new = layer(h_active)
# 合并结果
h = torch.where(active_mask.unsqueeze(-1), h_new, h)
return hKV缓存优化
MoR提出KV共享变体,专门优化预填充延迟:
class MoRWithKVSharing(nn.Module):
"""
KV共享:所有递归步骤复用第一个token的KV
专门用于降低预填充延迟
"""
def __init__(self, base_model):
super().__init__()
self.base_model = base_model
# 只为第一个token维护KV
self.first_token_kv = None
def forward(self, x):
if self.first_token_kv is None:
# 初始化:处理完整序列
self.first_token_kv = self.base_model(x[:, :1])
# 后续token使用共享KV
rest_out = self.base_model.rest_layers(x[:, 1:])
return torch.cat([self.first_token_kv, rest_out], dim=1)理论分析
参数效率
共享层数为 ,最大递归深度为 :
| 配置 | 实际层数 | 参数量 |
|---|---|---|
| 标准Transformer | ||
| MoR |
参数量减少因子:
计算效率
活跃token比例为 时:
通过动态 和 ,实现自适应计算。
实验结果
帕累托前沿
| 模型 | 参数量 | 训练FLOPs | 验证困惑度 |
|---|---|---|---|
| Vanilla-135M | 135M | 1.0× | 24.3 |
| MoR-135M | 135M | 0.6× | 23.8 |
| Vanilla-410M | 410M | 3.0× | 20.1 |
| MoR-410M | 410M | 1.8× | 19.6 |
| Vanilla-1B | 1B | 7.5× | 17.8 |
| MoR-1B | 1B | 4.5× | 17.2 |
| Vanilla-1.7B | 1.7B | 12.0× | 16.5 |
| MoR-1.7B | 1.7B | 7.0× | 15.9 |
MoR在所有规模上建立新的帕累托前沿。
Few-shot性能
| 模型 | LAMBADA | PIQA | HellaSwag |
|---|---|---|---|
| Vanilla-410M | 58.2 | 71.3 | 44.1 |
| MoR-410M | 60.1 | 72.8 | 45.6 |
吞吐量提升
| 模型 | 吞吐量(tokens/s) | 相对提升 |
|---|---|---|
| Vanilla-1B | 100 | 1.0× |
| MoR-1B | 156 | 1.56× |
与其他方法的对比
| 方法 | 参数效率 | 计算效率 | 两者统一 |
|---|---|---|---|
| Transformer | ✗ | ✗ | ✗ |
| MoE | ✓ | ✗ | ✗ |
| 动态深度 | ✗ | ✓ | ✗ |
| MoR | ✓ | ✓ | ✓ |
实现指南
配置推荐
# 小模型(<500M):更少共享层,更多递归
config_mor_small = {
"shared_layers": 6,
"max_depth": 4,
"router_hidden": 32,
}
# 大模型(>1B):更多共享层,适度递归
config_mor_large = {
"shared_layers": 24,
"max_depth": 6,
"router_hidden": 64,
}训练技巧
- 课程学习:从浅递归逐渐增加深度
- Gumbel温度退火:从高温度逐渐降低
- 正则化:防止所有token收敛到相同深度
参考资料
相关链接
Footnotes
-
“Mixture-of-Recursions: Learning Dynamic Recursive Depths for Adaptive Token-Level Computation” arXiv:2507.10524 ↩