M²RNN:矩阵值状态非线性 RNN 语言建模

1. 概述

UC Berkeley、Learning Machine、MIT-IBM Watson Lab、Princeton University、Together AI 联合团队在 2026 年 3 月发表 “M²RNN: Non-Linear RNNs with Matrix-Valued States for Scalable Language Modeling”。12

核心问题:Transformer 受限于 复杂度类,无法执行需要更高表达力的任务(如实体跟踪、代码执行)。线性 RNN(如 Mamba、Gated DeltaNet)虽高效但表达力不足。非线性 RNN 表达力强但语言建模性能差。

核心方案M²RNN(Matrix-to-Matrix RNN)——使用矩阵值隐藏状态的非线性 RNN,通过外积状态扩展实现大状态,同时利用张量核心高效训练。

主要结果

  • 完美状态跟踪泛化(在训练未见序列长度上)
  • 7B MoE 模型上超越 Gated DeltaNet 0.4-0.5 PPL,状态大小仅 1/3
  • 单层替换即获得全部优势,训练吞吐量损失最小
  • LongBench 长上下文基准上超越 SOTA 混合架构最高 8 分

2. 动机与背景

2.1 Transformer 的复杂度局限

Transformer 注意力机制受限于 复杂度类。这意味着某些任务理论上无法用常数深度 Transformer 完成:

任务是否需要超越 Transformer 能力
实体跟踪✅ 需要
代码执行✅ 需要
排列组合✅ 需要
简单检索❌ 不需要
语言建模部分需要

2.2 现有替代方案的权衡

架构表达力训练效率推理效率语言建模状态跟踪
Transformer
Mamba/SSM
线性 RNN (DeltaNet)
非线性 RNN (LSTM)
M²RNN (本文)

2.3 非线性 RNN 的三大挑战

挑战一:语言建模性能差

当前非线性 RNN(如 LSTM、GRU)显著落后于线性 RNN。

挑战二:长上下文检索性能差

向量值状态容量有限,在 needle-in-a-haystack 任务上落后 Mamba-2/DeltaNet 约 20 分。

挑战三:训练效率低

无法跨序列长度并行化,硬件利用率差。

3. M²RNN 架构

3.1 核心创新:矩阵值状态

传统非线性 RNN 使用向量值状态 。M²RNN 改用矩阵值状态

状态大小:从 增加到 表达能力指数级提升

3.2 状态更新方程

M²RNN 的核心方程:

详细形式:

其中:

  • :遗忘门(forget gate)
  • :输入门(input gate)
  • :非线性状态转移函数
  • :逐元素函数和乘法

3.3 外积状态扩展机制

关键洞察:外积机制可直接应用于非线性 RNN

类似 LSTM/GRU 的门控结构,但状态通过外积扩展:

其中 分别是键向量和值向量。

这一机制使状态从 扩展到

3.4 完整架构

输入 x_t
    │
    ├──► [Forget Gate f_t] ──┐
    │                        │
    ├──► [Input Gate i_t] ───┤
    │                        ├──► [State Update] ──► H_t (d×d 矩阵)
    │                        │
    └──► [Key-Value proj] ───┘
                              │
                              ▼
                    [Output projection]
                              │
                              ▼
                          output

3.5 理论性质

定理 3.1(表达力):M²RNN 层可以表达任意非线性 RNN 能表达的所有计算。

证明思路:矩阵值状态 包含向量值状态 (通过对角元素),并提供额外的非对角元素用于表达更复杂的关系。

定理 3.2(与线性 RNN 的关系):当外积项主导时,M²RNN 退化为 DeltaNet 的非线性版本。

4. 训练优化

4.1 张量核心加速

关键工程创新:外积状态扩展天然适合张量核心

矩阵乘法分解:

每个外积都是 矩阵,可高效在 NVIDIA Tensor Core 上计算:

# 伪代码
H_t = torch.zeros(d, d)
for j in range(k):
    H_t += torch.outer(k_t[j], v_t[j])  # 张量核心加速

4.2 块式计算

为提高并行度,采用块式(chunkwise)计算:

# 块大小 B = 256
for chunk in chunks(sequence_length, B):
    # 块内并行计算
    H_block = compute_chunk_parallel(chunk)
    # 块间串行
    H_global = sequential_combine(H_global, H_block)

4.3 混合精度

张量精度
FP16/BF16
FP16/BF16
门控 FP32
累加器FP32

5. 实验结果

5.1 状态跟踪任务

在标准状态跟踪基准(Multi-query Associative Recall, MQAR)上:

模型训练长度测试长度准确率
Mamba-26425632.4%
Gated DeltaNet6425638.7%
LSTM6425678.2%
M²RNN64256100.0%

M²RNN 在测试长度是训练长度 4 倍时仍达到完美准确率——展现出色的长度泛化能力。

5.2 语言建模性能

在 7B MoE 模型(与 Gated DeltaNet 等参数量)上的结果:

架构WikiText PPLLongBench 准确率状态大小
Hybrid Mamba-26.3242.13× larger
Hybrid Gated DeltaNet6.1844.53× larger
Hybrid M²RNN5.7152.31× (baseline)

关键:M²RNN 用 1/3 的状态大小超越基线 0.4-0.5 PPL,并在长上下文任务上领先 8 分。

