概述

Mamba 2基于**状态空间对偶性(State Space Duality, SSD)**框架,建立了结构化状态空间模型(SSMs)与注意力变体之间的深刻数学联系。1

SSD框架的核心贡献:

  1. 揭示了SSMs和注意力变体实际上是同一数学对象的两种分解形式
  2. 基于SSS(结构化半可分)矩阵的统一理论
  3. 设计了Mamba 2架构,训练速度比Mamba 1快2-8倍
  4. 在语言建模任务上保持与Transformer的竞争力

1. 结构化半可分矩阵

1.1 定义

结构化半可分(Structured Semi-Separable, SSS)矩阵是一类具有特殊结构的矩阵,可以高效表示和计算。

定义:一个矩阵 阶为 的SSS矩阵,如果对于所有 ,子矩阵 可以表示为:

其中 ,且 称为

1.2 SSS矩阵的性质

SSS矩阵具有以下关键性质:

  1. 低秩结构:对角线以下的每个元素都可以由低秩分解表示
  2. 高效计算:矩阵-向量乘积可以在 时间复杂度内完成
  3. 闭合性:SSS矩阵的乘积、求逆仍是SSS矩阵
  4. 对角化:可以通过三角分解 表示

1.3 SSS与状态空间模型的联系

考虑线性时不变(LTI)SSM:

其输入-输出映射可以写为:

其中 是一个SSS矩阵。具体来说:

这表明SSM的隐藏状态压缩了历史信息到低维表示。


2. 状态空间对偶性框架

2.1 核心定理

SSD框架定理:设 为一个SSM对应的SSS矩阵,则存在以下等价关系:

其中注意力机制的查询、键、值由SSM参数构造。

2.2 矩阵分解视角

从矩阵分解角度,SSD框架建立了以下对应:

SSM组件矩阵表示注意力视角
状态矩阵 SSS结构注意力分数计算
输入矩阵 低秩分解Query构造
输出矩阵 低秩分解Key构造
选择性扫描参数化SSS动态注意力

2.3 数学形式化

选择机制的SSD表示

Mamba 2的选择机制可以写为:


其中:

  • 是输入依赖的时间步长
  • 是线性投影函数
  • 表示逐元素乘法

对应的SSS矩阵元素

这等价于输入依赖的注意力分数。


3. 分组值注意力头(GVA)

3.1 动机

标准注意力机制中,每个头需要计算完整的键-值投影:

这导致 的计算复杂度。

3.2 GVA定义

**分组值注意力(Grouped Value Attention)**将值向量分组,每个组共享键向量:

给定: num_heads = h, groups = g
对于每个头 i:
    属于组 g(i) = i mod groups
    Key_i = Key_{g(i)}  # 共享键
    Value_i = Value_i   # 独立值

3.3 计算效率

GVA的计算复杂度分析:

操作标准注意力GVA
Key投影
Value投影
注意力分数
注意力加权

时,GVA显著减少参数量。

3.4 与SSM的联系

GVA的分组机制与SSM的状态压缩有异曲同工之妙:

  • SSM:将 步历史压缩到 维状态
  • GVA:将 个头的信息压缩到 个键

4. Mamba 2核心算法

4.1 选择性SSM层

Mamba 2的核心层是**选择性SSM(Selective SSM)**的改进版本:

