序列建模统一分类学:从RNN到TTT的范式整合
引言
序列建模是深度学习的核心问题之一。从 2017 年 Transformer 问世以来,序列建模领域经历了范式之争:RNN/LSTM 的循环归纳偏置、Attention 的全连接上下文、SSM 的连续-离散状态空间、TTT 的测试时学习——每种范式都有自己的优势和局限。
2025-2026 年的研究带来了两个根本性新洞察:
- 统一性:Mamba-2 的状态空间对偶性 (SSD) 证明了 SSM 与线性注意力在数学上等价;Titans/MIRAS 进一步将所有序列模型统一为”联想记忆模块”
- 新维度:测试时训练 (TTT) 引入了”测试时持续学习”的新维度,使序列模型从静态函数变为持续演化的动力学系统
本文旨在建立序列建模的统一分类学,把 30+ 看似零散的架构整合到同一框架内,并给出选型决策树。1
一、序列建模问题形式化
1.1 任务定义
序列建模的核心任务是学习一个映射:
或单输出形式 。其中 ,。
1.2 自回归分解
任何序列模型都可以分解为自回归形式:
其中每步的条件分布由模型参数 决定。
1.3 状态空间抽象
几乎所有序列模型都可以表示为状态递推:
其中 是状态, 是状态转移函数, 是输出函数。
关键差异在于:
- 的维度( vs )
- 是否可学习、是否依赖输入
- 是否维护显式历史 vs 压缩状态
1.4 三层抽象视图
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SequenceModelInterface(nn.Module):
"""所有序列模型的统一接口"""
def __init__(self, d_in, d_state, d_out, kind):
super().__init__()
self.d_in = d_in
self.d_state = d_state
self.d_out = d_out
self.kind = kind # 'rnn', 'ssm', 'attention', 'ttt'
def init_state(self, batch_size, device):
"""初始化状态"""
if self.kind == 'rnn':
# 固定维度状态 h_0
return torch.zeros(batch_size, self.d_state, device=device)
elif self.kind == 'ssm':
return torch.zeros(batch_size, self.d_state, device=device)
elif self.kind == 'attention':
# 状态 = 完整的 K, V 缓存
return {'K': [], 'V': []}
elif self.kind == 'ttt':
# 状态 = 可学习的参数
return {'memory': nn.Parameter(torch.zeros(self.d_state, self.d_state))}
def step(self, x_t, state):
"""单步前向传播"""
if self.kind == 'rnn':
h = state
i = self.W_ih(x_t)
h_new = torch.tanh(self.W_hh(h) + i)
return h_new, self.W_out(h_new)
elif self.kind == 'ssm':
# 状态空间更新
h = state
A = self.A # (d_state, d_state)
B = self.B(x_t) # (d_state, d_in)
h_new = h @ A + B
return h_new, self.C(h_new)
elif self.kind == 'attention':
K, V = state['K'], state['V']
K.append(self.W_k(x_t))
V.append(self.W_v(x_t))
K_stack = torch.stack(K, dim=1) # (B, T, d_k)
V_stack = torch.stack(V, dim=1)
q = self.W_q(x_t)
scores = q @ K_stack.transpose(-2, -1) / math.sqrt(self.d_k)
attn = F.softmax(scores, dim=-1)
y = attn @ V_stack
return {'K': K, 'V': V}, self.W_o(y)
elif self.kind == 'ttt':
# 测试时学习:根据当前输入更新记忆
memory = state['memory']
# 简单的 Hebbian 更新
x_proj = self.proj(x_t)
memory = memory + 0.01 * torch.outer(x_proj, x_proj)
y = x_proj @ memory
return {'memory': memory}, y
def forward(self, x):
"""完整序列前向"""
B, T, _ = x.shape
state = self.init_state(B, x.device)
outputs = []
for t in range(T):
state, y_t = self.step(x[:, t], state)
outputs.append(y_t)
return torch.stack(outputs, dim=1)二、四大范式分类
2.1 范式 1:循环神经网络(RNN 范式)
核心思想:维护固定维度状态 ,每步更新。
2.1.1 Vanilla RNN
- 复杂度:每步 ,总
- 归纳偏置:时间局部性、权重共享
- 训练:BPTT(Back-Propagation Through Time)
- 问题:梯度消失/爆炸
2.1.2 LSTM(Hochreiter & Schmidhuber 1997)
LSTM 通过门控机制解决梯度问题:
关键洞察:加性更新 使梯度通过 通路不衰减。
2.1.3 GRU(Cho et al. 2014)
GRU 简化为两个门:
2.1.4 xLSTM(Beck et al. 2024, ICLR 2026 Outstanding)
xLSTM 是 LSTM 的现代复兴:
- 指数门控(exponential gating)
- 矩阵记忆(matrix memory)替代标量
- 协方差更新规则
- xLSTM-7B 工业实现,scaling laws 拟合
class XLSTMCell(nn.Module):
"""xLSTM Cell(简化版)"""
def __init__(self, d_input, d_state):
super().__init__()
self.d_input = d_input
self.d_state = d_state
# 门控参数
self.W_i = nn.Linear(d_input, d_state)
self.W_f = nn.Linear(d_input, d_state)
self.W_o = nn.Linear(d_input, d_state)
self.W_z = nn.Linear(d_input, d_state)
# 矩阵记忆
self.C = nn.Parameter(torch.randn(d_state, d_state) * 0.01)
def forward(self, x_t, h_prev, C_prev):
i = torch.exp(self.W_i(x_t)) # 指数门控
f = torch.exp(self.W_f(x_t))
o = torch.sigmoid(self.W_o(x_t))
z = self.W_z(x_t)
# 协方差更新规则
C_new = f.unsqueeze(-1) * C_prev + i.unsqueeze(-1) * torch.outer(torch.tanh(z), x_t)
h_new = o * (C_new @ x_t)
return h_new, C_new2.1.5 M2RNN 矩阵值 RNN
M2RNN 使用 矩阵作为状态:
表达力:理论上比 vanilla RNN 指数级强,能记忆更多信息。
2.1.6 复杂度分析
| 模型 | 每步复杂度 | 总复杂度 | 并行化 |
|---|---|---|---|
| Vanilla RNN | 低(需 BPTT) | ||
| LSTM | 低 | ||
| xLSTM | 取决于实现 | 中等 | |
| M2RNN | 低 |
2.2 范式 2:状态空间模型(SSM 范式)
核心思想:连续-离散状态空间,将序列视为 ODE 的离散采样。
2.2.1 连续状态空间
其中 ,,。
2.2.2 离散化
用零阶保持(ZOH)或双线性变换离散化:
2.2.3 S4(Gu, Goel, Re 2022)
S4 引入对角 + 低秩 (DPLR) 参数化 ,使 的高次幂可快速计算(基于 HiPPO 初始化)。
2.2.4 Mamba(Gu & Dao 2023)
Mamba 的关键创新是输入依赖的选择性 SSM:
关键差异(与 S4): 依赖输入,实现选择性信息过滤。
2.2.5 Mamba-2 与 SSD 框架
Mamba-2 (Dao & Gu 2024) 引入状态空间对偶性 (SSD),证明:
SSM(半可分矩阵)= 结构化注意力(掩码 1-SS)
具体地,SSM 的输出 ,其中 是半可分矩阵(可分解为 个秩一阵子之和)。这与结构化掩码注意力的矩阵形式完全等价。
Mamba-2 块用 SSD 视角实现,比 Mamba-1 快 2-8 倍。
2.2.6 Mamba-3 (ICLR 2026 Best Paper)
Mamba-3 的进一步创新:
- 复值状态:,更丰富的表达力
- MIMO 架构:单步多输入多输出
- 选择性增强:更精细的输入依赖机制
2.2.7 复杂度分析
| 模型 | 训练复杂度 | 推理复杂度 | 内存 |
|---|---|---|---|
| S4 | |||
| Mamba-1 | |||
| Mamba-2 | |||
| Mamba-3 |
2.3 范式 3:注意力(Attention 范式)
核心思想:每步关注整个历史,通过 query-key 匹配检索。
2.3.1 标准注意力
- 训练复杂度:
- 推理复杂度:(带 KV Cache)
- 归纳偏置:排列等变(无内置顺序概念)
- 位置编码:RoPE、ALiBi、相对位置等
2.3.2 线性注意力(Linear Attention)
核心思想:把 softmax 替换为可分解的特征映射 :
计算重排:
- 朴素:,
- 线性:,
class LinearAttention(nn.Module):
"""线性注意力实现"""
def __init__(self, d_model, d_k, eps=1e-6):
super().__init__()
self.W_q = nn.Linear(d_model, d_k, bias=False)
self.W_k = nn.Linear(d_model, d_k, bias=False)
self.W_v = nn.Linear(d_model, d_model)
self.eps = eps
def feature_map(self, x):
"""特征映射:x -> elu(x) + 1(保证非负)"""
return F.elu(x) + 1
def forward(self, x):
B, T, d = x.shape
Q = self.feature_map(self.W_q(x)) # (B, T, d_k)
K = self.feature_map(self.W_k(x)) # (B, T, d_k)
V = self.W_v(x) # (B, T, d)
# 线性化:Q (K^T V) 而非 (Q K^T) V
KV = K.transpose(-2, -1) @ V # (B, d_k, d)
numerator = Q @ KV # (B, T, d)
# 归一化
K_sum = K.sum(dim=1, keepdim=True) # (B, 1, d_k)
denominator = Q @ K_sum.transpose(-2, -1) # (B, T, 1)
return numerator / (denominator + self.eps)2.3.3 各种 Linear Attention 变体
| 变体 | 特征映射 | 论文 |
|---|---|---|
| Linear Transformer | Katharopoulos 2020 | |
| Performer | 随机特征 | Choromanski 2021 |
| Random Feature Attention | 高斯随机特征 | Peng 2021 |
| RetNet | 衰减 | Sun 2023 |
| Gated DeltaNet | 门控 + Delta 规则 | Yang 2024 (ICLR 2025) |
2.3.4 状态视角
线性注意力可视为矩阵状态递推:
这与 SSM 的形式完全平行——这是 SSD 框架的核心。
2.3.5 复杂度分析
| 模型 | 训练 | 推理 | 表达力 |
|---|---|---|---|
| Full Attention | 最高 | ||
| Linear Attention | 中等 | ||
| Performer | 近似 Full | ||
| RetNet | 衰减记忆 | ||
| Gated DeltaNet | 强 |
2.4 范式 4:测试时训练(TTT 范式)
核心思想:状态在测试时持续学习,而非固定参数。
2.4.1 TTT 的核心洞察
Sun et al. 2024 提出:
任何序列模型的状态都可以视为”模型权重”——即状态本身可以是通过梯度下降学习的参数。
这意味着:
- RNN 的 可以是测试时学习的参数
- SSM 的 可以是测试时学习的参数
- Attention 的 KV Cache 可以是测试时学习的参数
2.4.2 TTT 层
TTT 层在测试时执行以下操作:
1. 维护一个"内部权重" W_t(状态)
2. 对当前输入 x_t 计算损失:L(W_t; x_t)
3. 用 SGD 更新:W_{t+1} = W_t - η ∇L(W_t; x_t)
4. 用 W_{t+1} 处理后续输入
双层优化:
- 外层:训练 TTT 层的”超参数”(学习率、初始化、损失函数)
- 内层:测试时持续学习 W_t
2.4.3 Titans(Behrouz et al., NeurIPS 2025)
Titans 引入神经长期记忆 (Neural Long-Term Memory):
class NeuralMemory(nn.Module):
"""Titans 神经长期记忆模块"""
def __init__(self, d_model, d_memory, mlp_hidden):
super().__init__()
# 记忆是 MLP
self.memory_mlp = nn.Sequential(
nn.Linear(d_model, mlp_hidden),
nn.SiLU(),
nn.Linear(mlp_hidden, d_memory),
)
# 三个学习信号
self.theta_q = nn.Linear(d_model, d_memory) # 查询
self.theta_k = nn.Linear(d_model, d_memory) # 键
self.theta_v = nn.Linear(d_model, d_memory) # 值
self.eta = nn.Parameter(torch.zeros(d_memory)) # 门控学习率
def forward(self, x_t, M_t):
"""
x_t: (B, d_model) 当前输入
M_t: (d_memory, d_memory) 记忆状态
"""
# Surprise-based gating
q = self.theta_q(x_t)
k = self.theta_k(x_t)
v = self.theta_v(x_t)
# 计算 surprise
with torch.no_grad():
Mv = M_t @ v.unsqueeze(-1)
surprise = (v.unsqueeze(-1) - Mv).norm(dim=-2)
# 自适应学习率
eta = F.sigmoid(self.eta) * surprise
# 梯度下降更新记忆
grad = torch.outer(k, v) # 一阶梯度近似
M_new = M_t - eta.unsqueeze(-1) * grad
return M_new, M_new @ q关键:
- 记忆权重 通过测试时梯度下降更新
- 学习率 是输入依赖的(自适应门控)
- 突破 200 万 token 上下文长度
2.4.4 MIRAS (Behrouz et al. 2025)
MIRAS 提出统一框架:
所有序列模型都是联想记忆模块(Associative Memory Modules)
包括:
- Attention = Hopfield 网络
- SSM = 线性联想记忆
- RNN = 有限状态联想记忆
- TTT = 持续学习的联想记忆
YAAD、MONETA、MEMORA 是 MIRAS 框架下的具体实例。
2.4.5 复杂度分析
TTT 的计算开销主要在测试时梯度计算:
- 单步:( 是内层步数)
- 与序列长度解耦
三、统一数学框架
3.1 线性可加性是统一的核心
几乎所有现代序列模型都可以表示为线性可加递推:
其中 可能是输入依赖的。
3.1.1 各范式在此框架下
| 范式 | 维度 | ||
|---|---|---|---|
| Vanilla RNN | |||
| LSTM | (对角门) | ||
| Linear SSM | (HiPPO/Mamba) | ||
| Mamba | (输入依赖) | ||
| Mamba-3 | (复值) | ||
| Linear Attention | |||
| TTT/Titans |
3.2 状态维度的本质差异
- RNN/SSM:状态 , 与序列长度无关 → 状态
- Linear Attention:状态 ,矩阵 → 状态
- Full Attention:状态 = 完整 K, V 历史 → 状态
- TTT/Titans:状态 = 可学习参数 → 可任意大
3.3 SSD 框架的严格证明
定理(Dao & Gu 2024):设 是半可分矩阵:
或等价地:
则 既可以用 SSM(前向递推)计算,也可以用结构化掩码注意力计算,且完全等价。
证明概要:
- SSM 视角:
- 注意力视角:
- 矩阵 在两种视角下形式相同
3.4 记忆容量理论
不同范式的记忆容量不同:
| 范式 | 容量 | 理论依据 |
|---|---|---|
| Vanilla RNN | 信息瓶颈 | |
| LSTM | 有效 | 门控机制 |
| SSM | 状态维度 | |
| Full Attention | 精确 | KV Cache |
| Linear Attention | 有限 | 状态维度 |
| TTT/Titans | 可学习 | 神经长期记忆 |
信息论下界:任何状态维度 的序列模型最多记忆 比特信息(Shannon 论证)。
四、复杂度-表达力-效率三角
4.1 三角权衡
任何序列模型都面临三个核心约束:
- 计算复杂度:训练与推理的计算开销
- 表达力:能学习的函数类别
- 实现效率:并行化能力、内存占用
表达力
▲
/│\
/ │ \
/ │ \
/ │ \
/ │ \
/ ◆ │ ◆ \
/ RNN │ Attn \
/ 范式 │ 范式 \
/───────┼────────\
/ │ \
/ SSM │ Linear \
/ 范式 │ Attn \
/──────────┴──────────\
复杂度 ─────────────►
没有银弹:每种范式都在三角上做了不同取舍。
4.2 复杂度对比表
| 范式 | 训练 | 推理 | 内存 | 并行化 |
|---|---|---|---|---|
| Vanilla RNN | 低 | |||
| LSTM | 低 | |||
| xLSTM | 中 | |||
| S4 | 高 | |||
| Mamba-1 | 高 | |||
| Mamba-2 | 高 | |||
| Linear Attn | 高 | |||
| Full Attn | 中(KV 缓存) | |||
| TTT | 中 | |||
| Titans | 中 |
4.3 表达力对比
| 范式 | 形式语言 | 长程依赖 | 状态查询 | 算术 |
|---|---|---|---|---|
| Vanilla RNN | TC⁰ | 弱 | 否 | 中 |
| LSTM | TC⁰ | 中 | 部分 | 中 |
| SSM/S4 | TC⁰ | 强 | 否 | 强 |
| Mamba | TC⁰ | 强 | 选择性 | 强 |
| Full Attn | TC⁰ | 强 | 精确 | 强 |
| Linear Attn | 有限 | 弱 | 近似 | 中 |
| TTT | 任意(测试时) | 强 | 自学习 | 可调 |
关键洞察:Full Attention 的表达力理论上最强(任何 TC⁰ 函数),但实践中未必——LSTM/SSM 在特定任务上能匹敌或超越。
4.4 TC⁰ 复杂性类
Merrill & Sabharwal 2023 严格证明:
标准 Transformer 精确识别的形式语言类 = TC⁰(阈值电路常数深度多项式大小)
而 RNN(包括 LSTM、GRU)识别能力严格弱于 TC⁰。这是 Transformer vs RNN 表达力的严格分层。
TC⁰ 包含:
- 加法、乘法
- 多数表决
- 阈值函数
- 排序
TC⁰ 不包含(需要 TC¹):
- 匹配括号(Dyck-)的完整识别
- 一般图灵可计算函数
但 Transformer 通过”软”近似仍能在实践中处理这些任务。
五、测试时学习(TTT)作为新维度
5.1 传统序列建模的局限
所有传统范式都有一个隐含假设:
训练时学习参数,推理时固定参数
这个假设限制了信息存储和适应能力。
5.2 TTT 的革命性
TTT 引入了第三维度:
训练时 ──────── 推理时
↓ ↓
传统:学习参数 固定参数
TTT: 学习超参数 持续学习参数
这意味着:
- 模型可以记住训练时未见过的信息
- 模型可以根据输入分布自适应
- 模型在长上下文中能持续学习
5.3 TTT 的统一视角
TTT 层可视为**“模型作为模型权重”**:
- 状态 是模型权重
- 输入 是训练数据
- 更新 是梯度下降
因此 TTT 是一种元学习形式(learning to learn)。
5.4 TTT vs 微调
| 维度 | 微调 | TTT |
|---|---|---|
| 更新频率 | 训练阶段 | 推理时每步 |
| 数据量 | 大批量 | 单样本 |
| 计算位置 | 离线 | 在线 |
| 内存开销 | 全模型 | 单层状态 |
| 性能稳定性 | 稳定 | 可能不稳定 |
5.5 Titans 的具体创新
Titans (Behrouz et al., NeurIPS 2025) 在 TTT 基础上:
-
三层记忆架构:
- 短期:滑动窗口注意力
- 中期:Persistent Memory(可学习参数)
- 长期:神经长期记忆(测试时学习)
-
自适应学习率:
学习率与”surprise”成正比——惊讶越大,学习越快 -
门控机制:
用门控控制记忆更新强度 -
实验结果:
- Needle-in-Haystack:200 万 token
- 性能超越 Mamba-2 和 Transformer
六、混合架构
6.1 为什么需要混合
没有单一范式在所有任务上最优:
- Transformer:短序列强表达力、长序列贵
- SSM:长序列高效、表达力受限
- TTT:长上下文、自适应
混合策略:用 Attention 处理局部 + SSM 处理全局 + TTT 长期记忆。
6.2 代表性混合架构
6.2.1 Jamba(AI21, 2024)
7B 参数,52 层,1:7 Attention:SSM 比例:
class JambaBlock(nn.Module):
def __init__(self, d_model, n_heads, d_state, mamba_expand, attn_period=8):
super().__init__()
# 每 8 层用 1 个 Attention
self.use_attn = False
if attn_period and random.random() < 1/attn_period:
self.use_attn = True
self.attn = Attention(d_model, n_heads)
self.mamba = MambaBlock(d_model, d_state, mamba_expand)
self.moe = MoE(d_model, n_experts=16, top_k=2)
def forward(self, x):
if self.use_attn:
x = x + self.attn(x)
x = x + self.mamba(x)
x = x + self.moe(x)
return x6.2.2 TransMamba(2025)
层次化混合:底层 Mamba,高层 Attention。
6.2.3 Hymba(NVIDIA, 2025)
头级别混合:每头决定用 Attention 还是 SSM。
6.2.4 HydraHead
类似 Hymba,但有更细粒度的头级别门控。
6.3 混合设计原则
| 原则 | 说明 |
|---|---|
| Attention 占比 | 1:8 到 1:4 是常见区间 |
| 层次 | 浅层 Attention、深层 SSM(或反之) |
| 头级别 | 不同头用不同范式 |
| 门控 | 软门控学习权重 |
七、2025-2026 前沿进展
7.1 Mamba-3 (ICLR 2026 Best Paper)
Mamba-3 引入:
-
复值状态:
- 表达力翻倍
- 自然编码振荡模式
- 谱方法更易分析
-
MIMO 架构:单步多输入多输出
- 每次更新处理多个 token
- 提高硬件利用率
-
选择性增强:
更精细的输入依赖离散化
7.2 Titans (NeurIPS 2025)
Titans 的核心创新:
- 神经长期记忆(MLP 作为状态)
- 自适应学习率
- 三层记忆架构
实验结果:
- Needle-in-Haystack:200 万 token 99% 准确率
- BABILong 基准:超越 Transformer 和 Mamba-2
- 训练效率:2-10x 加速
7.3 MIRAS(Behrouz et al. 2025)
MIRAS 提出统一框架:
所有序列模型 = 联想记忆模块
包括四大家族:
- YAAD (You Actually Add Delta):测试时学习
- MONETA:状态作为测试时学习
- MEMORA:多时间尺度记忆
- TITANS:神经长期记忆
核心洞察:
“将序列模型视为联想记忆模块,统一了 RNN/SSM/Attention/TTT 的本质”
7.4 扩散序列模型
2025-2026 年,扩散模型开始应用于序列建模:
- Diffusion Language Model (LLaDA):用掩码扩散替代自回归
- MDLM (Masked Diffusion Language Model):扩展到 7B 参数
- dLLM:离散扩散 LM
class MaskedDiffusionLM(nn.Module):
"""简化版掩码扩散 LM"""
def __init__(self, d_model, vocab_size, n_layers):
super().__init__()
self.backbone = Transformer(d_model, n_layers)
self.vocab_proj = nn.Linear(d_model, vocab_size)
def forward(self, x_t, t):
"""x_t 是 t 时刻被掩码的序列"""
h = self.backbone(x_t)
# 预测被掩码的位置
logits = self.vocab_proj(h)
return logits
def sample(self, n_steps, batch_size, seq_len):
"""反向扩散采样"""
x = torch.full((batch_size, seq_len), MASK_TOKEN_ID)
for t in reversed(range(n_steps)):
logits = self.forward(x, t)
# 按噪声调度采样被掩码的位置
x = self.mask_schedule.sample_step(logits, x, t)
return x7.5 状态空间对偶性 (SSD) 的新发展
Mamba-2 SSD 框架在 2025-2026 进一步推广:
- 多层次 SSD:跨层应用 SSD
- Gated SSD:加门控的 SSD(Mamba-2 用 Gated DeltaNet 接近)
- 稀疏 SSD:稀疏化的 SSD(用于长上下文)
- Hybrid SSD:与 Attention 混合的 SSD
八、选型决策树
8.1 决策流程
任务: 序列建模
│
├─ 序列长度 L < 1K?
│ └─ 是 → Full Attention (表达力优先)
│
├─ L > 32K 且硬件受限?
│ ├─ 是 → Mamba-2/3 或 Linear Attention
│ └─ 否 → 考虑混合架构
│
├─ 需要超长上下文(> 1M tokens)?
│ └─ 是 → Titans / TTT 类
│
├─ 任务对长程依赖敏感(如代码)?
│ └─ 是 → Titans 或 Mamba-3
│
├─ 多模态/视频/音频?
│ └─ 是 → 混合架构(Jamba 风格)
│
└─ 默认 → Transformer (Mamba-2 备份)
8.2 任务-范式映射表
| 任务 | 推荐范式 | 备选 |
|---|---|---|
| 短文本分类 | Full Attention | RNN 即可 |
| 长文档 QA | Mamba-2/3 或 Titans | Hybrid |
| 代码生成 | Full Attention + 滑动窗口 | Titans |
| 时间序列预测 | xLSTM 或 Mamba | TCN |
| 视频理解 | Hybrid (Jamba 风格) | Mamba-3 |
| 语音识别 | Conformer (CNN + Attention) | Mamba-2 |
| 强化学习策略 | LSTM (xLSTM) | RNN |
| 长程对话 | Titans | Mamba-3 |
| 实时流式 | SSM (因果) | Linear Attn |
8.3 工程考虑
- 训练资源:Full Attention 训练最贵,Linear 训练最便宜
- 推理延迟:RNN/SSM 恒定延迟,Attention 增长
- KV 缓存:Full Attention 内存爆炸
- 硬件友好:SSM 在 GPU 上有专门 kernel
- 可扩展性:Titans/Mamba 适合长上下文
九、完整 PyTorch 实现:UnifiedSequenceBlock
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from enum import Enum
class SequenceKind(Enum):
RNN = "rnn"
LSTM = "lstm"
XLSTM = "xlstm"
SSM = "ssm"
MAMBA = "mamba"
LINEAR_ATTN = "linear_attn"
TITANS = "titans"
HYBRID = "hybrid"
class UnifiedSequenceBlock(nn.Module):
"""统一序列建模块:可在四大范式间切换"""
def __init__(self, d_model, d_state, kind, n_heads=4, mlp_hidden=None):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.kind = SequenceKind(kind)
self.n_heads = n_heads
self.mlp_hidden = mlp_hidden or 4 * d_model
# 归一化
self.norm = nn.RMSNorm(d_model)
self.norm_state = nn.LayerNorm(d_state) if self.kind in [SequenceKind.RNN, SequenceKind.LSTM, SequenceKind.SSM] else nn.Identity()
# 范式特定组件
if self.kind == SequenceKind.RNN:
self.cell = nn.RNNCell(d_model, d_state)
self.out_proj = nn.Linear(d_state, d_model)
elif self.kind == SequenceKind.LSTM:
self.cell = nn.LSTMCell(d_model, d_state)
self.out_proj = nn.Linear(d_state, d_model)
elif self.kind == SequenceKind.XLSTM:
self.cell = XLSTMCell(d_model, d_state)
self.out_proj = nn.Linear(d_state, d_model)
elif self.kind == SequenceKind.SSM:
# S4 简化版
self.A = nn.Parameter(torch.randn(d_state, d_state) * 0.01)
self.B = nn.Linear(d_model, d_state)
self.C = nn.Linear(d_state, d_model)
self.D = nn.Parameter(torch.zeros(d_model))
self.log_dt = nn.Parameter(torch.zeros(d_state))
elif self.kind == SequenceKind.MAMBA:
# 简化 Mamba 块
self.A = nn.Parameter(torch.randn(d_state, d_state) * 0.01)
self.B_proj = nn.Linear(d_model, d_state)
self.C_proj = nn.Linear(d_model, d_state)
self.dt_proj = nn.Linear(d_model, d_state)
self.out_proj = nn.Linear(d_state, d_model)
elif self.kind == SequenceKind.LINEAR_ATTN:
self.W_q = nn.Linear(d_model, d_state, bias=False)
self.W_k = nn.Linear(d_model, d_state, bias=False)
self.W_v = nn.Linear(d_model, d_model)
self.feature_map = lambda x: F.elu(x) + 1
elif self.kind == SequenceKind.TITANS:
self.memory_mlp = nn.Sequential(
nn.Linear(d_model, d_state),
nn.SiLU(),
nn.Linear(d_state, d_state * d_state),
)
self.theta_q = nn.Linear(d_model, d_state)
self.theta_k = nn.Linear(d_model, d_state)
self.theta_v = nn.Linear(d_model, d_state)
self.eta = nn.Parameter(torch.zeros(d_state))
elif self.kind == SequenceKind.HYBRID:
# Jamba 风格:Attention + SSM
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.A = nn.Parameter(torch.randn(d_state, d_state) * 0.01)
self.B = nn.Linear(d_model, d_state)
self.C = nn.Linear(d_state, d_model)
# 共享 FFN
self.ffn = nn.Sequential(
nn.Linear(d_model, self.mlp_hidden),
nn.GELU(),
nn.Linear(self.mlp_hidden, d_model),
)
def forward(self, x):
"""x: (B, T, d_model)"""
B, T, d = x.shape
if self.kind == SequenceKind.RNN:
h = torch.zeros(B, self.d_state, device=x.device)
outputs = []
for t in range(T):
h = self.cell(x[:, t], h)
outputs.append(self.out_proj(h))
y = torch.stack(outputs, dim=1)
elif self.kind == SequenceKind.LSTM:
h = torch.zeros(B, self.d_state, device=x.device)
c = torch.zeros(B, self.d_state, device=x.device)
outputs = []
for t in range(T):
h, c = self.cell(x[:, t], (h, c))
outputs.append(self.out_proj(h))
y = torch.stack(outputs, dim=1)
elif self.kind == SequenceKind.XLSTM:
h = torch.zeros(B, self.d_state, device=x.device)
C = torch.zeros(B, self.d_state, self.d_state, device=x.device)
outputs = []
for t in range(T):
h, C = self.cell(x[:, t], h, C)
outputs.append(self.out_proj(h))
y = torch.stack(outputs, dim=1)
elif self.kind == SequenceKind.SSM:
h = torch.zeros(B, self.d_state, device=x.device)
A = -torch.exp(self.log_dt).unsqueeze(0) * self.A
outputs = []
for t in range(T):
h = h @ A + self.B(x[:, t]).unsqueeze(1) * 1.0
outputs.append(self.C(h))
y = torch.stack(outputs, dim=1) + x * self.D
elif self.kind == SequenceKind.MAMBA:
h = torch.zeros(B, self.d_state, device=x.device)
outputs = []
for t in range(T):
# 输入依赖
B_t = self.B_proj(x[:, t])
C_t = self.C_proj(x[:, t])
dt = F.softplus(self.dt_proj(x[:, t]))
# 离散化
A_bar = torch.exp(dt.unsqueeze(-1) * self.A)
h = h * A_bar + B_t * x[:, t].unsqueeze(1)
y_t = (C_t.unsqueeze(1) * h).sum(dim=-1)
outputs.append(self.out_proj(y_t))
y = torch.stack(outputs, dim=1)
elif self.kind == SequenceKind.LINEAR_ATTN:
Q = self.feature_map(self.W_q(x)) # (B, T, d_state)
K = self.feature_map(self.W_k(x))
V = self.W_v(x)
# 状态递推
S = torch.zeros(B, self.d_state, d, device=x.device)
outputs = []
for t in range(T):
S = S + torch.bmm(K[:, t:t+1].transpose(-2, -1), V[:, t:t+1])
y_t = torch.bmm(Q[:, t:t+1], S)
outputs.append(y_t)
y = torch.cat(outputs, dim=1)
elif self.kind == SequenceKind.TITANS:
# 神经长期记忆
M = self.memory_mlp(x[:, 0]).view(B, self.d_state, self.d_state)
outputs = []
for t in range(T):
q = self.theta_q(x[:, t])
k = self.theta_k(x[:, t])
v = self.theta_v(x[:, t])
# 读取
y_t = torch.bmm(M, q.unsqueeze(-1)).squeeze(-1)
# Surprise-based 更新
with torch.no_grad():
Mv = torch.bmm(M, v.unsqueeze(-1))
surprise = (v.unsqueeze(-1) - Mv).norm(dim=-2)
eta = torch.sigmoid(self.eta) * surprise
grad = torch.bmm(k.unsqueeze(-1), v.unsqueeze(1))
M = M - eta.unsqueeze(-1) * grad
outputs.append(y_t)
y = torch.stack(outputs, dim=1)
elif self.kind == SequenceKind.HYBRID:
# Jamba 风格:先 Attention,再 SSM
x_norm = self.norm(x)
attn_out, _ = self.attn(x_norm, x_norm, x_norm)
x = x + attn_out
h = torch.zeros(B, self.d_state, device=x.device)
outputs = []
for t in range(T):
h = h @ self.A + self.B(x[:, t])
outputs.append(self.C(h))
y = torch.stack(outputs, dim=1)
# 残差 + FFN
y = x + y
y = y + self.ffn(self.norm(y))
return y
# 使用示例
if __name__ == "__main__":
B, T, d_model = 4, 64, 128
d_state = 64
x = torch.randn(B, T, d_model)
for kind in SequenceKind:
model = UnifiedSequenceBlock(d_model, d_state, kind.value)
y = model(x)
print(f"{kind.value}: input {x.shape} -> output {y.shape}")十、未来方向
10.1 开放问题
- 统一架构:是否存在统一范式同时实现 RNN 效率、SSM 长程、Attention 表达力、TTT 自适应?
- 硬件-算法协同设计:新硬件(如 HBM-GPU、Wafer-Scale)将如何改变序列建模范式?
- 超大规模记忆: 状态维度的 TTT 实用化
- 跨模态序列:视频、音频、3D 统一建模
- 测试时计算与记忆的协同:用 TTT 替代 RAG?
10.2 2026 年值得关注的论文
- Mamba-3 (ICLR 2026 Best Paper)
- Titans / MIRAS (NeurIPS 2025)
- xLSTM-7B 工业部署
- Diffusion Language Model 系列
- 状态空间对偶性 进一步推广
10.3 与其他领域的交叉
- 强化学习:TTT 与元学习的天然联系
- 神经科学:海马体-皮层记忆系统的工程类比
- 符号推理:RNN 的有限状态机视角
- 优化理论:TTT 是隐式正则化器
总结
本文建立了序列建模的统一分类学,将四大范式(RNN/SSM/Attention/TTT)整合到同一数学框架内。关键洞察:
- 线性可加递推 是几乎所有现代序列模型的核心
- 状态空间对偶性 (SSD) 揭示了 SSM 与 Linear Attention 的等价性
- 测试时训练 (TTT) 引入了”测试时持续学习”的新维度
- 混合架构是当前最优实践,平衡表达力与效率
2025-2026 年见证了三个重大突破:
- Mamba-3 (ICLR 2026 Best):复值状态
- Titans (NeurIPS 2025):神经长期记忆
- MIRAS:所有序列模型 = 联想记忆
未来方向是统一架构——在 RNN 效率、SSM 长程、Attention 表达力、TTT 自适应之间找到最优平衡点。1
参考资料
Footnotes
-
本文综合了以下工作的核心思想:Katharopoulos et al. 2020 (Linear Attention)、Gu et al. 2022 (S4)、Gu & Dao 2023 (Mamba)、Dao & Gu 2024 (Mamba-2/SSD)、Behrouz et al. 2025 (Titans/MIRAS)、Beck et al. 2024 (xLSTM)、Dao et al. 2022 (FlashAttention)、Sun et al. 2024 (TTT)、Merrill & Sabharwal 2023 (TC⁰)、Chizat & Bach 2024 (Mean-Field) 等。 ↩ ↩2