概述
2024 年,Albert Gu 和 Tri Dao 发表了里程碑式的论文 Mamba-2,提出了状态空间对偶性(State Space Duality, SSD) 理论,首次在数学上统一了 LSTM、状态空间模型(SSM)和 Transformer 三种看似不同的序列建模范式。1
这一理论不仅解释了为什么 SSM(如 Mamba)在许多任务上能够与 Transformer 竞争,更重要的是揭示了深度学习中序列建模的本质。
从 LSTM 到 SSM 的演进
标准 LSTM 的矩阵形式
回忆 LSTM 的核心方程:
简化 LSTM
为揭示与 SSM 的联系,考虑简化 LSTM(没有窥视孔):
其中 是输入依赖的门控。
SSM 的标准形式
连续时间 SSM:
离散化 SSM(步长 ):
其中 ,。
结构化半可分矩阵
定义
结构化半可分矩阵(Structured Semiseparable, SSS)是 SSD 理论的核心数学对象。
定义:矩阵 是秩- 结构化半可分的,如果存在向量 和 使得:
即矩阵的下三角部分和上三角部分可以分别用低秩分解表示。
SSS 矩阵的可视化
SSS 矩阵结构:
j=1 2 3 4 5
i=1 [ a₁b₁ a₁b₂ a₁b₃ a₁b₄ a₁b₅ ]
i=2 [ a₂b₁ a₂b₂ a₂b₃ a₂b₄ a₂b₅ ]
i=3 [ a₃b₁ a₃b₂ a₃b₃ a₃b₄ a₃b₅ ]
i=4 [ a₄b₁ a₄b₂ a₄b₃ a₄b₄ a₄b₅ ]
i=5 [ a₅b₁ a₅b₂ a₅b₃ a₅b₄ a₅b₅ ]
下三角: M[i,j] = a_i b_j^T (i ≥ j)
上三角: M[i,j] = a_j b_i^T (i < j)
SSS 矩阵的性质
- 低秩表示:可用 参数表示整个 矩阵
- 矩阵-向量乘法:可在 时间内完成(而非 )
- 递归结构:等价于线性 RNN
SSM 视角的 SSD
SSM 的矩阵形式
设 SSM 的离散化参数为 ,定义状态转移矩阵:
核心定理:
对于对角 的 SSM,其选择性扫描(Selective Scan) 操作等价于与某个 SSS 矩阵的乘法:
其中 是秩为 的 SSS 矩阵。
数学推导
递归形式:
展开为:
这可以写成矩阵形式:
其中 当 ,否则为 0。
注意力视角的 SSD
标准注意力矩阵
标准 Transformer 的注意力矩阵为:
结构化掩码注意力
SSD 理论引入结构化掩码注意力(Structured Masked Attention, SMA):
定义:SMA 操作定义为矩阵-向量乘积:
其中 是掩码矩阵。
关键连接
定理(SSM-注意力对偶性):
对于对角转移矩阵的 SSM,存在一个结构化掩码 使得:
其中 是 SSM 对应的 SSS 矩阵。
直观理解
Transformer 注意力 vs SSM
Transformer: SSM (SSD):
┌─────────┐ ┌─────────┐
│ ■ ■ □ □ │ │ ■ │
│ ■ ■ ■ □ │ ≈ │ ■ ■ │
│ ■ ■ ■ ■ │ │ ■ ■ ■ │
│ ■ ■ ■ ■ │ │ ■ ■ ■ ■ │
└─────────┘ └─────────┘
■ = 强连接 □ = 弱连接
SSM 的下三角结构是 SMA 掩码的特殊形式
统一的理论框架
三元对偶性
SSD 理论建立了三种表示之间的等价性:
SSM 视角
(递归形式)
│
│ ← 等价变换
▼
┌─────────────────────────┐
│ │
│ 结构化半可分矩阵 (SSS) │ ← SSM-注意力对偶
│ │
└─────────────────────────┘
▲
│ ← 矩阵分解
│
注意力视角
(矩阵形式)
表示转换
| 视角 | 表示形式 | 优点 |
|---|---|---|
| SSM | 递归: | 因果建模、可解释性 |
| SSS | 矩阵: | 并行计算、算法优化 |
| SMA | 掩码: | 与 Transformer 兼容 |
Mamba-2 架构
核心设计
Mamba-2 基于 SSD 框架,具有以下关键特性:2
class Mamba2Block(nn.Module):
"""
Mamba-2 Block
基于状态空间对偶性设计
"""
def __init__(self, d_model, d_state=128, d_conv=4, expand=2):
super().__init__()
self.d_model = d_model
self.d_state = d_state
d_inner = d_model * expand
# 输入投影
self.x_proj = nn.Linear(d_model, d_state + 2 * d_conv, bias=False)
# 输出投影
self.dt_proj = nn.Linear(d_state, d_model)
self.B_proj = nn.Linear(d_conv, d_state, bias=False)
self.C_proj = nn.Linear(d_state, d_conv, bias=False)
# SSM 参数
self.A_log = nn.Parameter(torch.randn(d_state, d_conv))
self.D = nn.Parameter(torch.ones(d_conv))
# 归一化
self.norm = nn.LayerNorm(d_model)
# 输出投影
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x):
"""
Args:
x: (batch, seq_len, d_model)
Returns:
out: (batch, seq_len, d_model)
"""
residual = x
x = self.norm(x)
# 投影得到 SSM 参数
x_dbl = self.x_proj(x) # (B, L, d_state + 2*d_conv)
dt, B, C = x_dbl.split([self.d_state, self.d_conv, self.d_conv], dim=-1)
# 离散化
A = -torch.exp(self.A_log.float()) # 确保稳定性
dA = torch.exp(dt @ A.T) # (B, d_state, d_conv)
dB = B.unsqueeze(-1) * dt.unsqueeze(-2) # (B, d_conv, 1)
# 选择性扫描 (SSD)
y = self.selective_scan(
dA, dB, C, self.D, x
)
# 输出投影
y = self.out_proj(y)
return y + residual
def selective_scan(self, dA, dB, C, D, u):
"""
Mamba-2 的选择性扫描
利用 SSM-SSS 对偶性高效实现
"""
B, T, d_state = dA.shape[:3]
d_conv = dA.shape[-1]
# 重组为 SSS 矩阵形式
# 使用并行前缀扫描算法
dA_cumsum = torch.cumsum(dA, dim=1)
# 简化的扫描实现
ys = []
h = torch.zeros(B, d_state, d_conv, device=dA.device)
for t in range(T):
# SSM 递归更新
h = dA[:, t] * h + dB[:, t] * u[:, t:t+1]
y = (C[:, t] @ h).squeeze(-1)
ys.append(y)
y = torch.stack(ys, dim=1)
# 跳跃连接
y = y + u * D
return y与 Mamba-1 的对比
| 特性 | Mamba-1 | Mamba-2 |
|---|---|---|
| 状态维度 | 可变 | 固定 (SSM-SSD) |
| 并行化 | 关联扫描 | 原生并行 |
| 注意力兼容性 | 低 | 高 |
| 训练速度 | 基线 | 2-8× 提升 |
| 理论框架 | 启发式 | SSD 数学严格 |
硬件感知优化
Mamba-2 使用 FlashAttention 风格的内存层级优化:
def mamba2_flash_scan(dA, dB, C, D, u, chunk_size=64):
"""
分块计算以优化 GPU 内存使用
"""
B, T, d_state, d_conv = dA.shape
# 分块处理
num_chunks = (T + chunk_size - 1) // chunk_size
# 状态扩展因子
scale = torch.exp(torch.cumsum(torch.log(dA + 1e-6), dim=1))
ys = []
h_prev = torch.zeros(B, d_state, d_conv, device=dA.device)
for i in range(num_chunks):
start = i * chunk_size
end = min((i + 1) * chunk_size, T)
# 块内计算
chunk_u = u[:, start:end]
chunk_scale = scale[:, start:end]
# 重置状态检测
reset = detect_reset(chunk_scale)
# 更新状态
h = reset * 0 + (1 - reset) * (chunk_scale[:, :-1] * h_prev)
# 块内并行扫描
chunk_y = parallel_scan(h, chunk_u, C[:, start:end], D)
ys.append(chunk_y)
h_prev = h[:, -1:]
return torch.cat(ys, dim=1)LSTM 作为 SSM 的特例
连接推导
考虑一个简化的线性 LSTM(无非线性):
设 为常数(无输入依赖),,则:
这正是齐次 SSM 的形式!
LSTM 门控 = SSM 选择性
关键洞察:LSTM 的输入依赖门控本质上是在选择性地修改 SSM 的转移矩阵 。
| LSTM 门控 | SSM 对应 |
|---|---|
| 遗忘门 | 状态衰减率 |
| 输入门 | 输入权重 |
| 输出门 | 输出投影 |
表达力比较
定理(表达力关系):
- SSM 的表达能力 ≥ 固定门控 LSTM
- 带选择性扫描的 SSM 可以模拟 任意输入依赖的 LSTM
- 但 SSM 不一定能模拟 Transformer 的全局注意力
实践意义
为什么 SSD 重要?
1. 算法设计
- SSD 提供了一致的框架来设计新的序列模型
- 可以在 SSM 和注意力之间自由切换
2. 硬件优化
- SSD 使得 SSM 可以利用 Transformer 的优化(如 FlashAttention)
- Mamba-2 的训练速度提升 2-8 倍
3. 理论理解
- 揭示了不同架构之间的数学联系
- 帮助理解为什么某些模型有效
未来方向
- 混合架构:结合 SSM 的效率和注意力的大上下文
- 新变体:基于 SSD 设计更多高效架构
- 理论深化:更深入理解表达力边界
代码:完整的 SSD 框架
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SSM_SSD_Block(nn.Module):
"""
基于状态空间对偶性的完整 SSM Block
支持两种模式:
- 'ssm': 纯 SSM 递归
- 'attention': 结构化掩码注意力
- 'hybrid': 混合模式
"""
def __init__(self, d_model, d_state=128, mode='ssm'):
super().__init__()
self.mode = mode
self.d_model = d_model
self.d_state = d_state
# 输入投影
self.x_proj = nn.Linear(d_model, d_state * 2, bias=False)
# SSM 参数
self.A = nn.Parameter(torch.randn(d_state, d_state))
self.B = nn.Parameter(torch.randn(d_state, 1))
self.C = nn.Parameter(torch.randn(1, d_state))
self.D = nn.Parameter(torch.ones(1))
# 注意力头(用于混合模式)
if mode == 'hybrid':
self.num_heads = 8
self.head_dim = d_model // self.num_heads
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.o_proj = nn.Linear(d_model, d_model)
self.norm = nn.LayerNorm(d_model)
# 初始化 A 为稳定矩阵
nn.init.normal_(self.A, 0, 0.02)
def ssm_forward(self, x):
"""SSM 前向传播"""
B, T, D = x.shape
# 投影得到 B, C
x_proj = self.x_proj(x)
B_t = x_proj[:, :, :self.d_state].sigmoid() # (B, T, d_state)
C_t = x_proj[:, :, self.d_state:].sigmoid() # (B, T, d_state)
# 初始化状态
h = torch.zeros(B, self.d_state, device=x.device)
ys = []
for t in range(T):
# SSM 更新
h = F.linear(h, self.A) + B_t[:, t] * self.B * x[:, t]
y = (C_t[:, t] @ self.C) * x[:, t] + self.D * x[:, t]
ys.append(y)
return torch.stack(ys, dim=1)
def attention_forward(self, x):
"""结构化掩码注意力前向传播"""
B, T, D = x.shape
# Q, K, V
Q = self.q_proj(x)
K = self.k_proj(x)
V = self.v_proj(x)
# 重塑为多头
Q = Q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
# 计算注意力分数
scale = math.sqrt(self.head_dim)
scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
# 结构化掩码:下三角(因果)
mask = torch.tril(torch.ones(T, T, device=x.device))
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax
attn = F.softmax(scores, dim=-1)
# 输出
out = torch.matmul(attn, V)
out = out.transpose(1, 2).contiguous().view(B, T, D)
return self.o_proj(out)
def forward(self, x):
residual = x
x = self.norm(x)
if self.mode == 'ssm':
out = self.ssm_forward(x)
elif self.mode == 'attention':
out = self.attention_forward(x)
else: # hybrid
# 结合 SSM 和注意力
ssm_out = self.ssm_forward(x)
attn_out = self.attention_forward(x)
out = 0.5 * ssm_out + 0.5 * attn_out
return out + residual参考
相关阅读
- LSTM 详解 — LSTM 基础架构
- 状态空间模型 — Mamba 等 SSM 架构
- Transformer 与注意力机制 — 完全基于注意力的序列建模
- 混合 SSM-Transformer — 结合两种范式的架构