title: MoR混合递归Transformer
date: 2026-05-07
description: 通过参数共享与自适应计算统一实现高效Transformer的方法
tags:

  • transformer
  • recursive-computation
  • mixture-of-experts
    draft: false
    permalink:

MoR混合递归Transformer

概述

MoR(Mixture-of-Recursions)是一种统一框架,同时实现参数共享自适应计算两种效率维度。1

传统Transformer面临参数量和计算量的双重挑战,MoR通过递归Transformer架构和轻量级路由器,在135M到1.7B参数规模上建立新的帕累托前沿。

核心思想

双轴效率

Transformer效率优化通常关注两个独立方向:

  1. 参数效率:通过权重共享减少参数量
  2. 计算效率:通过自适应计算减少FLOPs

现有方法只能优化其中一个维度

  • MoE优化参数效率,但计算量不变
  • 动态深度/跳过方法优化计算效率,但参数量不变

MoR的统一框架

MoR的核心洞察:递归是同时实现双轴效率的自然机制

        ┌─────────────────────────────────────┐
        │         共享Transformer层栈          │
        │    (参数效率:O(L)层 vs O(L×D)参数)   │
        └─────────────┬───────────────────────┘
                      │
                      ▼
        ┌─────────────────────────────────────┐
        │         轻量级递归路由器              │
        │    (计算效率:动态分配递归深度)         │
        └─────────────┬───────────────────────┘
                      │
          ┌───────────┼───────────┐
          ▼           ▼           ▼
       Token A     Token B     Token C
       (D=3)      (D=2)      (D=1)
     递归3次      递归2次      递归1次

方法详解

递归Transformer架构

MoR使用共享层栈进行递归计算:

其中 是递归深度, 是输入。

自适应递归深度分配

关键创新:轻量级路由器为每个token动态分配递归深度:

class RecursionRouter(nn.Module):
    def __init__(self, d_model, max_depth, depth_hidden_dim=64):
        super().__init__()
        self.max_depth = max_depth
        
        # 极简路由器
        self.router = nn.Sequential(
            nn.Linear(d_model, depth_hidden_dim),
            nn.SiLU(),
            nn.Linear(depth_hidden_dim, max_depth),
        )
        
    def forward(self, x):
        # x: [B, N, D]
        logits = self.router(x)  # [B, N, max_depth]
        
        # 采样递归深度(训练时随机,推理时贪婪)
        if self.training:
            # Gumbel-Softmax采样
            gumbels = -torch.empty_like(logits).exponential_().log()
            logits = logits + gumbels
            depths = F.softmax(logits, dim=-1)
        else:
            # 贪婪选择
            depths = F.one_hot(logits.argmax(dim=-1), self.max_depth).float()
            
        return depths  # [B, N, max_depth]

稀疏注意力机制

递归深度分配后,MoR对活跃token应用注意力:

class MoRLayer(nn.Module):
    def __init__(self, d_model, n_heads, max_depth):
        super().__init__()
        self.shared_layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, n_heads)
            for _ in range(3)  # 基础层数
        ])
        self.router = RecursionRouter(d_model, max_depth)
        self.max_depth = max_depth
        
    def forward(self, x, attention_mask=None):
        B, N, D = x.shape
        
        # 获取每个token的递归深度
        depth_probs = self.router(x)  # [B, N, max_depth]
        
        # 逐层递归处理
        h = x
        for k in range(self.max_depth):
            # 选择该层需要处理的token
            active_mask = depth_probs[:, :, k] > 0.5  # [B, N]
            
            if active_mask.sum() == 0:
                continue
                
            # 应用共享Transformer层
            layer = self.shared_layers[k % len(self.shared_layers)]
            h_active = h.clone()
            h_active[~active_mask] = 0  # 屏蔽非活跃token
            
            h_new = layer(h_active)
            
            # 合并结果
            h = torch.where(active_mask.unsqueeze(-1), h_new, h)
            
        return h

KV缓存优化

MoR提出KV共享变体,专门优化预填充延迟:

class MoRWithKVSharing(nn.Module):
    """
    KV共享:所有递归步骤复用第一个token的KV
    专门用于降低预填充延迟
    """
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model
        
        # 只为第一个token维护KV
        self.first_token_kv = None
        
    def forward(self, x):
        if self.first_token_kv is None:
            # 初始化:处理完整序列
            self.first_token_kv = self.base_model(x[:, :1])
            
        # 后续token使用共享KV
        rest_out = self.base_model.rest_layers(x[:, 1:])
        
        return torch.cat([self.first_token_kv, rest_out], dim=1)

理论分析

参数效率

共享层数为 ,最大递归深度为

配置实际层数参数量
标准Transformer
MoR

参数量减少因子:

计算效率

活跃token比例为 时:

通过动态 ,实现自适应计算。

实验结果

帕累托前沿

模型参数量训练FLOPs验证困惑度
Vanilla-135M135M1.0×24.3
MoR-135M135M0.6×23.8
Vanilla-410M410M3.0×20.1
MoR-410M410M1.8×19.6
Vanilla-1B1B7.5×17.8
MoR-1B1B4.5×17.2
Vanilla-1.7B1.7B12.0×16.5
MoR-1.7B1.7B7.0×15.9

MoR在所有规模上建立新的帕累托前沿。

Few-shot性能

模型LAMBADAPIQAHellaSwag
Vanilla-410M58.271.344.1
MoR-410M60.172.845.6

吞吐量提升

模型吞吐量(tokens/s)相对提升
Vanilla-1B1001.0×
MoR-1B1561.56×

与其他方法的对比

方法参数效率计算效率两者统一
Transformer
MoE
动态深度
MoR

实现指南

配置推荐

# 小模型(<500M):更少共享层,更多递归
config_mor_small = {
    "shared_layers": 6,
    "max_depth": 4,
    "router_hidden": 32,
}
 
# 大模型(>1B):更多共享层,适度递归
config_mor_large = {
    "shared_layers": 24,
    "max_depth": 6,
    "router_hidden": 64,
}

训练技巧

  1. 课程学习:从浅递归逐渐增加深度
  2. Gumbel温度退火:从高温度逐渐降低
  3. 正则化:防止所有token收敛到相同深度

参考资料

相关链接

Footnotes

  1. “Mixture-of-Recursions: Learning Dynamic Recursive Depths for Adaptive Token-Level Computation” arXiv:2507.10524