序列建模统一分类学:从RNN到TTT的范式整合

引言

序列建模是深度学习的核心问题之一。从 2017 年 Transformer 问世以来,序列建模领域经历了范式之争:RNN/LSTM 的循环归纳偏置、Attention 的全连接上下文、SSM 的连续-离散状态空间、TTT 的测试时学习——每种范式都有自己的优势和局限。

2025-2026 年的研究带来了两个根本性新洞察:

  1. 统一性:Mamba-2 的状态空间对偶性 (SSD) 证明了 SSM 与线性注意力在数学上等价;Titans/MIRAS 进一步将所有序列模型统一为”联想记忆模块”
  2. 新维度:测试时训练 (TTT) 引入了”测试时持续学习”的新维度,使序列模型从静态函数变为持续演化的动力学系统

本文旨在建立序列建模的统一分类学,把 30+ 看似零散的架构整合到同一框架内,并给出选型决策树。1


一、序列建模问题形式化

1.1 任务定义

序列建模的核心任务是学习一个映射:

或单输出形式 。其中

1.2 自回归分解

任何序列模型都可以分解为自回归形式:

其中每步的条件分布由模型参数 决定。

1.3 状态空间抽象

几乎所有序列模型都可以表示为状态递推

其中 是状态, 是状态转移函数, 是输出函数。

