概述

2024 年,Albert Gu 和 Tri Dao 发表了里程碑式的论文 Mamba-2,提出了状态空间对偶性(State Space Duality, SSD) 理论,首次在数学上统一了 LSTM、状态空间模型(SSM)和 Transformer 三种看似不同的序列建模范式。1

这一理论不仅解释了为什么 SSM(如 Mamba)在许多任务上能够与 Transformer 竞争,更重要的是揭示了深度学习中序列建模的本质。


从 LSTM 到 SSM 的演进

标准 LSTM 的矩阵形式

回忆 LSTM 的核心方程:

简化 LSTM

为揭示与 SSM 的联系,考虑简化 LSTM(没有窥视孔):

其中 是输入依赖的门控。

SSM 的标准形式

连续时间 SSM

离散化 SSM(步长 ):

其中


结构化半可分矩阵

定义

结构化半可分矩阵(Structured Semiseparable, SSS)是 SSD 理论的核心数学对象。

定义:矩阵 秩- 结构化半可分的,如果存在向量 使得:

即矩阵的下三角部分和上三角部分可以分别用低秩分解表示

SSS 矩阵的可视化

SSS 矩阵结构:

    j=1  2  3  4  5
i=1 [ a₁b₁ a₁b₂ a₁b₃ a₁b₄ a₁b₅ ]
i=2 [ a₂b₁ a₂b₂ a₂b₃ a₂b₄ a₂b₅ ]
i=3 [ a₃b₁ a₃b₂ a₃b₃ a₃b₄ a₃b₅ ]
i=4 [ a₄b₁ a₄b₂ a₄b₃ a₄b₄ a₄b₅ ]
i=5 [ a₅b₁ a₅b₂ a₅b₃ a₅b₄ a₅b₅ ]

下三角: M[i,j] = a_i b_j^T  (i ≥ j)
上三角: M[i,j] = a_j b_i^T  (i < j)

