引言
Mamba 的硬件感知算法(Hardware-aware Algorithms)是使选择性状态空间模型能够在现代 GPU 上高效运行的关键工程基础。与传统方法不同,这些算法在设计时显式考虑了目标硬件的内存层次结构、并行度模型和吞吐量约束,而非将硬件视为透明抽象层。1
核心设计目标有三个:
- 将选择性 SSM 递推重新表述为可并行化的形式
- 显式管理 GPU 内存层级(HBM vs SRAM)以减少数据移动
- 在反向传播时不存储中间状态,通过重计算处理梯度1
选择性扫描的并行前缀算法
关联扫描理论基础
选择性扫描的核心数据流是参数化线性递推:
其中 、、 是输入依赖的(选择性)矩阵。Mamba 将此递推重新表述为**关联扫描(Associative Scan)**问题,使得可以在 步并行步骤内完成整个序列长度的计算,而无需逐token顺序执行。12
CUDA Warp 级并行实现
Mamba 的并行扫描实现在 selective_scan_fwd_kernel.cuh 中采用了 warp 级扫描原语:
// 每个 warp 使用 Kogge-Stone 风格并行扫描
// 所有 lane 0 的线程读取 running_prefix 进行跨块状态传递
running_prefix = chunk > 0 && threadIdx.x % 32 == 0
? smem_running_prefix[state_idx + r * MAX_DSTATE]
: make_float2(1.f, 0.f);关键实现细节:
- 使用
cub::BlockScan的WARP_SCAN模式实现低延迟 warp 级扫描 - 通过共享内存(shared memory)传递跨块边界的状态
- 每个线程块处理固定大小的 chunk,块间通过原子操作同步3
块分解与状态传递
Mamba-2 的 SSD 算法进一步优化,将序列划分为固定大小的块(chunk):
- 块内输出:计算每个块的局部输出(假设初始状态为0)
- 块状态:计算每个块的最终状态
- 状态传递:在块级别执行并行或顺序扫描
- 输出状态:根据真实的初始状态计算输出贡献2
这种块分解使得大多数计算可以完全并行化,只有状态传递步骤需要跨块通信。
核融合策略
融合选择性扫描操作
Mamba 的核心优化是将离散化步骤、SSM 递推和输出投影融合为单个 CUDA kernel:
// 融合 kernel 中的执行流程
// 1. 离散化:Δ, A, B → Ā, B̄(零阶保持公式)
// 2. 输入加载到共享内存
// 3. 本地前缀扫描
// 4. 跨块状态传递
// 5. 输出投影 y_t = C_t · h_t
// 6. 结果写回 HBM融合策略的核心优势:
- 消除中间张量物化:避免在 HBM 中存储递推过程中的中间状态
- 减少 kernel 启动开销:一次启动替代多次串行启动
- 提高缓存局部性:同一线程块内产生的数据立即被后续步骤消费4
Mamba-2 五步融合
Mamba-2 将原本分散的五个 SSD 操作融合为单个 Triton kernel:
| 步骤 | 操作 | 融合前 |
|---|---|---|
| 1 | Chunk Cumsum | 独立 kernel |
| 2 | SSD BMM | 独立 kernel |
| 3 | Chunk State | 独立 kernel |
| 4 | State Passing | 独立 kernel |
| 5 | Chunk Scan | 独立 kernel |
PyTorch 团队的实验表明,融合后的 SSD kernel 在 A100 和 H100 上实现 1.50x-2.51x 加速,端到端推理提升约 8-20%。4
反向传播重计算策略
内存-计算权衡
Mamba 的关键设计决策是:前向传播时不将隐状态 存储到 HBM,而是让它们驻留在 SRAM 中瞬态使用。反向传播时重新计算这些状态。1
# PyTorch 梯度检查点伪代码
# 与 Mamba 的 kernel 级重计算策略对比
def forward_with_recompute(x):
h = ssm_scan(x) # 状态留在 SRAM
y = output_proj(h)
return y
def backward_recompute(grad_y):
# 重新执行前向 pass 计算 h_t
h_recomputed = ssm_scan(x)
# 计算梯度
grad_x = compute_gradients(grad_y, h_recomputed)
return grad_x这种策略以 30-50% 的额外 FLOPs 换取显著降低的峰值 HBM 占用。在序列长度 较大时, 的内存节省是决定性优势。1
与 PyTorch Gradient Checkpointing 的区别
Mamba 的重计算是在 CUDA kernel 级别集成,相比 PyTorch 级别的梯度检查点开销更低:
- PyTorch checkpoint:Python 调用开销 + 多 kernel 启动
- Mamba kernel recompute:单一 kernel 内的轻量级重计算1
内存层次管理
HBM vs SRAM 分层
现代 GPU 的内存层次结构决定了算法设计:
| 内存层级 | A100 规格 | 典型延迟 | 带宽 |
|---|---|---|---|
| HBM | 40/80 GB | ~500-700 ns | ~2 TB/s |
| L2 Cache | 40-50 MB | ~200 ns | 高 |
| SRAM (SM 私有) | 128 KB/SM | ~30 ns | 极高 |
A100 提供 312 TFLOPS 的 BF16 矩阵乘法计算能力,但 HBM 带宽仅 2 TB/s。算术强度不足会导致内存带宽成为瓶颈,而非计算吞吐量。14
带宽瓶颈分析
Mamba 的选择性扫描需要 HBM 读写,对于 的典型配置:
序列长度 L=4096, N=64:
- 总数据量: 4096 * 64 * 2(输入+状态) * 2B(BF16) ≈ 1 MB
- 带宽需求 vs 计算: ~10x 差距
解决方案:
1. 核融合减少数据在 HBM 和 SRAM 间的移动
2. 将热点数据保留在共享内存和 L1/L2 缓存
3. 使用 fp16 状态减少数据传输量
与 FlashAttention 的系统设计对比
Mamba 的硬件感知设计与 FlashAttention 同属 IO-aware 算法这一更广泛的类别:
| 特性 | FlashAttention | Mamba |
|---|---|---|
| IO 建模 | HBM vs SRAM | HBM vs SRAM |
| 核融合 | 融合注意力分块计算 | 融合选择性扫描 |
| 重计算 | 标准 + 可选重计算 | 前向不存储,backward 重算 |
| 并行化 | 分块矩阵乘法 | 关联前缀扫描 |
| 数学等价 | 与注意力精确等价 | 与递推精确等价 |
| 内存复杂度 | 序列无关 | 线性 |
FlashAttention 通过分块矩阵乘法实现注意力计算,而 Mamba 通过关联扫描实现选择性 SSM。两者都避免了中间结果的完整物化,但处理的数学运算不同。14
数值稳定性考虑
长期依赖的数值问题
关联扫描累积 矩阵的乘积,对于特征值接近 1 的状态矩阵,在序列超过 ~16,000 步时会遇到数值精度退化问题。1
Mamba-2 Segsum 稳定化
Mamba-2 论文提出了分段求和(Segment Sum, segsum) 原语来处理稳定性:
def segsum(x):
"""
稳定的分段求和计算。
在 log-space 进行累加避免数值下溢。
"""
T = x.size(-1)
x_cumsum = torch.cumsum(x, dim=-1)
# 使用掩码避免下三角以外的区域
x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :]
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum关键设计原则:
精度-性能权衡
Mamba-2 融合 kernel 支持多种精度配置:
| 配置 | 精度 | 速度 | 准确率匹配 |
|---|---|---|---|
| exact dtypes | 所有运算相同 dtype | 基准 | ~99.7% |
| relaxed dtypes | 部分运算使用 fp16 | +16% | >99.7% @ 1e-3 |
对于 relaxed dtypes,在 atol=1e-2 阈值下实际达到 100% 匹配,对实际应用影响可忽略。4
Mamba-2 SSD 核改进
结构化状态空间对偶性
Mamba-2 引入了 结构化状态空间对偶性(Structured State Space Duality, SSD) 框架,建立了 SSM 与线性注意力之间的联系:
- 利用张量核心:SSD 算法允许使用矩阵乘法作为原语,相比标量运算快达 16x
- 块分解:将 SSM 矩阵分解为对角块(使用注意力风格计算)和低秩非对角块(使用批量矩阵乘法)2
张量并行与序列并行
SSD 框架使得原本为 Transformer 开发的系统优化可以应用于 SSM:
# Mamba-2 支持的并行策略
# 1. 张量并行:跨设备分割线性层
# 2. 序列并行:支持可变长度序列
# 3. 变量长度序列的伪块(pseudo chunks)处理这解决了 Mamba-1 在大规模训练中 tensor parallelism 的困难。5
硬件特性利用
Mamba-2 融合 kernel 的利用率分析(A100):
| 指标 | 利用率 |
|---|---|
| 计算利用率 | 40-50% |
| 内存利用率 | 65-75% |
瓶颈主要来自:
- Warp 阻塞:等待 L2/VRAM 内存访问
- 同步等待:块间状态传递的屏障
- Occupancy 限制:受寄存器数量约束4
未来优化方向:
- 使用 Hopper 的 TMA(Tensor Memory Accelerator)
- Thread Block Clusters 优化广播加载
- Blackwell 架构的 Tensor Memory 支持
总结
Mamba 的硬件感知算法代表了 SSM 从理论到高效工程实现的关键跨越:
- 并行扫描将 的串行计算转化为 的并行步骤
- 核融合通过消除 HBM 中间物化解决了内存墙问题
- 重计算策略以 30-50% 的 FLOPs 开销换取线性内存复杂度
- Mamba-2 SSD 通过张量核心利用和块分解进一步提升性能
这些技术的组合使 Mamba 能够在现代 GPU 上实现 5 倍于 Transformer 的推理吞吐量,同时保持线性序列长度扩展能力。
参考
Footnotes
-
Mamba: Linear-Time Sequence Modeling with Selective State Spaces - Gu & Dao, 2023 ↩ ↩2 ↩3 ↩4 ↩5 ↩6 ↩7 ↩8 ↩9
-
State Space Duality (Mamba-2) Part III - The Algorithm - Tri Dao ↩ ↩2 ↩3 ↩4
-
Mamba Selective Scan CUDA Kernel - state-spaces/mamba ↩
-
Accelerating Mamba2 with Kernel Fusion - PyTorch Blog ↩ ↩2 ↩3 ↩4 ↩5 ↩6
-
State Space Duality (Mamba-2) Part IV - The Systems - Goomba Lab ↩