关键差异在于:

  • 的维度( vs
  • 是否可学习、是否依赖输入
  • 是否维护显式历史 vs 压缩状态

1.4 三层抽象视图

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
 
class SequenceModelInterface(nn.Module):
    """所有序列模型的统一接口"""
    def __init__(self, d_in, d_state, d_out, kind):
        super().__init__()
        self.d_in = d_in
        self.d_state = d_state
        self.d_out = d_out
        self.kind = kind  # 'rnn', 'ssm', 'attention', 'ttt'
 
    def init_state(self, batch_size, device):
        """初始化状态"""
        if self.kind == 'rnn':
            # 固定维度状态 h_0
            return torch.zeros(batch_size, self.d_state, device=device)
        elif self.kind == 'ssm':
            return torch.zeros(batch_size, self.d_state, device=device)
        elif self.kind == 'attention':
            # 状态 = 完整的 K, V 缓存
            return {'K': [], 'V': []}
        elif self.kind == 'ttt':
            # 状态 = 可学习的参数
            return {'memory': nn.Parameter(torch.zeros(self.d_state, self.d_state))}
 
    def step(self, x_t, state):
        """单步前向传播"""
        if self.kind == 'rnn':
            h = state
            i = self.W_ih(x_t)
            h_new = torch.tanh(self.W_hh(h) + i)
            return h_new, self.W_out(h_new)
        elif self.kind == 'ssm':
            # 状态空间更新
            h = state
            A = self.A  # (d_state, d_state)
            B = self.B(x_t)  # (d_state, d_in)
            h_new = h @ A + B
            return h_new, self.C(h_new)
        elif self.kind == 'attention':
            K, V = state['K'], state['V']
            K.append(self.W_k(x_t))
            V.append(self.W_v(x_t))
            K_stack = torch.stack(K, dim=1)  # (B, T, d_k)
            V_stack = torch.stack(V, dim=1)
            q = self.W_q(x_t)
            scores = q @ K_stack.transpose(-2, -1) / math.sqrt(self.d_k)
            attn = F.softmax(scores, dim=-1)
            y = attn @ V_stack
            return {'K': K, 'V': V}, self.W_o(y)
        elif self.kind == 'ttt':
            # 测试时学习:根据当前输入更新记忆
            memory = state['memory']
            # 简单的 Hebbian 更新
            x_proj = self.proj(x_t)
            memory = memory + 0.01 * torch.outer(x_proj, x_proj)
            y = x_proj @ memory
            return {'memory': memory}, y
 
    def forward(self, x):
        """完整序列前向"""
        B, T, _ = x.shape
        state = self.init_state(B, x.device)
        outputs = []
        for t in range(T):
            state, y_t = self.step(x[:, t], state)
            outputs.append(y_t)
        return torch.stack(outputs, dim=1)

二、四大范式分类

2.1 范式 1:循环神经网络(RNN 范式)

核心思想:维护固定维度状态 ,每步更新。

2.1.1 Vanilla RNN

  • 复杂度:每步 ,总
  • 归纳偏置:时间局部性、权重共享
  • 训练:BPTT(Back-Propagation Through Time)
  • 问题:梯度消失/爆炸

2.1.2 LSTM(Hochreiter & Schmidhuber 1997)

LSTM 通过门控机制解决梯度问题:

关键洞察:加性更新 使梯度通过 通路不衰减。

2.1.3 GRU(Cho et al. 2014)

GRU 简化为两个门:

2.1.4 xLSTM(Beck et al. 2024, ICLR 2026 Outstanding)

xLSTM 是 LSTM 的现代复兴:

  • 指数门控(exponential gating)
  • 矩阵记忆(matrix memory)替代标量
  • 协方差更新规则
  • xLSTM-7B 工业实现,scaling laws 拟合
class XLSTMCell(nn.Module):
    """xLSTM Cell(简化版)"""
    def __init__(self, d_input, d_state):
        super().__init__()
        self.d_input = d_input
        self.d_state = d_state
 
        # 门控参数
        self.W_i = nn.Linear(d_input, d_state)
        self.W_f = nn.Linear(d_input, d_state)
        self.W_o = nn.Linear(d_input, d_state)
        self.W_z = nn.Linear(d_input, d_state)
 
        # 矩阵记忆
        self.C = nn.Parameter(torch.randn(d_state, d_state) * 0.01)
 
    def forward(self, x_t, h_prev, C_prev):
        i = torch.exp(self.W_i(x_t))  # 指数门控
        f = torch.exp(self.W_f(x_t))
        o = torch.sigmoid(self.W_o(x_t))
        z = self.W_z(x_t)
 
        # 协方差更新规则
        C_new = f.unsqueeze(-1) * C_prev + i.unsqueeze(-1) * torch.outer(torch.tanh(z), x_t)
        h_new = o * (C_new @ x_t)
        return h_new, C_new

2.1.5 M2RNN 矩阵值 RNN

M2RNN 使用 矩阵作为状态:

表达力:理论上比 vanilla RNN 指数级强,能记忆更多信息。

2.1.6 复杂度分析

模型每步复杂度总复杂度并行化
Vanilla RNN低(需 BPTT)
LSTM
xLSTM取决于实现中等
M2RNN

2.2 范式 2:状态空间模型(SSM 范式)

核心思想:连续-离散状态空间,将序列视为 ODE 的离散采样。

2.2.1 连续状态空间

其中

2.2.2 离散化

用零阶保持(ZOH)或双线性变换离散化:

2.2.3 S4(Gu, Goel, Re 2022)

S4 引入对角 + 低秩 (DPLR) 参数化 ,使 的高次幂可快速计算(基于 HiPPO 初始化)。

2.2.4 Mamba(Gu & Dao 2023)

Mamba 的关键创新是输入依赖的选择性 SSM

关键差异(与 S4): 依赖输入,实现选择性信息过滤。

2.2.5 Mamba-2 与 SSD 框架

Mamba-2 (Dao & Gu 2024) 引入状态空间对偶性 (SSD),证明:

SSM(半可分矩阵)= 结构化注意力(掩码 1-SS)

具体地,SSM 的输出 ,其中 半可分矩阵(可分解为 个秩一阵子之和)。这与结构化掩码注意力的矩阵形式完全等价

Mamba-2 块用 SSD 视角实现,比 Mamba-1 快 2-8 倍

2.2.6 Mamba-3 (ICLR 2026 Best Paper)

Mamba-3 的进一步创新:

  • 复值状态,更丰富的表达力
  • MIMO 架构:单步多输入多输出
  • 选择性增强:更精细的输入依赖机制

2.2.7 复杂度分析

模型训练复杂度推理复杂度内存
S4
Mamba-1
Mamba-2
Mamba-3

2.3 范式 3:注意力(Attention 范式)

核心思想:每步关注整个历史,通过 query-key 匹配检索。

2.3.1 标准注意力

  • 训练复杂度
  • 推理复杂度(带 KV Cache)
  • 归纳偏置:排列等变(无内置顺序概念)
  • 位置编码:RoPE、ALiBi、相对位置等

2.3.2 线性注意力(Linear Attention)

核心思想:把 softmax 替换为可分解的特征映射

计算重排

  • 朴素:
  • 线性:
class LinearAttention(nn.Module):
    """线性注意力实现"""
    def __init__(self, d_model, d_k, eps=1e-6):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_k, bias=False)
        self.W_k = nn.Linear(d_model, d_k, bias=False)
        self.W_v = nn.Linear(d_model, d_model)
        self.eps = eps
 
    def feature_map(self, x):
        """特征映射:x -> elu(x) + 1(保证非负)"""
        return F.elu(x) + 1
 
    def forward(self, x):
        B, T, d = x.shape
        Q = self.feature_map(self.W_q(x))  # (B, T, d_k)
        K = self.feature_map(self.W_k(x))  # (B, T, d_k)
        V = self.W_v(x)  # (B, T, d)
 
        # 线性化:Q (K^T V) 而非 (Q K^T) V
        KV = K.transpose(-2, -1) @ V  # (B, d_k, d)
        numerator = Q @ KV  # (B, T, d)
 
        # 归一化
        K_sum = K.sum(dim=1, keepdim=True)  # (B, 1, d_k)
        denominator = Q @ K_sum.transpose(-2, -1)  # (B, T, 1)
        return numerator / (denominator + self.eps)

