引言

Mamba 的硬件感知算法(Hardware-aware Algorithms)是使选择性状态空间模型能够在现代 GPU 上高效运行的关键工程基础。与传统方法不同,这些算法在设计时显式考虑了目标硬件的内存层次结构、并行度模型和吞吐量约束,而非将硬件视为透明抽象层。1

核心设计目标有三个:

  1. 将选择性 SSM 递推重新表述为可并行化的形式
  2. 显式管理 GPU 内存层级(HBM vs SRAM)以减少数据移动
  3. 在反向传播时不存储中间状态,通过重计算处理梯度1

选择性扫描的并行前缀算法

关联扫描理论基础

选择性扫描的核心数据流是参数化线性递推:

其中 是输入依赖的(选择性)矩阵。Mamba 将此递推重新表述为**关联扫描(Associative Scan)**问题,使得可以在 步并行步骤内完成整个序列长度的计算,而无需逐token顺序执行。12

CUDA Warp 级并行实现

Mamba 的并行扫描实现在 selective_scan_fwd_kernel.cuh 中采用了 warp 级扫描原语:

// 每个 warp 使用 Kogge-Stone 风格并行扫描
// 所有 lane 0 的线程读取 running_prefix 进行跨块状态传递
running_prefix = chunk > 0 && threadIdx.x % 32 == 0 
    ? smem_running_prefix[state_idx + r * MAX_DSTATE] 
    : make_float2(1.f, 0.f);

关键实现细节:

  • 使用 cub::BlockScanWARP_SCAN 模式实现低延迟 warp 级扫描
  • 通过共享内存(shared memory)传递跨块边界的状态
  • 每个线程块处理固定大小的 chunk,块间通过原子操作同步3

块分解与状态传递

Mamba-2 的 SSD 算法进一步优化,将序列划分为固定大小的块(chunk):

  1. 块内输出:计算每个块的局部输出(假设初始状态为0)
  2. 块状态:计算每个块的最终状态
  3. 状态传递:在块级别执行并行或顺序扫描
  4. 输出状态:根据真实的初始状态计算输出贡献2

这种块分解使得大多数计算可以完全并行化,只有状态传递步骤需要跨块通信。

核融合策略

融合选择性扫描操作

Mamba 的核心优化是将离散化步骤、SSM 递推和输出投影融合为单个 CUDA kernel

// 融合 kernel 中的执行流程
// 1. 离散化:Δ, A, B → Ā, B̄(零阶保持公式)
// 2. 输入加载到共享内存
// 3. 本地前缀扫描
// 4. 跨块状态传递
// 5. 输出投影 y_t = C_t · h_t
// 6. 结果写回 HBM

融合策略的核心优势:

  • 消除中间张量物化:避免在 HBM 中存储递推过程中的中间状态
  • 减少 kernel 启动开销:一次启动替代多次串行启动
  • 提高缓存局部性:同一线程块内产生的数据立即被后续步骤消费4

Mamba-2 五步融合

Mamba-2 将原本分散的五个 SSD 操作融合为单个 Triton kernel:

步骤操作融合前
1Chunk Cumsum独立 kernel
2SSD BMM独立 kernel
3Chunk State独立 kernel
4State Passing独立 kernel
5Chunk Scan独立 kernel

PyTorch 团队的实验表明,融合后的 SSD kernel 在 A100 和 H100 上实现 1.50x-2.51x 加速,端到端推理提升约 8-20%。4

反向传播重计算策略

内存-计算权衡

Mamba 的关键设计决策是:前向传播时不将隐状态 存储到 HBM,而是让它们驻留在 SRAM 中瞬态使用。反向传播时重新计算这些状态。1

# PyTorch 梯度检查点伪代码
# 与 Mamba 的 kernel 级重计算策略对比
def forward_with_recompute(x):
    h = ssm_scan(x)  # 状态留在 SRAM
    y = output_proj(h)
    return y
 
def backward_recompute(grad_y):
    # 重新执行前向 pass 计算 h_t
    h_recomputed = ssm_scan(x)
    # 计算梯度
    grad_x = compute_gradients(grad_y, h_recomputed)
    return grad_x

这种策略以 30-50% 的额外 FLOPs 换取显著降低的峰值 HBM 占用。在序列长度 较大时, 的内存节省是决定性优势。1

与 PyTorch Gradient Checkpointing 的区别

Mamba 的重计算是在 CUDA kernel 级别集成,相比 PyTorch 级别的梯度检查点开销更低:

  • PyTorch checkpoint:Python 调用开销 + 多 kernel 启动
  • Mamba kernel recompute:单一 kernel 内的轻量级重计算1

内存层次管理

HBM vs SRAM 分层

现代 GPU 的内存层次结构决定了算法设计:

内存层级A100 规格典型延迟带宽
HBM40/80 GB~500-700 ns~2 TB/s
L2 Cache40-50 MB~200 ns
SRAM (SM 私有)128 KB/SM~30 ns极高

A100 提供 312 TFLOPS 的 BF16 矩阵乘法计算能力,但 HBM 带宽仅 2 TB/s。算术强度不足会导致内存带宽成为瓶颈,而非计算吞吐量。14