5.3 最小侵入性

仅替换单层循环层为 M²RNN:

配置训练吞吐量PPL 增益LongBench 增益
基线(无 M²RNN)100%00
替换 1 层为 M²RNN96%-0.3+6.2
替换全部循环层89%-0.5+8.0

单层替换几乎保留全部训练吞吐量,但获得大部分精度增益。

5.4 长上下文评估

在 LongBench 各项子任务上:

任务Hybrid DeltaNetHybrid M²RNNΔ
NarrativeQA38.249.1+10.9
Qasper41.547.8+6.3
MultiFieldQA44.651.2+6.6
HotpotQA39.846.3+6.5
2WikiMQA36.442.7+6.3
Musique23.128.5+5.4
平均37.344.3+7.0

6. 架构对比分析

6.1 与线性 RNN 家族的对比

维度Mamba-2Gated DeltaNetM²RNN
状态类型标量+外积外积矩阵+外积
状态大小
非线性选择性
表达力
训练效率
推理效率

6.2 与 LSTM/GRU 的对比

维度LSTMGRUM²RNN
状态类型向量 + 细胞状态向量矩阵
状态大小
门控输入/遗忘/输出更新/重置输入/遗忘 + 键值
表达力更强
训练效率
适用规模小-中小-中中-大

6.3 与 Transformer 的对比

维度TransformerM²RNN
复杂度类超越
状态跟踪
训练复杂度
推理复杂度
长上下文 内存 内存

7. 理论分析

7.1 表达力形式化

定理 7.1:M²RNN 在 之上的电路类中具有完整表达力。

证明思路:

  • 矩阵值状态可编码任意二元关系
  • 非线性门控允许复杂条件
  • 外积扩展实现高秩关联

7.2 与 S5 群的关系

论文证明 M²RNN 可以表达 S5 群(对称群)的所有元素:

这意味着 M²RNN 在排列组合任务上有理论保证。

7.3 长度泛化分析

为什么 M²RNN 在未见长度上完美泛化?

直觉:状态大小与位置无关,仅取决于状态维度 。因此长度变化不改变状态表示。

形式化

只要测试长度未超过状态容量

8. 实现细节

8.1 关键超参数

超参数推荐值说明
状态维度 128-512越大表达力越强
块大小 256训练并行度
遗忘门偏置初始化1.0鼓励保留历史
学习率标准

8.2 训练技巧

  1. 门控偏置初始化:遗忘门初始偏置设为 1.0,避免早期梯度消失
  2. 层归一化:在 M²RNN 层后加 LayerNorm
  3. 残差连接:与 Transformer 类似,残差连接 + 投影
  4. 混合策略:M²RNN 与 attention 层交错使用

8.3 推理优化

M²RNN 推理(自回归)使用纯循环形式:

def inference_step(H_prev, x_t):
    f_t = sigmoid(W_f @ x_t + U_f @ H_prev.flatten())
    i_t = sigmoid(W_i @ x_t + U_i @ H_prev.flatten())
    k_t = W_k @ x_t
    v_t = W_v @ x_t
    
    H_t = (1 - f_t) * H_prev + i_t * torch.outer(k_t, v_t)
    
    output = H_t @ W_o
    return H_t, output

推理复杂度 每步,与序列长度无关

9. 应用场景

9.1 大语言模型

最适合:

  • 需要强状态跟踪的代码 LLM
  • 长文档理解
  • 多轮对话中的实体一致性

9.2 多模态扩展

可扩展到:

  • 视频序列建模(强时序跟踪)
  • 音频流处理
  • 强化学习中的记忆任务

9.3 边缘部署

M²RNN 的 推理内存使其适合边缘部署。

10. 局限性与未来工作

10.1 当前局限

  1. 状态大小 仍较大,限制部署
  2. 训练稳定性:深层 M²RNN 训练仍需谨慎
  3. 多模态扩展:尚未充分验证

10.2 未来方向

  1. 稀疏化:通过稀疏矩阵降低状态大小
  2. 自适应维度:根据任务选择状态维度
  3. 与其他架构混合:M²RNN + Mamba + Attention

11. 与现有 Wiki 内容联系

  • RNN 基础[[recurrent-neural-network-rnn|RNN]] - 传统 RNN
  • LSTM[[lstm|LSTM]] - LSTM 架构基础
  • 现代 LSTM[[modern-lstm-advances-xlstm-tau-gru|现代 LSTM 进展]] - xLSTM 系列
  • Mamba 系列[[state-space-model|状态空间模型]] - SSM 基础
  • Mamba-2[[mamba-2-state-space-duality-deep-theory|Mamba-2 状态空间对偶性]]
  • 混合架构[[hybrid-ssm-transformer|混合 SSM-Transformer]]
  • 位置编码[[lstm-to-ssm-state-space-duality|LSTM-SSM 状态空间对偶性]]

12. 参考文献

Footnotes

  1. Mishra M., Tan S., Stoica I., Gonzalez J. E., Dao T. “M²RNN: Non-Linear RNNs with Matrix-Valued States for Scalable Language Modeling.” arXiv:2603.14360, 2026.

  2. 代码与模型:github.com/open-lm-engine/lm-engine