2.3.3 各种 Linear Attention 变体

变体特征映射论文
Linear TransformerKatharopoulos 2020
Performer随机特征 Choromanski 2021
Random Feature Attention高斯随机特征Peng 2021
RetNet衰减 Sun 2023
Gated DeltaNet门控 + Delta 规则Yang 2024 (ICLR 2025)

2.3.4 状态视角

线性注意力可视为矩阵状态递推

这与 SSM 的形式完全平行——这是 SSD 框架的核心。

2.3.5 复杂度分析

模型训练推理表达力
Full Attention最高
Linear Attention中等
Performer近似 Full
RetNet衰减记忆
Gated DeltaNet

2.4 范式 4:测试时训练(TTT 范式)

核心思想:状态在测试时持续学习,而非固定参数。

2.4.1 TTT 的核心洞察

Sun et al. 2024 提出:

任何序列模型的状态都可以视为”模型权重”——即状态本身可以是通过梯度下降学习的参数。

这意味着:

  • RNN 的 可以是测试时学习的参数
  • SSM 的 可以是测试时学习的参数
  • Attention 的 KV Cache 可以是测试时学习的参数

2.4.2 TTT 层

TTT 层在测试时执行以下操作:
1. 维护一个"内部权重" W_t(状态)
2. 对当前输入 x_t 计算损失:L(W_t; x_t)
3. 用 SGD 更新:W_{t+1} = W_t - η ∇L(W_t; x_t)
4. 用 W_{t+1} 处理后续输入

双层优化

  • 外层:训练 TTT 层的”超参数”(学习率、初始化、损失函数)
  • 内层:测试时持续学习 W_t

2.4.3 Titans(Behrouz et al., NeurIPS 2025)

Titans 引入神经长期记忆 (Neural Long-Term Memory)

class NeuralMemory(nn.Module):
    """Titans 神经长期记忆模块"""
    def __init__(self, d_model, d_memory, mlp_hidden):
        super().__init__()
        # 记忆是 MLP
        self.memory_mlp = nn.Sequential(
            nn.Linear(d_model, mlp_hidden),
            nn.SiLU(),
            nn.Linear(mlp_hidden, d_memory),
        )
        # 三个学习信号
        self.theta_q = nn.Linear(d_model, d_memory)  # 查询
        self.theta_k = nn.Linear(d_model, d_memory)  # 键
        self.theta_v = nn.Linear(d_model, d_memory)  # 值
        self.eta = nn.Parameter(torch.zeros(d_memory))  # 门控学习率
 
    def forward(self, x_t, M_t):
        """
        x_t: (B, d_model) 当前输入
        M_t: (d_memory, d_memory) 记忆状态
        """
        # Surprise-based gating
        q = self.theta_q(x_t)
        k = self.theta_k(x_t)
        v = self.theta_v(x_t)
 
        # 计算 surprise
        with torch.no_grad():
            Mv = M_t @ v.unsqueeze(-1)
            surprise = (v.unsqueeze(-1) - Mv).norm(dim=-2)
 
        # 自适应学习率
        eta = F.sigmoid(self.eta) * surprise
 
        # 梯度下降更新记忆
        grad = torch.outer(k, v)  # 一阶梯度近似
        M_new = M_t - eta.unsqueeze(-1) * grad
 
        return M_new, M_new @ q

关键

  • 记忆权重 通过测试时梯度下降更新
  • 学习率 是输入依赖的(自适应门控)
  • 突破 200 万 token 上下文长度

2.4.4 MIRAS (Behrouz et al. 2025)

MIRAS 提出统一框架

所有序列模型都是联想记忆模块(Associative Memory Modules)