带宽瓶颈分析

Mamba 的选择性扫描需要 HBM 读写,对于 的典型配置:

序列长度 L=4096, N=64:
- 总数据量: 4096 * 64 * 2(输入+状态) * 2B(BF16) ≈ 1 MB
- 带宽需求 vs 计算: ~10x 差距

解决方案:
1. 核融合减少数据在 HBM 和 SRAM 间的移动
2. 将热点数据保留在共享内存和 L1/L2 缓存
3. 使用 fp16 状态减少数据传输量

与 FlashAttention 的系统设计对比

Mamba 的硬件感知设计与 FlashAttention 同属 IO-aware 算法这一更广泛的类别:

特性FlashAttentionMamba
IO 建模HBM vs SRAMHBM vs SRAM
核融合融合注意力分块计算融合选择性扫描
重计算标准 + 可选重计算前向不存储,backward 重算
并行化分块矩阵乘法关联前缀扫描
数学等价与注意力精确等价与递推精确等价
内存复杂度 序列无关 线性

FlashAttention 通过分块矩阵乘法实现注意力计算,而 Mamba 通过关联扫描实现选择性 SSM。两者都避免了中间结果的完整物化,但处理的数学运算不同。14

数值稳定性考虑

长期依赖的数值问题

关联扫描累积 矩阵的乘积,对于特征值接近 1 的状态矩阵,在序列超过 ~16,000 步时会遇到数值精度退化问题。1

Mamba-2 Segsum 稳定化

Mamba-2 论文提出了分段求和(Segment Sum, segsum) 原语来处理稳定性:

def segsum(x):
    """
    稳定的分段求和计算。
    在 log-space 进行累加避免数值下溢。
    """
    T = x.size(-1)
    x_cumsum = torch.cumsum(x, dim=-1)
    # 使用掩码避免下三角以外的区域
    x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
    mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
    x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
    return x_segsum

关键设计原则:

  1. 在 log-space 执行所有运算(乘变加)
  2. 避免所有减法操作:即使在 log-space,累积求和的相减也会导致灾难性抵消
  3. 使用批量独立 cumsum 直接产生正确答案2

精度-性能权衡

Mamba-2 融合 kernel 支持多种精度配置:

配置精度速度准确率匹配
exact dtypes所有运算相同 dtype基准~99.7%
relaxed dtypes部分运算使用 fp16+16%>99.7% @ 1e-3

对于 relaxed dtypes,在 atol=1e-2 阈值下实际达到 100% 匹配,对实际应用影响可忽略。4

Mamba-2 SSD 核改进

结构化状态空间对偶性

Mamba-2 引入了 结构化状态空间对偶性(Structured State Space Duality, SSD) 框架,建立了 SSM 与线性注意力之间的联系:

  • 利用张量核心:SSD 算法允许使用矩阵乘法作为原语,相比标量运算快达 16x
  • 块分解:将 SSM 矩阵分解为对角块(使用注意力风格计算)和低秩非对角块(使用批量矩阵乘法)2

张量并行与序列并行

SSD 框架使得原本为 Transformer 开发的系统优化可以应用于 SSM:

# Mamba-2 支持的并行策略
# 1. 张量并行:跨设备分割线性层
# 2. 序列并行:支持可变长度序列
# 3. 变量长度序列的伪块(pseudo chunks)处理

这解决了 Mamba-1 在大规模训练中 tensor parallelism 的困难。5

硬件特性利用

Mamba-2 融合 kernel 的利用率分析(A100):

指标利用率
计算利用率40-50%
内存利用率65-75%

瓶颈主要来自:

  • Warp 阻塞:等待 L2/VRAM 内存访问
  • 同步等待:块间状态传递的屏障
  • Occupancy 限制:受寄存器数量约束4

未来优化方向:

  • 使用 Hopper 的 TMA(Tensor Memory Accelerator)
  • Thread Block Clusters 优化广播加载
  • Blackwell 架构的 Tensor Memory 支持

总结

Mamba 的硬件感知算法代表了 SSM 从理论到高效工程实现的关键跨越:

  1. 并行扫描 的串行计算转化为 的并行步骤
  2. 核融合通过消除 HBM 中间物化解决了内存墙问题
  3. 重计算策略以 30-50% 的 FLOPs 开销换取线性内存复杂度
  4. Mamba-2 SSD 通过张量核心利用和块分解进一步提升性能

这些技术的组合使 Mamba 能够在现代 GPU 上实现 5 倍于 Transformer 的推理吞吐量,同时保持线性序列长度扩展能力。

参考

Footnotes

  1. Mamba: Linear-Time Sequence Modeling with Selective State Spaces - Gu & Dao, 2023 2 3 4 5 6 7 8 9

  2. State Space Duality (Mamba-2) Part III - The Algorithm - Tri Dao 2 3 4

  3. Mamba Selective Scan CUDA Kernel - state-spaces/mamba

  4. Accelerating Mamba2 with Kernel Fusion - PyTorch Blog 2 3 4 5 6

  5. State Space Duality (Mamba-2) Part IV - The Systems - Goomba Lab