概述
Mamba2D是首个从第一性原理推导的原生二维状态空间模型,不同于之前将1D SSM应用于展平序列的方法,Mamba2D直接在2D图像上进行状态空间建模。1
核心创新包括:
- 原生2D扫描:沿两个空间方向并行扫描
- Wavefront CUDA核:对角线并行,充分利用GPU并行能力
- 子线性扩展:高分辨率下延迟接近O(N^0.56)
从1D到2D的理论推导
标准Mamba的1D公式
对于1D输入序列 ,选择性状态空间模型定义为:
其中 是状态转移矩阵。
2D扩展的挑战
将1D SSM展平应用于2D图像存在根本问题:
| 问题 | 描述 |
|---|---|
| 空间结构丢失 | 展平后相邻像素在1D序列中可能距离很远 |
| 扫描方向固定 | 单一扫描方向无法捕获所有空间关系 |
| 并行度低 | 串行扫描无法充分利用GPU |
原生2D公式
设2D输入 ,定义4个方向的扫描:
水平向右:
水平向左:
垂直向下:
垂直向上:
融合输出:
Wavefront并行扫描
对角线并行原理
Wavefront扫描利用2D依赖关系的对角线结构:
时间步 t=0: (0,0)
时间步 t=1: (0,1) (1,0)
时间步 t=2: (0,2) (1,1) (2,0)
时间步 t=3: (0,3) (1,2) (2,1) (3,0)
...
每一时间步处理的像素满足:
GPU并行化优势
| 扫描方式 | 并行度 | GPU利用率 |
|---|---|---|
| 行扫描 | O(W) | 中等 |
| 列扫描 | O(H) | 中等 |
| Wavefront | O(min(H,W)) | 高 |
Wavefront扫描确保每个时间步有大量像素可并行处理:
def wavefront_scan(grid, scan_directions):
"""Wavefront并行扫描实现"""
H, W = grid.shape[:2]
# 初始化4个方向的隐藏状态
states_hR = torch.zeros_like(grid)
states_hL = torch.zeros_like(grid)
states_vD = torch.zeros_like(grid)
states_vU = torch.zeros_like(grid)
# 从左上到右下的对角线
for t in range(H + W - 1):
# 找到当前对角线上的所有像素
for i in range(max(0, t - W + 1), min(H, t + 1)):
j = t - i
if 0 <= j < W:
# 四个方向的并行更新
# hR: 来自左边
if j > 0:
states_hR[i,j] = A_h * states_hR[i,j-1] + B_h * grid[i,j]
# hL: 来自右边
if j < W - 1:
states_hL[i,j] = A_h * states_hL[i,j+1] + B_h * grid[i,j]
# vD: 来自上边
if i > 0:
states_vD[i,j] = A_v * states_vD[i-1,j] + B_v * grid[i,j]
# vU: 来自下边
if i < H - 1:
states_vU[i,j] = A_v * states_vU[i+1,j] + B_v * grid[i,j]
return states_hR + states_hL + states_vD + states_vU架构设计
整体结构
Mamba2D采用类似ResNet的层次化结构:
输入图像
↓
Stage 1: Conv Stem + Mamba2D Block × 3
↓ (2×下采样)
Stage 2: Mamba2D Block × 4
↓ (2×下采样)
Stage 3: Mamba2D Block × 6
↓ (2×下采样)
Stage 4: Mamba2D Block × 3
↓
分类头 / 检测头 / 分割头
Mamba2D Block结构
class Mamba2DBlock(nn.Module):
def __init__(self, dim, state_dim=16, expand=2):
super().__init__()
d_inner = dim * expand
# 归一化
self.norm = nn.LayerNorm(dim)
# 投影
self.x_proj = nn.Linear(dim, dim * 2 + state_dim * 4, bias=False)
self.dt_proj = nn.Linear(state_dim * 4, dim, bias=True)
# 四个方向的SSM
self.ssm_hR = SelectiveSSM2D(dim, state_dim, direction='hR')
self.ssm_hL = SelectiveSSM2D(dim, state_dim, direction='hL')
self.ssm_vD = SelectiveSSM2D(dim, state_dim, direction='vD')
self.ssm_vU = SelectiveSSM2D(dim, state_dim, direction='vU')
# 融合门控
self.gate = nn.Sequential(
nn.Linear(dim * 4, dim),
nn.SiLU(),
nn.Linear(dim, dim * 4)
)
# 输出投影
self.proj = nn.Linear(dim, dim)
def forward(self, x):
B, C, H, W = x.shape
# 残差连接
residual = x
# 归一化
x_norm = self.norm(x.permute(0, 2, 3, 1)) # B,H,W,C
x_flat = x_norm.reshape(B * H * W, C)
# 计算SSM参数
xz = self.x_proj(x_flat)
x_inner, dt = xz[:, :C*2], xz[:, C*2:]
# 四个方向的扫描
out_hR = self.ssm_hR(x_inner[:, :C], dt)
out_hL = self.ssm_hL(x_inner[:, C:C*2], dt)
out_vD = self.ssm_vD(x_inner[:, C*2:C*3], dt)
out_vU = self.ssm_vU(x_inner[:, C*3:], dt)
# 融合
combined = torch.cat([out_hR, out_hL, out_vD, out_vU], dim=-1)
gated = self.gate(combined)
out = self.proj(gated)
# 恢复形状 + 残差
out = out.reshape(B, H, W, C).permute(0, 3, 1, 2)
return out + residual实验结果
ImageNet分类
| 模型 | 参数量 | FLOPs | Top-1 | 吞吐量 |
|---|---|---|---|---|
| M2D-Ti | 8M | 1.5G | 78.8% | 4200 img/s |
| M2D-S | 27M | 4.5G | 84.0% | 1850 img/s |
| M2D-B | 50M | 8.2G | 85.3% | 1100 img/s |
| M2D-L | 200M | 30G | 86.2% | 380 img/s |
与其他视觉SSM对比
| 模型 | Top-1 | 吞吐量(A100) | 特点 |
|---|---|---|---|
| VMamba-T | 82.6% | 2100 img/s | 十字扫描 |
| Mamba2D-S | 84.0% | 1850 img/s | 原生2D |
| LocalMamba-T | 82.1% | 1650 img/s | 局部窗口 |
| Mamba2D-B | 85.3% | 1100 img/s | 原生2D |
下游任务
COCO目标检测 (Mask R-CNN)
| Backbone | AP^b | AP^m | 参数 |
|---|---|---|---|
| M2D-Ti | 49.2 | 44.5 | 28M |
| M2D-S | 50.5 | 45.8 | 47M |
| M2D-B | 52.2 | 47.3 | 70M |
ADE20K语义分割 (UperNet)
| Backbone | mIoU | 参数量 |
|---|---|---|
| M2D-S | 48.9 | 52M |
| M2D-B | 51.7 | 80M |
| Swin-B | 49.0 | 88M |
扩展性分析
延迟随分辨率扩展
Mamba2D的wavefront扫描实现亚线性扩展:
| 图像尺寸 | 注意力延迟 | Mamba2D延迟 | 加速比 |
|---|---|---|---|
| 224² | O(N²) = 50K | O(N^0.56) ≈ 300 | 167× |
| 384² | O(N²) = 147K | O(N^0.56) ≈ 600 | 245× |
| 512² | O(N²) = 262K | O(N^0.56) ≈ 900 | 291× |
| 1024² | O(N²) = 1M | O(N^0.56) ≈ 2200 | 455× |
内存效率
| 尺寸 | 注意力内存 | Mamba2D内存 | 节省 |
|---|---|---|---|
| 224² | 256 MB | 32 MB | 8× |
| 512² | 1.5 GB | 64 MB | 24× |
| 1024² | 6 GB | 128 MB | 47× |
与MambaVision的对比
| 特性 | Mamba2D | MambaVision |
|---|---|---|
| 扫描方式 | Wavefront 4方向 | 选择性2D扫描 |
| 并行策略 | 对角线并行 | 标准并行 |
| 空间建模 | 原生2D | 混合CNN+SSM |
| ImageNet | 85.3% (50M) | 85.0% (228M) |
| 高分辨率效率 | 更优 | 中等 |
| 实现复杂度 | 较高 | 较低 |
总结
Mamba2D的核心贡献:
- 首个原生2D SSM,从第一性原理推导
- Wavefront并行扫描,充分利用GPU并行能力
- 子线性扩展,高分辨率下显著优于注意力
- 四方向融合,捕获完整空间依赖
- SOTA性能,成为视觉SSM的新基准
参考文献
相关主题
Footnotes
-
Sun, Y., et al. (2024). Mamba2D: Efficient 2D State Space Model with Wavefront Parallel Scan. arXiv:2405.14410. https://arxiv.org/abs/2405.14410 ↩