包括:

  • Attention = Hopfield 网络
  • SSM = 线性联想记忆
  • RNN = 有限状态联想记忆
  • TTT = 持续学习的联想记忆

YAAD、MONETA、MEMORA 是 MIRAS 框架下的具体实例。

2.4.5 复杂度分析

TTT 的计算开销主要在测试时梯度计算

  • 单步: 是内层步数)
  • 与序列长度解耦

三、统一数学框架

3.1 线性可加性是统一的核心

几乎所有现代序列模型都可以表示为线性可加递推

其中 可能是输入依赖的。

3.1.1 各范式在此框架下

范式维度
Vanilla RNN
LSTM (对角门)
Linear SSM (HiPPO/Mamba)
Mamba (输入依赖)
Mamba-3 (复值)
Linear Attention
TTT/Titans

3.2 状态维度的本质差异

  • RNN/SSM:状态 与序列长度无关 → 状态
  • Linear Attention:状态 ,矩阵 → 状态
  • Full Attention:状态 = 完整 K, V 历史 → 状态
  • TTT/Titans:状态 = 可学习参数 → 可任意大

3.3 SSD 框架的严格证明

定理(Dao & Gu 2024):设 是半可分矩阵:

或等价地:

既可以用 SSM(前向递推)计算,也可以用结构化掩码注意力计算,且完全等价

证明概要

  • SSM 视角:
  • 注意力视角:
  • 矩阵 在两种视角下形式相同

3.4 记忆容量理论

不同范式的记忆容量不同:

范式容量理论依据
Vanilla RNN信息瓶颈
LSTM 有效门控机制
SSM状态维度
Full Attention 精确KV Cache
Linear Attention 有限状态维度
TTT/Titans 可学习神经长期记忆

信息论下界:任何状态维度 的序列模型最多记忆 比特信息(Shannon 论证)。


四、复杂度-表达力-效率三角

4.1 三角权衡

任何序列模型都面临三个核心约束:

  1. 计算复杂度:训练与推理的计算开销
  2. 表达力:能学习的函数类别
  3. 实现效率:并行化能力、内存占用
                表达力
                  ▲
                 /│\
                / │ \
               /  │  \
              /   │   \
             /    │    \
            /  ◆  │  ◆  \
           / RNN  │ Attn \
          / 范式  │  范式  \
         /───────┼────────\
        /        │        \
       /  SSM    │  Linear \
      /  范式    │  Attn   \
     /──────────┴──────────\
    复杂度  ─────────────►

没有银弹:每种范式都在三角上做了不同取舍。

4.2 复杂度对比表

范式训练 推理 内存 并行化
Vanilla RNN
LSTM
xLSTM
S4
Mamba-1
Mamba-2
Linear Attn
Full Attn中(KV 缓存)
TTT
Titans

4.3 表达力对比

范式形式语言长程依赖状态查询算术
Vanilla RNNTC⁰
LSTMTC⁰部分
SSM/S4TC⁰
MambaTC⁰选择性
Full AttnTC⁰精确
Linear Attn有限近似
TTT任意(测试时)自学习可调

关键洞察:Full Attention 的表达力理论上最强(任何 TC⁰ 函数),但实践中未必——LSTM/SSM 在特定任务上能匹敌或超越。

4.4 TC⁰ 复杂性类

Merrill & Sabharwal 2023 严格证明:

标准 Transformer 精确识别的形式语言类 = TC⁰(阈值电路常数深度多项式大小)

而 RNN(包括 LSTM、GRU)识别能力严格弱于 TC⁰。这是 Transformer vs RNN 表达力的严格分层

TC⁰ 包含

  • 加法、乘法
  • 多数表决
  • 阈值函数
  • 排序

TC⁰ 不包含(需要 TC¹):

  • 匹配括号(Dyck-)的完整识别
  • 一般图灵可计算函数

但 Transformer 通过”软”近似仍能在实践中处理这些任务。


五、测试时学习(TTT)作为新维度

5.1 传统序列建模的局限

所有传统范式都有一个隐含假设:

训练时学习参数,推理时固定参数

这个假设限制了信息存储适应能力

5.2 TTT 的革命性

TTT 引入了第三维度:

       训练时 ──────── 推理时
         ↓                ↓
  传统:学习参数       固定参数
  TTT:  学习超参数     持续学习参数

这意味着

  • 模型可以记住训练时未见过的信息
  • 模型可以根据输入分布自适应
  • 模型在长上下文中能持续学习

5.3 TTT 的统一视角

