概述
Mamba 2基于**状态空间对偶性(State Space Duality, SSD)**框架,建立了结构化状态空间模型(SSMs)与注意力变体之间的深刻数学联系。1
SSD框架的核心贡献:
- 揭示了SSMs和注意力变体实际上是同一数学对象的两种分解形式
- 基于SSS(结构化半可分)矩阵的统一理论
- 设计了Mamba 2架构,训练速度比Mamba 1快2-8倍
- 在语言建模任务上保持与Transformer的竞争力
1. 结构化半可分矩阵
1.1 定义
结构化半可分(Structured Semi-Separable, SSS)矩阵是一类具有特殊结构的矩阵,可以高效表示和计算。
定义:一个矩阵 是阶为 的SSS矩阵,如果对于所有 ,子矩阵 可以表示为:
其中 ,,且 称为秩。
1.2 SSS矩阵的性质
SSS矩阵具有以下关键性质:
- 低秩结构:对角线以下的每个元素都可以由低秩分解表示
- 高效计算:矩阵-向量乘积可以在 时间复杂度内完成
- 闭合性:SSS矩阵的乘积、求逆仍是SSS矩阵
- 对角化:可以通过三角分解 表示
1.3 SSS与状态空间模型的联系
考虑线性时不变(LTI)SSM:
其输入-输出映射可以写为:
其中 是一个SSS矩阵。具体来说:
这表明SSM的隐藏状态压缩了历史信息到低维表示。
2. 状态空间对偶性框架
2.1 核心定理
SSD框架定理:设 为一个SSM对应的SSS矩阵,则存在以下等价关系:
其中注意力机制的查询、键、值由SSM参数构造。
2.2 矩阵分解视角
从矩阵分解角度,SSD框架建立了以下对应:
| SSM组件 | 矩阵表示 | 注意力视角 |
|---|---|---|
| 状态矩阵 | SSS结构 | 注意力分数计算 |
| 输入矩阵 | 低秩分解 | Query构造 |
| 输出矩阵 | 低秩分解 | Key构造 |
| 选择性扫描 | 参数化SSS | 动态注意力 |
2.3 数学形式化
选择机制的SSD表示:
Mamba 2的选择机制可以写为:
其中:
- 是输入依赖的时间步长
- 是线性投影函数
- 表示逐元素乘法
对应的SSS矩阵元素:
这等价于输入依赖的注意力分数。
3. 分组值注意力头(GVA)
3.1 动机
标准注意力机制中,每个头需要计算完整的键-值投影:
这导致 的计算复杂度。
3.2 GVA定义
**分组值注意力(Grouped Value Attention)**将值向量分组,每个组共享键向量:
给定: num_heads = h, groups = g
对于每个头 i:
属于组 g(i) = i mod groups
Key_i = Key_{g(i)} # 共享键
Value_i = Value_i # 独立值
3.3 计算效率
GVA的计算复杂度分析:
| 操作 | 标准注意力 | GVA |
|---|---|---|
| Key投影 | ||
| Value投影 | ||
| 注意力分数 | ||
| 注意力加权 |
当 时,GVA显著减少参数量。
3.4 与SSM的联系
GVA的分组机制与SSM的状态压缩有异曲同工之妙:
- SSM:将 步历史压缩到 维状态
- GVA:将 个头的信息压缩到 个键
4. Mamba 2核心算法
4.1 选择性SSM层
Mamba 2的核心层是**选择性SSM(Selective SSM)**的改进版本:
class Mamba2SSMLayer(nn.Module):
def __init__(self, d_model, d_state=128, d_conv=4, expand=2):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_inner = d_model * expand
# 输入投影
self.in_proj = nn.Linear(d_model, self.d_inner * 2 + d_state)
# 卷积
self.conv1d = nn.Conv1d(d_inner, d_inner, d_conv, padding=d_conv-1)
# SSM参数
self.x_proj = nn.Linear(d_inner, d_state + 1, bias=False)
self.dt_proj = nn.Linear(d_state, d_inner, bias=True)
# 输出投影
self.out_proj = nn.Linear(d_inner, d_model)
# A矩阵(对角参数化)
self.A = nn.Parameter(torch.randn(d_state, d_inner))
def forward(self, x):
B, L, D = x.shape
# 输入投影 + 分割
xz = self.in_proj(x)
x_inner, z = xz[..., :self.d_inner], xz[..., self.d_inner:]
# 卷积
x_conv = self.conv1d(x_inner.transpose(1,2))[..., :L].transpose(1,2)
x_silu = F.silu(x_conv)
# 选择性参数(与输入相关)
x_gate = self.x_proj(x_silu)
B_t, C_t = x_gate[..., :self.d_state], x_gate[..., self.d_state:]
dt = self.dt_proj(B_t)
# SSM扫描(使用SSD算法)
y = self.ssd_scan(x_silu, dt, C_t, self.A)
# 门控 + 输出
y = y * F.silu(z)
return self.out_proj(y)
def ssd_scan(self, u, dt, C, A):
"""状态空间对偶扫描算法"""
# 离散化
dA = torch.exp(dt.unsqueeze(-1) * A)
dB_u = dt.unsqueeze(-1) * u.unsqueeze(-1)
# 扫描(并行化)
return self.ssd_parallel_scan(dA, dB_u, C)4.2 SSD并行扫描
Mamba 2的关键创新是SSD(Selective State Space Duality)并行扫描:
def ssd_parallel_scan(dA, dB_u, C):
"""
状态空间对偶并行扫描
参数:
dA: (batch, seqlen, d_state, d_inner) - 离散化的A矩阵
dB_u: (batch, seqlen, d_state, d_inner) - 离散化的B*u
C: (batch, seqlen, d_state) - 输出投影
"""
batch, seqlen, N, D = dA.shape
# 展开为半可分结构
# M[i,j] = C[i] @ A[i]@...@A[j+1] @ B[j] if i > j
# 并行扫描(类似前缀和)
# 使用横享扫描(Tiled Scan)算法
T = seqlen
# 分块处理
num_chunks = (T + chunk_size - 1) // chunk_size
# 对角块积分
for c in range(num_chunks):
start, end = c*chunk_size, min((c+1)*chunk_size, T)
# 计算当前块的贡献
chunk_contrib = compute_chunk(dA, dB_u, C, start, end)
# 跨块传播
carry = None
for c in range(num_chunks):
carry = propagate_chunk(carry, dA, dB_u, c)
return final_output4.3 与FlashAttention的联系
SSD扫描与FlashAttention有深刻联系:
| FlashAttention | SSD扫描 |
|---|---|
| 分块计算避免HBM访问 | 分块扫描避免HBM访问 |
| 矩阵乘法融合 | SSM递归融合 |
| 内存 | 内存 |
5. 张量并行支持
5.1 Mamba 1的限制
Mamba 1使用线性递归:
这种递归形式难以并行化,无法直接支持张量并行。
5.2 Mamba 2的解决方案
Mamba 2将递归转化为矩阵乘法形式:
- SSS矩阵表示:,其中 是SSS矩阵
- 并行扫描:使用SSD并行扫描算法
- 列切割:对输入/输出维度进行分布式计算
# 张量并行示例
class Mamba2Parallel(nn.Module):
def __init__(self, d_model, n_heads, tp_size):
super().__init__()
self.tp_size = tp_size
self.d_head = d_model // n_heads
# 列切割:每个GPU持有部分列
self.in_proj = ColumnParallelLinear(d_model, d_model*2, tp_size)
def forward(self, x):
# 每个GPU独立计算
x_local = self.in_proj(x) # [B, L, D/tp]
# SSD扫描(本地)
y_local = self.ssd_scan(x_local)
return y_local5.3 与Megatron-LM集成
NVIDIA Megatron-LM提供了Mamba 2的分布式实现:
# Megatron-LM中的Mamba-2配置
mamba2_config = {
"hidden_size": 4096,
"num_attention_heads": 32,
"num_layers": 32,
"ssm_state_size": 128,
"ssm_expand_factor": 2,
"use_bias": False,
"fused_add_norm": True,
"num_tp_procs": 8, # 张量并行度
}6. 性能基准
6.1 训练速度对比
| 模型 | 序列长度 | 训练速度( tokens/sec/GPU) | 相对加速 |
|---|---|---|---|
| Mamba 1 (2.8B) | 2K | 100% | 1x |
| Mamba 2 (2.8B) | 2K | 280% | 2.8x |
| Mamba 2 (2.8B) | 8K | 340% | 3.4x |
| Transformer (2.8B) | 2K | 100% | 1x |
6.2 推理内存对比
| 模型 | KV Cache大小(2K序列) | 状态大小 |
|---|---|---|
| Transformer | - | |
| Mamba 1 | ||
| Mamba 2 |
6.3 基准测试结果
标准语言建模任务:
| 任务 | Mamba 2 (2.8B) | Transformer (2.8B) | Delta |
|---|---|---|---|
| WikiText-103 PPL | 15.2 | 14.8 | +0.4 |
| Pile PPL | 8.9 | 8.7 | +0.2 |
| MMLU (5-shot) | 52.1% | 53.2% | -1.1% |
7. 与其他注意力变体的联系
7.1 Linear Attention
SSD框架与Linear Attention的关系:
| Linear Attention | SSD表示 |
|---|---|
| SSS矩阵的核近似 | |
| 循环形式 | SSS递归形式 |
| 固定核 | 输入依赖核 |
7.2 FlashAttention
SSD可以看作FlashAttention的状态空间解释:
- FlashAttention:分块计算注意力
- SSD:分块计算SSS矩阵乘法
两者都避免了HBM访问,实现了内存高效计算。
7.3 RetNet/GLA
RetNet和GLA都可以纳入SSD框架:
| 模型 | SSS表示 | 特点 |
|---|---|---|
| RetNet | 保留率矩阵 | 多尺度衰减 |
| GLA | 门控SSS | 输入依赖门控 |
| Mamba 2 | 选择性SSS | 完全输入依赖 |
8. 总结
Mamba 2的SSD框架建立了状态空间模型与注意力机制的深刻统一:
- 数学基础:SSS矩阵提供了统一表示
- 算法效率:SSD并行扫描实现2-8x加速
- 工程优化:支持张量并行,与Megatron-LM集成
- 理论贡献:揭示了SSM-Attention的内在联系
SSD框架不仅是一个工程优化,更是对神经网络架构本质的深刻洞察。
参考资料
相关文档:mamba-2-ssd-theory、state-space-model、linear-attention-mechanism-theory
Footnotes
-
Dao, T., & Gu, A. (2024). Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality. International Conference on Machine Learning (ICML). ↩