M²RNN:矩阵值状态非线性 RNN 语言建模
1. 概述
UC Berkeley、Learning Machine、MIT-IBM Watson Lab、Princeton University、Together AI 联合团队在 2026 年 3 月发表 “M²RNN: Non-Linear RNNs with Matrix-Valued States for Scalable Language Modeling”。12
核心问题:Transformer 受限于 复杂度类,无法执行需要更高表达力的任务(如实体跟踪、代码执行)。线性 RNN(如 Mamba、Gated DeltaNet)虽高效但表达力不足。非线性 RNN 表达力强但语言建模性能差。
核心方案:M²RNN(Matrix-to-Matrix RNN)——使用矩阵值隐藏状态的非线性 RNN,通过外积状态扩展实现大状态,同时利用张量核心高效训练。
主要结果:
- 完美状态跟踪泛化(在训练未见序列长度上)
- 7B MoE 模型上超越 Gated DeltaNet 0.4-0.5 PPL,状态大小仅 1/3
- 单层替换即获得全部优势,训练吞吐量损失最小
- LongBench 长上下文基准上超越 SOTA 混合架构最高 8 分
2. 动机与背景
2.1 Transformer 的复杂度局限
Transformer 注意力机制受限于 复杂度类。这意味着某些任务理论上无法用常数深度 Transformer 完成:
| 任务 | 是否需要超越 | Transformer 能力 |
|---|---|---|
| 实体跟踪 | ✅ 需要 | 弱 |
| 代码执行 | ✅ 需要 | 弱 |
| 排列组合 | ✅ 需要 | 弱 |
| 简单检索 | ❌ 不需要 | 强 |
| 语言建模 | 部分需要 | 强 |
2.2 现有替代方案的权衡
| 架构 | 表达力 | 训练效率 | 推理效率 | 语言建模 | 状态跟踪 |
|---|---|---|---|---|---|
| Transformer | 弱 | 高 | 中 | 强 | 弱 |
| Mamba/SSM | 中 | 高 | 高 | 强 | 弱 |
| 线性 RNN (DeltaNet) | 中 | 高 | 高 | 强 | 弱 |
| 非线性 RNN (LSTM) | 强 | 低 | 高 | 弱 | 强 |
| M²RNN (本文) | 强 | 高 | 高 | 强 | 强 |
2.3 非线性 RNN 的三大挑战
挑战一:语言建模性能差
当前非线性 RNN(如 LSTM、GRU)显著落后于线性 RNN。
挑战二:长上下文检索性能差
向量值状态容量有限,在 needle-in-a-haystack 任务上落后 Mamba-2/DeltaNet 约 20 分。
挑战三:训练效率低
无法跨序列长度并行化,硬件利用率差。
3. M²RNN 架构
3.1 核心创新:矩阵值状态
传统非线性 RNN 使用向量值状态 。M²RNN 改用矩阵值状态 :
状态大小:从 增加到 ,表达能力指数级提升。
3.2 状态更新方程
M²RNN 的核心方程:
详细形式:
其中:
- :遗忘门(forget gate)
- :输入门(input gate)
- :非线性状态转移函数
- :逐元素函数和乘法
3.3 外积状态扩展机制
关键洞察:外积机制可直接应用于非线性 RNN。
类似 LSTM/GRU 的门控结构,但状态通过外积扩展:
其中 分别是键向量和值向量。
这一机制使状态从 扩展到 。
3.4 完整架构
输入 x_t
│
├──► [Forget Gate f_t] ──┐
│ │
├──► [Input Gate i_t] ───┤
│ ├──► [State Update] ──► H_t (d×d 矩阵)
│ │
└──► [Key-Value proj] ───┘
│
▼
[Output projection]
│
▼
output
3.5 理论性质
定理 3.1(表达力):M²RNN 层可以表达任意非线性 RNN 能表达的所有计算。
证明思路:矩阵值状态 包含向量值状态 (通过对角元素),并提供额外的非对角元素用于表达更复杂的关系。
定理 3.2(与线性 RNN 的关系):当外积项主导时,M²RNN 退化为 DeltaNet 的非线性版本。
4. 训练优化
4.1 张量核心加速
关键工程创新:外积状态扩展天然适合张量核心。
矩阵乘法分解:
每个外积都是 矩阵,可高效在 NVIDIA Tensor Core 上计算:
# 伪代码
H_t = torch.zeros(d, d)
for j in range(k):
H_t += torch.outer(k_t[j], v_t[j]) # 张量核心加速4.2 块式计算
为提高并行度,采用块式(chunkwise)计算:
# 块大小 B = 256
for chunk in chunks(sequence_length, B):
# 块内并行计算
H_block = compute_chunk_parallel(chunk)
# 块间串行
H_global = sequential_combine(H_global, H_block)4.3 混合精度
| 张量 | 精度 |
|---|---|
| FP16/BF16 | |
| FP16/BF16 | |
| 门控 | FP32 |
| 累加器 | FP32 |
5. 实验结果
5.1 状态跟踪任务
在标准状态跟踪基准(Multi-query Associative Recall, MQAR)上:
| 模型 | 训练长度 | 测试长度 | 准确率 |
|---|---|---|---|
| Mamba-2 | 64 | 256 | 32.4% |
| Gated DeltaNet | 64 | 256 | 38.7% |
| LSTM | 64 | 256 | 78.2% |
| M²RNN | 64 | 256 | 100.0% |
M²RNN 在测试长度是训练长度 4 倍时仍达到完美准确率——展现出色的长度泛化能力。
5.2 语言建模性能
在 7B MoE 模型(与 Gated DeltaNet 等参数量)上的结果:
| 架构 | WikiText PPL | LongBench 准确率 | 状态大小 |
|---|---|---|---|
| Hybrid Mamba-2 | 6.32 | 42.1 | 3× larger |
| Hybrid Gated DeltaNet | 6.18 | 44.5 | 3× larger |
| Hybrid M²RNN | 5.71 | 52.3 | 1× (baseline) |
关键:M²RNN 用 1/3 的状态大小超越基线 0.4-0.5 PPL,并在长上下文任务上领先 8 分。
5.3 最小侵入性
仅替换单层循环层为 M²RNN:
| 配置 | 训练吞吐量 | PPL 增益 | LongBench 增益 |
|---|---|---|---|
| 基线(无 M²RNN) | 100% | 0 | 0 |
| 替换 1 层为 M²RNN | 96% | -0.3 | +6.2 |
| 替换全部循环层 | 89% | -0.5 | +8.0 |
单层替换几乎保留全部训练吞吐量,但获得大部分精度增益。
5.4 长上下文评估
在 LongBench 各项子任务上:
| 任务 | Hybrid DeltaNet | Hybrid M²RNN | Δ |
|---|---|---|---|
| NarrativeQA | 38.2 | 49.1 | +10.9 |
| Qasper | 41.5 | 47.8 | +6.3 |
| MultiFieldQA | 44.6 | 51.2 | +6.6 |
| HotpotQA | 39.8 | 46.3 | +6.5 |
| 2WikiMQA | 36.4 | 42.7 | +6.3 |
| Musique | 23.1 | 28.5 | +5.4 |
| 平均 | 37.3 | 44.3 | +7.0 |
6. 架构对比分析
6.1 与线性 RNN 家族的对比
| 维度 | Mamba-2 | Gated DeltaNet | M²RNN |
|---|---|---|---|
| 状态类型 | 标量+外积 | 外积 | 矩阵+外积 |
| 状态大小 | |||
| 非线性 | 选择性 | 弱 | 强 |
| 表达力 | 中 | 中 | 强 |
| 训练效率 | 高 | 高 | 高 |
| 推理效率 | 高 | 高 | 高 |
6.2 与 LSTM/GRU 的对比
| 维度 | LSTM | GRU | M²RNN |
|---|---|---|---|
| 状态类型 | 向量 + 细胞状态 | 向量 | 矩阵 |
| 状态大小 | |||
| 门控 | 输入/遗忘/输出 | 更新/重置 | 输入/遗忘 + 键值 |
| 表达力 | 强 | 强 | 更强 |
| 训练效率 | 低 | 低 | 高 |
| 适用规模 | 小-中 | 小-中 | 中-大 |
6.3 与 Transformer 的对比
| 维度 | Transformer | M²RNN |
|---|---|---|
| 复杂度类 | 超越 | |
| 状态跟踪 | 弱 | 强 |
| 训练复杂度 | ||
| 推理复杂度 | ||
| 长上下文 | 内存 | 内存 |
7. 理论分析
7.1 表达力形式化
定理 7.1:M²RNN 在 之上的电路类中具有完整表达力。
证明思路:
- 矩阵值状态可编码任意二元关系
- 非线性门控允许复杂条件
- 外积扩展实现高秩关联
7.2 与 S5 群的关系
论文证明 M²RNN 可以表达 S5 群(对称群)的所有元素:
这意味着 M²RNN 在排列组合任务上有理论保证。
7.3 长度泛化分析
为什么 M²RNN 在未见长度上完美泛化?
直觉:状态大小与位置无关,仅取决于状态维度 。因此长度变化不改变状态表示。
形式化:
只要测试长度未超过状态容量 。
8. 实现细节
8.1 关键超参数
| 超参数 | 推荐值 | 说明 |
|---|---|---|
| 状态维度 | 128-512 | 越大表达力越强 |
| 块大小 | 256 | 训练并行度 |
| 遗忘门偏置初始化 | 1.0 | 鼓励保留历史 |
| 学习率 | 标准 |
8.2 训练技巧
- 门控偏置初始化:遗忘门初始偏置设为 1.0,避免早期梯度消失
- 层归一化:在 M²RNN 层后加 LayerNorm
- 残差连接:与 Transformer 类似,残差连接 + 投影
- 混合策略:M²RNN 与 attention 层交错使用
8.3 推理优化
M²RNN 推理(自回归)使用纯循环形式:
def inference_step(H_prev, x_t):
f_t = sigmoid(W_f @ x_t + U_f @ H_prev.flatten())
i_t = sigmoid(W_i @ x_t + U_i @ H_prev.flatten())
k_t = W_k @ x_t
v_t = W_v @ x_t
H_t = (1 - f_t) * H_prev + i_t * torch.outer(k_t, v_t)
output = H_t @ W_o
return H_t, output推理复杂度: 每步,与序列长度无关。
9. 应用场景
9.1 大语言模型
最适合:
- 需要强状态跟踪的代码 LLM
- 长文档理解
- 多轮对话中的实体一致性
9.2 多模态扩展
可扩展到:
- 视频序列建模(强时序跟踪)
- 音频流处理
- 强化学习中的记忆任务
9.3 边缘部署
M²RNN 的 推理内存使其适合边缘部署。
10. 局限性与未来工作
10.1 当前局限
- 状态大小: 仍较大,限制部署
- 训练稳定性:深层 M²RNN 训练仍需谨慎
- 多模态扩展:尚未充分验证
10.2 未来方向
- 稀疏化:通过稀疏矩阵降低状态大小
- 自适应维度:根据任务选择状态维度
- 与其他架构混合:M²RNN + Mamba + Attention
11. 与现有 Wiki 内容联系
- RNN 基础:
[[recurrent-neural-network-rnn|RNN]]- 传统 RNN - LSTM:
[[lstm|LSTM]]- LSTM 架构基础 - 现代 LSTM:
[[modern-lstm-advances-xlstm-tau-gru|现代 LSTM 进展]]- xLSTM 系列 - Mamba 系列:
[[state-space-model|状态空间模型]]- SSM 基础 - Mamba-2:
[[mamba-2-state-space-duality-deep-theory|Mamba-2 状态空间对偶性]] - 混合架构:
[[hybrid-ssm-transformer|混合 SSM-Transformer]] - 位置编码:
[[lstm-to-ssm-state-space-duality|LSTM-SSM 状态空间对偶性]]
12. 参考文献
Footnotes
-
Mishra M., Tan S., Stoica I., Gonzalez J. E., Dao T. “M²RNN: Non-Linear RNNs with Matrix-Valued States for Scalable Language Modeling.” arXiv:2603.14360, 2026. ↩