TTT 层可视为**“模型作为模型权重”**:

  • 状态 模型权重
  • 输入 训练数据
  • 更新 梯度下降

因此 TTT 是一种元学习形式(learning to learn)。

5.4 TTT vs 微调

维度微调TTT
更新频率训练阶段推理时每步
数据量大批量单样本
计算位置离线在线
内存开销全模型单层状态
性能稳定性稳定可能不稳定

5.5 Titans 的具体创新

Titans (Behrouz et al., NeurIPS 2025) 在 TTT 基础上:

  1. 三层记忆架构

    • 短期:滑动窗口注意力
    • 中期:Persistent Memory(可学习参数)
    • 长期:神经长期记忆(测试时学习)
  2. 自适应学习率

    学习率与”surprise”成正比——惊讶越大,学习越快

  3. 门控机制

    用门控控制记忆更新强度

  4. 实验结果

    • Needle-in-Haystack:200 万 token
    • 性能超越 Mamba-2 和 Transformer

六、混合架构

6.1 为什么需要混合

没有单一范式在所有任务上最优

  • Transformer:短序列强表达力、长序列贵
  • SSM:长序列高效、表达力受限
  • TTT:长上下文、自适应

混合策略:用 Attention 处理局部 + SSM 处理全局 + TTT 长期记忆。

6.2 代表性混合架构

6.2.1 Jamba(AI21, 2024)

7B 参数,52 层,1:7 Attention:SSM 比例:

class JambaBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_state, mamba_expand, attn_period=8):
        super().__init__()
        # 每 8 层用 1 个 Attention
        self.use_attn = False
        if attn_period and random.random() < 1/attn_period:
            self.use_attn = True
            self.attn = Attention(d_model, n_heads)
 
        self.mamba = MambaBlock(d_model, d_state, mamba_expand)
        self.moe = MoE(d_model, n_experts=16, top_k=2)
 
    def forward(self, x):
        if self.use_attn:
            x = x + self.attn(x)
        x = x + self.mamba(x)
        x = x + self.moe(x)
        return x

6.2.2 TransMamba(2025)

层次化混合:底层 Mamba,高层 Attention。

6.2.3 Hymba(NVIDIA, 2025)

头级别混合:每头决定用 Attention 还是 SSM。

6.2.4 HydraHead

类似 Hymba,但有更细粒度的头级别门控。

6.3 混合设计原则

原则说明
Attention 占比1:8 到 1:4 是常见区间
层次浅层 Attention、深层 SSM(或反之)
头级别不同头用不同范式
门控软门控学习权重

七、2025-2026 前沿进展

7.1 Mamba-3 (ICLR 2026 Best Paper)

Mamba-3 引入:

  1. 复值状态

    • 表达力翻倍
    • 自然编码振荡模式
    • 谱方法更易分析
  2. MIMO 架构:单步多输入多输出

    • 每次更新处理多个 token
    • 提高硬件利用率
  3. 选择性增强

    更精细的输入依赖离散化

7.2 Titans (NeurIPS 2025)

Titans 的核心创新:

  • 神经长期记忆(MLP 作为状态)
  • 自适应学习率
  • 三层记忆架构

实验结果

  • Needle-in-Haystack:200 万 token 99% 准确率
  • BABILong 基准:超越 Transformer 和 Mamba-2
  • 训练效率:2-10x 加速

7.3 MIRAS(Behrouz et al. 2025)

MIRAS 提出统一框架

所有序列模型 = 联想记忆模块

包括四大家族:

  • YAAD (You Actually Add Delta):测试时学习
  • MONETA:状态作为测试时学习
  • MEMORA:多时间尺度记忆
  • TITANS:神经长期记忆

核心洞察

“将序列模型视为联想记忆模块,统一了 RNN/SSM/Attention/TTT 的本质”

7.4 扩散序列模型

2025-2026 年,扩散模型开始应用于序列建模:

  • Diffusion Language Model (LLaDA):用掩码扩散替代自回归
  • MDLM (Masked Diffusion Language Model):扩展到 7B 参数
  • dLLM:离散扩散 LM
class MaskedDiffusionLM(nn.Module):
    """简化版掩码扩散 LM"""
    def __init__(self, d_model, vocab_size, n_layers):
        super().__init__()
        self.backbone = Transformer(d_model, n_layers)
        self.vocab_proj = nn.Linear(d_model, vocab_size)
 
    def forward(self, x_t, t):
        """x_t 是 t 时刻被掩码的序列"""
        h = self.backbone(x_t)
        # 预测被掩码的位置
        logits = self.vocab_proj(h)
        return logits
 
    def sample(self, n_steps, batch_size, seq_len):
        """反向扩散采样"""
        x = torch.full((batch_size, seq_len), MASK_TOKEN_ID)
        for t in reversed(range(n_steps)):
            logits = self.forward(x, t)
            # 按噪声调度采样被掩码的位置
            x = self.mask_schedule.sample_step(logits, x, t)
        return x