SSS 矩阵的性质

  1. 低秩表示:可用 参数表示整个 矩阵
  2. 矩阵-向量乘法:可在 时间内完成(而非
  3. 递归结构:等价于线性 RNN

SSM 视角的 SSD

SSM 的矩阵形式

设 SSM 的离散化参数为 ,定义状态转移矩阵

核心定理

对于对角 的 SSM,其选择性扫描(Selective Scan) 操作等价于与某个 SSS 矩阵的乘法:

其中 是秩为 的 SSS 矩阵。

数学推导

递归形式

展开为:

这可以写成矩阵形式:

其中 ,否则为 0。


注意力视角的 SSD

标准注意力矩阵

标准 Transformer 的注意力矩阵为:

结构化掩码注意力

SSD 理论引入结构化掩码注意力(Structured Masked Attention, SMA):

定义:SMA 操作定义为矩阵-向量乘积:

其中 是掩码矩阵。

关键连接

定理(SSM-注意力对偶性):

对于对角转移矩阵的 SSM,存在一个结构化掩码 使得:

其中 是 SSM 对应的 SSS 矩阵。

直观理解

Transformer 注意力 vs SSM

Transformer:     SSM (SSD):
┌─────────┐      ┌─────────┐
│ ■ ■ □ □ │      │ ■       │
│ ■ ■ ■ □ │  ≈   │ ■ ■     │
│ ■ ■ ■ ■ │      │ ■ ■ ■   │
│ ■ ■ ■ ■ │      │ ■ ■ ■ ■ │
└─────────┘      └─────────┘

■ = 强连接  □ = 弱连接

SSM 的下三角结构是 SMA 掩码的特殊形式

统一的理论框架

三元对偶性

SSD 理论建立了三种表示之间的等价性

           SSM 视角
          (递归形式)
              │
              │  ← 等价变换
              ▼
┌─────────────────────────┐
│                         │
│  结构化半可分矩阵 (SSS)  │ ← SSM-注意力对偶
│                         │
└─────────────────────────┘
              ▲
              │  ← 矩阵分解
              │
      注意力视角
      (矩阵形式)

表示转换

视角表示形式优点
SSM递归: 因果建模、可解释性
SSS矩阵: 并行计算、算法优化
SMA掩码: 与 Transformer 兼容

Mamba-2 架构

核心设计

Mamba-2 基于 SSD 框架,具有以下关键特性:2

class Mamba2Block(nn.Module):
    """
    Mamba-2 Block
    
    基于状态空间对偶性设计
    """
    def __init__(self, d_model, d_state=128, d_conv=4, expand=2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        d_inner = d_model * expand
        
        # 输入投影
        self.x_proj = nn.Linear(d_model, d_state + 2 * d_conv, bias=False)
        
        # 输出投影
        self.dt_proj = nn.Linear(d_state, d_model)
        self.B_proj = nn.Linear(d_conv, d_state, bias=False)
        self.C_proj = nn.Linear(d_state, d_conv, bias=False)
        
        # SSM 参数
        self.A_log = nn.Parameter(torch.randn(d_state, d_conv))
        self.D = nn.Parameter(torch.ones(d_conv))
        
        # 归一化
        self.norm = nn.LayerNorm(d_model)
        
        # 输出投影
        self.out_proj = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        """
        Args:
            x: (batch, seq_len, d_model)
        Returns:
            out: (batch, seq_len, d_model)
        """
        residual = x
        x = self.norm(x)
        
        # 投影得到 SSM 参数
        x_dbl = self.x_proj(x)  # (B, L, d_state + 2*d_conv)
        dt, B, C = x_dbl.split([self.d_state, self.d_conv, self.d_conv], dim=-1)
        
        # 离散化
        A = -torch.exp(self.A_log.float())  # 确保稳定性
        dA = torch.exp(dt @ A.T)  # (B, d_state, d_conv)
        dB = B.unsqueeze(-1) * dt.unsqueeze(-2)  # (B, d_conv, 1)
        
        # 选择性扫描 (SSD)
        y = self.selective_scan(
            dA, dB, C, self.D, x
        )
        
        # 输出投影
        y = self.out_proj(y)
        
        return y + residual
    
    def selective_scan(self, dA, dB, C, D, u):
        """
        Mamba-2 的选择性扫描
        
        利用 SSM-SSS 对偶性高效实现
        """
        B, T, d_state = dA.shape[:3]
        d_conv = dA.shape[-1]
        
        # 重组为 SSS 矩阵形式
        # 使用并行前缀扫描算法
        dA_cumsum = torch.cumsum(dA, dim=1)
        
        # 简化的扫描实现
        ys = []
        h = torch.zeros(B, d_state, d_conv, device=dA.device)
        
        for t in range(T):
            # SSM 递归更新
            h = dA[:, t] * h + dB[:, t] * u[:, t:t+1]
            y = (C[:, t] @ h).squeeze(-1)
            ys.append(y)
        
        y = torch.stack(ys, dim=1)
        
        # 跳跃连接
        y = y + u * D
        
        return y

与 Mamba-1 的对比

特性Mamba-1Mamba-2
状态维度可变固定 (SSM-SSD)
并行化关联扫描原生并行
注意力兼容性
训练速度基线2-8× 提升
理论框架启发式SSD 数学严格

硬件感知优化

Mamba-2 使用 FlashAttention 风格的内存层级优化:

def mamba2_flash_scan(dA, dB, C, D, u, chunk_size=64):
    """
    分块计算以优化 GPU 内存使用
    """
    B, T, d_state, d_conv = dA.shape
    
    # 分块处理
    num_chunks = (T + chunk_size - 1) // chunk_size
    
    # 状态扩展因子
    scale = torch.exp(torch.cumsum(torch.log(dA + 1e-6), dim=1))
    
    ys = []
    h_prev = torch.zeros(B, d_state, d_conv, device=dA.device)
    
    for i in range(num_chunks):
        start = i * chunk_size
        end = min((i + 1) * chunk_size, T)
        
        # 块内计算
        chunk_u = u[:, start:end]
        chunk_scale = scale[:, start:end]
        
        # 重置状态检测
        reset = detect_reset(chunk_scale)
        
        # 更新状态
        h = reset * 0 + (1 - reset) * (chunk_scale[:, :-1] * h_prev)
        
        # 块内并行扫描
        chunk_y = parallel_scan(h, chunk_u, C[:, start:end], D)
        
        ys.append(chunk_y)
        h_prev = h[:, -1:]
    
    return torch.cat(ys, dim=1)

LSTM 作为 SSM 的特例

连接推导

考虑一个简化的线性 LSTM(无非线性):

为常数(无输入依赖),,则:

这正是齐次 SSM 的形式!

LSTM 门控 = SSM 选择性

关键洞察:LSTM 的输入依赖门控本质上是在选择性地修改 SSM 的转移矩阵

LSTM 门控SSM 对应
遗忘门 状态衰减率
输入门 输入权重
输出门 输出投影

表达力比较

定理(表达力关系):

  1. SSM 的表达能力 固定门控 LSTM
  2. 带选择性扫描的 SSM 可以模拟 任意输入依赖的 LSTM
  3. 但 SSM 不一定能模拟 Transformer 的全局注意力

实践意义

为什么 SSD 重要?

1. 算法设计

  • SSD 提供了一致的框架来设计新的序列模型
  • 可以在 SSM 和注意力之间自由切换

2. 硬件优化

  • SSD 使得 SSM 可以利用 Transformer 的优化(如 FlashAttention)
  • Mamba-2 的训练速度提升 2-8 倍

3. 理论理解

  • 揭示了不同架构之间的数学联系
  • 帮助理解为什么某些模型有效

未来方向

  1. 混合架构:结合 SSM 的效率和注意力的大上下文
  2. 新变体:基于 SSD 设计更多高效架构
  3. 理论深化:更深入理解表达力边界

代码:完整的 SSD 框架

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class SSM_SSD_Block(nn.Module):
    """
    基于状态空间对偶性的完整 SSM Block
    
    支持两种模式:
    - 'ssm': 纯 SSM 递归
    - 'attention': 结构化掩码注意力
    - 'hybrid': 混合模式
    """
    def __init__(self, d_model, d_state=128, mode='ssm'):
        super().__init__()
        self.mode = mode
        self.d_model = d_model
        self.d_state = d_state
        
        # 输入投影
        self.x_proj = nn.Linear(d_model, d_state * 2, bias=False)
        
        # SSM 参数
        self.A = nn.Parameter(torch.randn(d_state, d_state))
        self.B = nn.Parameter(torch.randn(d_state, 1))
        self.C = nn.Parameter(torch.randn(1, d_state))
        self.D = nn.Parameter(torch.ones(1))
        
        # 注意力头(用于混合模式)
        if mode == 'hybrid':
            self.num_heads = 8
            self.head_dim = d_model // self.num_heads
            self.q_proj = nn.Linear(d_model, d_model)
            self.k_proj = nn.Linear(d_model, d_model)
            self.v_proj = nn.Linear(d_model, d_model)
            self.o_proj = nn.Linear(d_model, d_model)
        
        self.norm = nn.LayerNorm(d_model)
        
        # 初始化 A 为稳定矩阵
        nn.init.normal_(self.A, 0, 0.02)
    
    def ssm_forward(self, x):
        """SSM 前向传播"""
        B, T, D = x.shape
        
        # 投影得到 B, C
        x_proj = self.x_proj(x)
        B_t = x_proj[:, :, :self.d_state].sigmoid()  # (B, T, d_state)
        C_t = x_proj[:, :, self.d_state:].sigmoid()  # (B, T, d_state)
        
        # 初始化状态
        h = torch.zeros(B, self.d_state, device=x.device)
        
        ys = []
        for t in range(T):
            # SSM 更新
            h = F.linear(h, self.A) + B_t[:, t] * self.B * x[:, t]
            y = (C_t[:, t] @ self.C) * x[:, t] + self.D * x[:, t]
            ys.append(y)
        
        return torch.stack(ys, dim=1)
    
    def attention_forward(self, x):
        """结构化掩码注意力前向传播"""
        B, T, D = x.shape
        
        # Q, K, V
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)
        
        # 重塑为多头
        Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 计算注意力分数
        scale = math.sqrt(self.head_dim)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
        
        # 结构化掩码:下三角(因果)
        mask = torch.tril(torch.ones(T, T, device=x.device))
        scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Softmax
        attn = F.softmax(scores, dim=-1)
        
        # 输出
        out = torch.matmul(attn, V)
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        
        return self.o_proj(out)
    
    def forward(self, x):
        residual = x
        x = self.norm(x)
        
        if self.mode == 'ssm':
            out = self.ssm_forward(x)
        elif self.mode == 'attention':
            out = self.attention_forward(x)
        else:  # hybrid
            # 结合 SSM 和注意力
            ssm_out = self.ssm_forward(x)
            attn_out = self.attention_forward(x)
            out = 0.5 * ssm_out + 0.5 * attn_out
        
        return out + residual

参考


相关阅读

Footnotes

  1. Gu, A., & Dao, T. (2024). “Mamba-2: State Space Duality at Scale”. arXiv:2405.21060.

  2. Dao, T., & Gu, A. (2024). “Mamba: Linear-Time Sequence Modeling with Selective State Spaces”. arXiv:2312.00752.