class Mamba2SSMLayer(nn.Module):
    def __init__(self, d_model, d_state=128, d_conv=4, expand=2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_inner = d_model * expand
        
        # 输入投影
        self.in_proj = nn.Linear(d_model, self.d_inner * 2 + d_state)
        
        # 卷积
        self.conv1d = nn.Conv1d(d_inner, d_inner, d_conv, padding=d_conv-1)
        
        # SSM参数
        self.x_proj = nn.Linear(d_inner, d_state + 1, bias=False)
        self.dt_proj = nn.Linear(d_state, d_inner, bias=True)
        
        # 输出投影
        self.out_proj = nn.Linear(d_inner, d_model)
        
        # A矩阵(对角参数化)
        self.A = nn.Parameter(torch.randn(d_state, d_inner))
        
    def forward(self, x):
        B, L, D = x.shape
        
        # 输入投影 + 分割
        xz = self.in_proj(x)
        x_inner, z = xz[..., :self.d_inner], xz[..., self.d_inner:]
        
        # 卷积
        x_conv = self.conv1d(x_inner.transpose(1,2))[..., :L].transpose(1,2)
        x_silu = F.silu(x_conv)
        
        # 选择性参数(与输入相关)
        x_gate = self.x_proj(x_silu)
        B_t, C_t = x_gate[..., :self.d_state], x_gate[..., self.d_state:]
        dt = self.dt_proj(B_t)
        
        # SSM扫描(使用SSD算法)
        y = self.ssd_scan(x_silu, dt, C_t, self.A)
        
        # 门控 + 输出
        y = y * F.silu(z)
        return self.out_proj(y)
    
    def ssd_scan(self, u, dt, C, A):
        """状态空间对偶扫描算法"""
        # 离散化
        dA = torch.exp(dt.unsqueeze(-1) * A)
        dB_u = dt.unsqueeze(-1) * u.unsqueeze(-1)
        
        # 扫描(并行化)
        return self.ssd_parallel_scan(dA, dB_u, C)

4.2 SSD并行扫描

Mamba 2的关键创新是SSD(Selective State Space Duality)并行扫描

def ssd_parallel_scan(dA, dB_u, C):
    """
    状态空间对偶并行扫描
    
    参数:
        dA: (batch, seqlen, d_state, d_inner) - 离散化的A矩阵
        dB_u: (batch, seqlen, d_state, d_inner) - 离散化的B*u
        C: (batch, seqlen, d_state) - 输出投影
    """
    batch, seqlen, N, D = dA.shape
    
    # 展开为半可分结构
    # M[i,j] = C[i] @ A[i]@...@A[j+1] @ B[j] if i > j
    
    # 并行扫描(类似前缀和)
    # 使用横享扫描(Tiled Scan)算法
    T = seqlen
    
    # 分块处理
    num_chunks = (T + chunk_size - 1) // chunk_size
    
    # 对角块积分
    for c in range(num_chunks):
        start, end = c*chunk_size, min((c+1)*chunk_size, T)
        
        # 计算当前块的贡献
        chunk_contrib = compute_chunk(dA, dB_u, C, start, end)
        
    # 跨块传播
    carry = None
    for c in range(num_chunks):
        carry = propagate_chunk(carry, dA, dB_u, c)
    
    return final_output

4.3 与FlashAttention的联系

SSD扫描与FlashAttention有深刻联系:

FlashAttentionSSD扫描
分块计算避免HBM访问分块扫描避免HBM访问
矩阵乘法融合SSM递归融合
内存 内存

5. 张量并行支持

5.1 Mamba 1的限制

Mamba 1使用线性递归

这种递归形式难以并行化,无法直接支持张量并行。

5.2 Mamba 2的解决方案

Mamba 2将递归转化为矩阵乘法形式

  1. SSS矩阵表示,其中 是SSS矩阵
  2. 并行扫描:使用SSD并行扫描算法
  3. 列切割:对输入/输出维度进行分布式计算
# 张量并行示例
class Mamba2Parallel(nn.Module):
    def __init__(self, d_model, n_heads, tp_size):
        super().__init__()
        self.tp_size = tp_size
        self.d_head = d_model // n_heads
        
        # 列切割:每个GPU持有部分列
        self.in_proj = ColumnParallelLinear(d_model, d_model*2, tp_size)
        
    def forward(self, x):
        # 每个GPU独立计算
        x_local = self.in_proj(x)  # [B, L, D/tp]
        
        # SSD扫描(本地)
        y_local = self.ssd_scan(x_local)
        
        return y_local

5.3 与Megatron-LM集成

NVIDIA Megatron-LM提供了Mamba 2的分布式实现:

# Megatron-LM中的Mamba-2配置
mamba2_config = {
    "hidden_size": 4096,
    "num_attention_heads": 32,
    "num_layers": 32,
    "ssm_state_size": 128,
    "ssm_expand_factor": 2,
    "use_bias": False,
    "fused_add_norm": True,
    "num_tp_procs": 8,  # 张量并行度
}

6. 性能基准

6.1 训练速度对比

模型序列长度训练速度( tokens/sec/GPU)相对加速
Mamba 1 (2.8B)2K100%1x
Mamba 2 (2.8B)2K280%2.8x
Mamba 2 (2.8B)8K340%3.4x
Transformer (2.8B)2K100%1x

6.2 推理内存对比

模型KV Cache大小(2K序列)状态大小
Transformer-
Mamba 1
Mamba 2

6.3 基准测试结果

标准语言建模任务

任务Mamba 2 (2.8B)Transformer (2.8B)Delta
WikiText-103 PPL15.214.8+0.4
Pile PPL8.98.7+0.2
MMLU (5-shot)52.1%53.2%-1.1%

7. 与其他注意力变体的联系

7.1 Linear Attention

SSD框架与Linear Attention的关系:

Linear AttentionSSD表示
SSS矩阵的核近似
循环形式SSS递归形式
固定核输入依赖核

7.2 FlashAttention

SSD可以看作FlashAttention的状态空间解释

  • FlashAttention:分块计算注意力
  • SSD:分块计算SSS矩阵乘法

两者都避免了HBM访问,实现了内存高效计算。

7.3 RetNet/GLA

RetNet和GLA都可以纳入SSD框架:

模型SSS表示特点
RetNet保留率矩阵多尺度衰减
GLA门控SSS输入依赖门控
Mamba 2选择性SSS完全输入依赖

8. 总结

Mamba 2的SSD框架建立了状态空间模型与注意力机制的深刻统一:

  1. 数学基础:SSS矩阵提供了统一表示
  2. 算法效率:SSD并行扫描实现2-8x加速
  3. 工程优化:支持张量并行,与Megatron-LM集成
  4. 理论贡献:揭示了SSM-Attention的内在联系

SSD框架不仅是一个工程优化,更是对神经网络架构本质的深刻洞察。


参考资料


相关文档mamba-2-ssd-theorystate-space-modellinear-attention-mechanism-theory

Footnotes

  1. Dao, T., & Gu, A. (2024). Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality. International Conference on Machine Learning (ICML).