7.5 状态空间对偶性 (SSD) 的新发展

Mamba-2 SSD 框架在 2025-2026 进一步推广:

  • 多层次 SSD:跨层应用 SSD
  • Gated SSD:加门控的 SSD(Mamba-2 用 Gated DeltaNet 接近)
  • 稀疏 SSD:稀疏化的 SSD(用于长上下文)
  • Hybrid SSD:与 Attention 混合的 SSD

八、选型决策树

8.1 决策流程

任务: 序列建模
  │
  ├─ 序列长度 L < 1K?
  │   └─ 是 → Full Attention (表达力优先)
  │
  ├─ L > 32K 且硬件受限?
  │   ├─ 是 → Mamba-2/3 或 Linear Attention
  │   └─ 否 → 考虑混合架构
  │
  ├─ 需要超长上下文(> 1M tokens)?
  │   └─ 是 → Titans / TTT 类
  │
  ├─ 任务对长程依赖敏感(如代码)?
  │   └─ 是 → Titans 或 Mamba-3
  │
  ├─ 多模态/视频/音频?
  │   └─ 是 → 混合架构(Jamba 风格)
  │
  └─ 默认 → Transformer (Mamba-2 备份)

8.2 任务-范式映射表

任务推荐范式备选
短文本分类Full AttentionRNN 即可
长文档 QAMamba-2/3 或 TitansHybrid
代码生成Full Attention + 滑动窗口Titans
时间序列预测xLSTM 或 MambaTCN
视频理解Hybrid (Jamba 风格)Mamba-3
语音识别Conformer (CNN + Attention)Mamba-2
强化学习策略LSTM (xLSTM)RNN
长程对话TitansMamba-3
实时流式SSM (因果)Linear Attn

8.3 工程考虑

  • 训练资源:Full Attention 训练最贵,Linear 训练最便宜
  • 推理延迟:RNN/SSM 恒定延迟,Attention 增长
  • KV 缓存:Full Attention 内存爆炸
  • 硬件友好:SSM 在 GPU 上有专门 kernel
  • 可扩展性:Titans/Mamba 适合长上下文

九、完整 PyTorch 实现:UnifiedSequenceBlock

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from enum import Enum
 
 
class SequenceKind(Enum):
    RNN = "rnn"
    LSTM = "lstm"
    XLSTM = "xlstm"
    SSM = "ssm"
    MAMBA = "mamba"
    LINEAR_ATTN = "linear_attn"
    TITANS = "titans"
    HYBRID = "hybrid"
 
 
class UnifiedSequenceBlock(nn.Module):
    """统一序列建模块:可在四大范式间切换"""
    def __init__(self, d_model, d_state, kind, n_heads=4, mlp_hidden=None):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.kind = SequenceKind(kind)
        self.n_heads = n_heads
        self.mlp_hidden = mlp_hidden or 4 * d_model
 
        # 归一化
        self.norm = nn.RMSNorm(d_model)
        self.norm_state = nn.LayerNorm(d_state) if self.kind in [SequenceKind.RNN, SequenceKind.LSTM, SequenceKind.SSM] else nn.Identity()
 
        # 范式特定组件
        if self.kind == SequenceKind.RNN:
            self.cell = nn.RNNCell(d_model, d_state)
            self.out_proj = nn.Linear(d_state, d_model)
        elif self.kind == SequenceKind.LSTM:
            self.cell = nn.LSTMCell(d_model, d_state)
            self.out_proj = nn.Linear(d_state, d_model)
        elif self.kind == SequenceKind.XLSTM:
            self.cell = XLSTMCell(d_model, d_state)
            self.out_proj = nn.Linear(d_state, d_model)
        elif self.kind == SequenceKind.SSM:
            # S4 简化版
            self.A = nn.Parameter(torch.randn(d_state, d_state) * 0.01)
            self.B = nn.Linear(d_model, d_state)
            self.C = nn.Linear(d_state, d_model)
            self.D = nn.Parameter(torch.zeros(d_model))
            self.log_dt = nn.Parameter(torch.zeros(d_state))
        elif self.kind == SequenceKind.MAMBA:
            # 简化 Mamba 块
            self.A = nn.Parameter(torch.randn(d_state, d_state) * 0.01)
            self.B_proj = nn.Linear(d_model, d_state)
            self.C_proj = nn.Linear(d_model, d_state)
            self.dt_proj = nn.Linear(d_model, d_state)
            self.out_proj = nn.Linear(d_state, d_model)
        elif self.kind == SequenceKind.LINEAR_ATTN:
            self.W_q = nn.Linear(d_model, d_state, bias=False)
            self.W_k = nn.Linear(d_model, d_state, bias=False)
            self.W_v = nn.Linear(d_model, d_model)
            self.feature_map = lambda x: F.elu(x) + 1
        elif self.kind == SequenceKind.TITANS:
            self.memory_mlp = nn.Sequential(
                nn.Linear(d_model, d_state),
                nn.SiLU(),
                nn.Linear(d_state, d_state * d_state),
            )
            self.theta_q = nn.Linear(d_model, d_state)
            self.theta_k = nn.Linear(d_model, d_state)
            self.theta_v = nn.Linear(d_model, d_state)
            self.eta = nn.Parameter(torch.zeros(d_state))
        elif self.kind == SequenceKind.HYBRID:
            # Jamba 风格:Attention + SSM
            self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
            self.A = nn.Parameter(torch.randn(d_state, d_state) * 0.01)
            self.B = nn.Linear(d_model, d_state)
            self.C = nn.Linear(d_state, d_model)
 
        # 共享 FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, self.mlp_hidden),
            nn.GELU(),
            nn.Linear(self.mlp_hidden, d_model),
        )
 
    def forward(self, x):
        """x: (B, T, d_model)"""
        B, T, d = x.shape
 
        if self.kind == SequenceKind.RNN:
            h = torch.zeros(B, self.d_state, device=x.device)
            outputs = []
            for t in range(T):
                h = self.cell(x[:, t], h)
                outputs.append(self.out_proj(h))
            y = torch.stack(outputs, dim=1)
 
        elif self.kind == SequenceKind.LSTM:
            h = torch.zeros(B, self.d_state, device=x.device)
            c = torch.zeros(B, self.d_state, device=x.device)
            outputs = []
            for t in range(T):
                h, c = self.cell(x[:, t], (h, c))
                outputs.append(self.out_proj(h))
            y = torch.stack(outputs, dim=1)
 
        elif self.kind == SequenceKind.XLSTM:
            h = torch.zeros(B, self.d_state, device=x.device)
            C = torch.zeros(B, self.d_state, self.d_state, device=x.device)
            outputs = []
            for t in range(T):
                h, C = self.cell(x[:, t], h, C)
                outputs.append(self.out_proj(h))
            y = torch.stack(outputs, dim=1)
 
        elif self.kind == SequenceKind.SSM:
            h = torch.zeros(B, self.d_state, device=x.device)
            A = -torch.exp(self.log_dt).unsqueeze(0) * self.A
            outputs = []
            for t in range(T):
                h = h @ A + self.B(x[:, t]).unsqueeze(1) * 1.0
                outputs.append(self.C(h))
            y = torch.stack(outputs, dim=1) + x * self.D
 
        elif self.kind == SequenceKind.MAMBA:
            h = torch.zeros(B, self.d_state, device=x.device)
            outputs = []
            for t in range(T):
                # 输入依赖
                B_t = self.B_proj(x[:, t])
                C_t = self.C_proj(x[:, t])
                dt = F.softplus(self.dt_proj(x[:, t]))
 
                # 离散化
                A_bar = torch.exp(dt.unsqueeze(-1) * self.A)
                h = h * A_bar + B_t * x[:, t].unsqueeze(1)
                y_t = (C_t.unsqueeze(1) * h).sum(dim=-1)
                outputs.append(self.out_proj(y_t))
            y = torch.stack(outputs, dim=1)
 
        elif self.kind == SequenceKind.LINEAR_ATTN:
            Q = self.feature_map(self.W_q(x))  # (B, T, d_state)
            K = self.feature_map(self.W_k(x))
            V = self.W_v(x)
 
            # 状态递推
            S = torch.zeros(B, self.d_state, d, device=x.device)
            outputs = []
            for t in range(T):
                S = S + torch.bmm(K[:, t:t+1].transpose(-2, -1), V[:, t:t+1])
                y_t = torch.bmm(Q[:, t:t+1], S)
                outputs.append(y_t)
            y = torch.cat(outputs, dim=1)
 
        elif self.kind == SequenceKind.TITANS:
            # 神经长期记忆
            M = self.memory_mlp(x[:, 0]).view(B, self.d_state, self.d_state)
            outputs = []
            for t in range(T):
                q = self.theta_q(x[:, t])
                k = self.theta_k(x[:, t])
                v = self.theta_v(x[:, t])
 
                # 读取
                y_t = torch.bmm(M, q.unsqueeze(-1)).squeeze(-1)
 
                # Surprise-based 更新
                with torch.no_grad():
                    Mv = torch.bmm(M, v.unsqueeze(-1))
                    surprise = (v.unsqueeze(-1) - Mv).norm(dim=-2)
                eta = torch.sigmoid(self.eta) * surprise
                grad = torch.bmm(k.unsqueeze(-1), v.unsqueeze(1))
                M = M - eta.unsqueeze(-1) * grad
                outputs.append(y_t)
            y = torch.stack(outputs, dim=1)
 
        elif self.kind == SequenceKind.HYBRID:
            # Jamba 风格:先 Attention,再 SSM
            x_norm = self.norm(x)
            attn_out, _ = self.attn(x_norm, x_norm, x_norm)
            x = x + attn_out
 
            h = torch.zeros(B, self.d_state, device=x.device)
            outputs = []
            for t in range(T):
                h = h @ self.A + self.B(x[:, t])
                outputs.append(self.C(h))
            y = torch.stack(outputs, dim=1)
 
        # 残差 + FFN
        y = x + y
        y = y + self.ffn(self.norm(y))
        return y
 
 
# 使用示例
if __name__ == "__main__":
    B, T, d_model = 4, 64, 128
    d_state = 64
    x = torch.randn(B, T, d_model)
 
    for kind in SequenceKind:
        model = UnifiedSequenceBlock(d_model, d_state, kind.value)
        y = model(x)
        print(f"{kind.value}: input {x.shape} -> output {y.shape}")

十、未来方向

10.1 开放问题

  1. 统一架构:是否存在统一范式同时实现 RNN 效率、SSM 长程、Attention 表达力、TTT 自适应?
  2. 硬件-算法协同设计:新硬件(如 HBM-GPU、Wafer-Scale)将如何改变序列建模范式?
  3. 超大规模记忆 状态维度的 TTT 实用化
  4. 跨模态序列:视频、音频、3D 统一建模
  5. 测试时计算与记忆的协同:用 TTT 替代 RAG?

10.2 2026 年值得关注的论文

  • Mamba-3 (ICLR 2026 Best Paper)
  • Titans / MIRAS (NeurIPS 2025)
  • xLSTM-7B 工业部署
  • Diffusion Language Model 系列
  • 状态空间对偶性 进一步推广

10.3 与其他领域的交叉

  • 强化学习:TTT 与元学习的天然联系
  • 神经科学:海马体-皮层记忆系统的工程类比
  • 符号推理:RNN 的有限状态机视角
  • 优化理论:TTT 是隐式正则化器

总结

本文建立了序列建模的统一分类学,将四大范式(RNN/SSM/Attention/TTT)整合到同一数学框架内。关键洞察:

  1. 线性可加递推 是几乎所有现代序列模型的核心
  2. 状态空间对偶性 (SSD) 揭示了 SSM 与 Linear Attention 的等价性
  3. 测试时训练 (TTT) 引入了”测试时持续学习”的新维度
  4. 混合架构是当前最优实践,平衡表达力与效率

2025-2026 年见证了三个重大突破:

  • Mamba-3 (ICLR 2026 Best):复值状态
  • Titans (NeurIPS 2025):神经长期记忆
  • MIRAS:所有序列模型 = 联想记忆

未来方向是统一架构——在 RNN 效率、SSM 长程、Attention 表达力、TTT 自适应之间找到最优平衡点。1


参考资料

Footnotes

  1. 本文综合了以下工作的核心思想:Katharopoulos et al. 2020 (Linear Attention)、Gu et al. 2022 (S4)、Gu & Dao 2023 (Mamba)、Dao & Gu 2024 (Mamba-2/SSD)、Behrouz et al. 2025 (Titans/MIRAS)、Beck et al. 2024 (xLSTM)、Dao et al. 2022 (FlashAttention)、Sun et al. 2024 (TTT)、Merrill & Sabharwal 2023 (TC⁰)、Chizat & Bach 2024 (Mean-Field) 等。 2