概述

Mamba2D是首个从第一性原理推导的原生二维状态空间模型,不同于之前将1D SSM应用于展平序列的方法,Mamba2D直接在2D图像上进行状态空间建模。1

核心创新包括:

  • 原生2D扫描:沿两个空间方向并行扫描
  • Wavefront CUDA核:对角线并行,充分利用GPU并行能力
  • 子线性扩展:高分辨率下延迟接近O(N^0.56)

从1D到2D的理论推导

标准Mamba的1D公式

对于1D输入序列 ,选择性状态空间模型定义为:

其中 是状态转移矩阵。

2D扩展的挑战

将1D SSM展平应用于2D图像存在根本问题:

问题描述
空间结构丢失展平后相邻像素在1D序列中可能距离很远
扫描方向固定单一扫描方向无法捕获所有空间关系
并行度低串行扫描无法充分利用GPU

原生2D公式

设2D输入 ,定义4个方向的扫描:

水平向右

水平向左

垂直向下

垂直向上

融合输出


Wavefront并行扫描

对角线并行原理

Wavefront扫描利用2D依赖关系的对角线结构:

时间步 t=0:  (0,0)
时间步 t=1:  (0,1) (1,0)
时间步 t=2:  (0,2) (1,1) (2,0)
时间步 t=3:  (0,3) (1,2) (2,1) (3,0)
    ...

每一时间步处理的像素满足:

GPU并行化优势

扫描方式并行度GPU利用率
行扫描O(W)中等
列扫描O(H)中等
WavefrontO(min(H,W))

Wavefront扫描确保每个时间步有大量像素可并行处理:

def wavefront_scan(grid, scan_directions):
    """Wavefront并行扫描实现"""
    H, W = grid.shape[:2]
    
    # 初始化4个方向的隐藏状态
    states_hR = torch.zeros_like(grid)
    states_hL = torch.zeros_like(grid)
    states_vD = torch.zeros_like(grid)
    states_vU = torch.zeros_like(grid)
    
    # 从左上到右下的对角线
    for t in range(H + W - 1):
        # 找到当前对角线上的所有像素
        for i in range(max(0, t - W + 1), min(H, t + 1)):
            j = t - i
            if 0 <= j < W:
                # 四个方向的并行更新
                # hR: 来自左边
                if j > 0:
                    states_hR[i,j] = A_h * states_hR[i,j-1] + B_h * grid[i,j]
                # hL: 来自右边
                if j < W - 1:
                    states_hL[i,j] = A_h * states_hL[i,j+1] + B_h * grid[i,j]
                # vD: 来自上边
                if i > 0:
                    states_vD[i,j] = A_v * states_vD[i-1,j] + B_v * grid[i,j]
                # vU: 来自下边
                if i < H - 1:
                    states_vU[i,j] = A_v * states_vU[i+1,j] + B_v * grid[i,j]
    
    return states_hR + states_hL + states_vD + states_vU

架构设计

整体结构

Mamba2D采用类似ResNet的层次化结构:

输入图像
    ↓
Stage 1: Conv Stem + Mamba2D Block × 3
    ↓ (2×下采样)
Stage 2: Mamba2D Block × 4
    ↓ (2×下采样)
Stage 3: Mamba2D Block × 6
    ↓ (2×下采样)
Stage 4: Mamba2D Block × 3
    ↓
分类头 / 检测头 / 分割头

Mamba2D Block结构

class Mamba2DBlock(nn.Module):
    def __init__(self, dim, state_dim=16, expand=2):
        super().__init__()
        d_inner = dim * expand
        
        # 归一化
        self.norm = nn.LayerNorm(dim)
        
        # 投影
        self.x_proj = nn.Linear(dim, dim * 2 + state_dim * 4, bias=False)
        self.dt_proj = nn.Linear(state_dim * 4, dim, bias=True)
        
        # 四个方向的SSM
        self.ssm_hR = SelectiveSSM2D(dim, state_dim, direction='hR')
        self.ssm_hL = SelectiveSSM2D(dim, state_dim, direction='hL')
        self.ssm_vD = SelectiveSSM2D(dim, state_dim, direction='vD')
        self.ssm_vU = SelectiveSSM2D(dim, state_dim, direction='vU')
        
        # 融合门控
        self.gate = nn.Sequential(
            nn.Linear(dim * 4, dim),
            nn.SiLU(),
            nn.Linear(dim, dim * 4)
        )
        
        # 输出投影
        self.proj = nn.Linear(dim, dim)
    
    def forward(self, x):
        B, C, H, W = x.shape
        
        # 残差连接
        residual = x
        
        # 归一化
        x_norm = self.norm(x.permute(0, 2, 3, 1))  # B,H,W,C
        x_flat = x_norm.reshape(B * H * W, C)
        
        # 计算SSM参数
        xz = self.x_proj(x_flat)
        x_inner, dt = xz[:, :C*2], xz[:, C*2:]
        
        # 四个方向的扫描
        out_hR = self.ssm_hR(x_inner[:, :C], dt)
        out_hL = self.ssm_hL(x_inner[:, C:C*2], dt)
        out_vD = self.ssm_vD(x_inner[:, C*2:C*3], dt)
        out_vU = self.ssm_vU(x_inner[:, C*3:], dt)
        
        # 融合
        combined = torch.cat([out_hR, out_hL, out_vD, out_vU], dim=-1)
        gated = self.gate(combined)
        out = self.proj(gated)
        
        # 恢复形状 + 残差
        out = out.reshape(B, H, W, C).permute(0, 3, 1, 2)
        return out + residual

实验结果

ImageNet分类

模型参数量FLOPsTop-1吞吐量
M2D-Ti8M1.5G78.8%4200 img/s
M2D-S27M4.5G84.0%1850 img/s
M2D-B50M8.2G85.3%1100 img/s
M2D-L200M30G86.2%380 img/s

与其他视觉SSM对比

模型Top-1吞吐量(A100)特点
VMamba-T82.6%2100 img/s十字扫描
Mamba2D-S84.0%1850 img/s原生2D
LocalMamba-T82.1%1650 img/s局部窗口
Mamba2D-B85.3%1100 img/s原生2D

下游任务

COCO目标检测 (Mask R-CNN)

BackboneAP^bAP^m参数
M2D-Ti49.244.528M
M2D-S50.545.847M
M2D-B52.247.370M

ADE20K语义分割 (UperNet)

BackbonemIoU参数量
M2D-S48.952M
M2D-B51.780M
Swin-B49.088M

扩展性分析

延迟随分辨率扩展

Mamba2D的wavefront扫描实现亚线性扩展

图像尺寸注意力延迟Mamba2D延迟加速比
224²O(N²) = 50KO(N^0.56) ≈ 300167×
384²O(N²) = 147KO(N^0.56) ≈ 600245×
512²O(N²) = 262KO(N^0.56) ≈ 900291×
1024²O(N²) = 1MO(N^0.56) ≈ 2200455×

内存效率

尺寸注意力内存Mamba2D内存节省
224²256 MB32 MB
512²1.5 GB64 MB24×
1024²6 GB128 MB47×

与MambaVision的对比

特性Mamba2DMambaVision
扫描方式Wavefront 4方向选择性2D扫描
并行策略对角线并行标准并行
空间建模原生2D混合CNN+SSM
ImageNet85.3% (50M)85.0% (228M)
高分辨率效率更优中等
实现复杂度较高较低

总结

Mamba2D的核心贡献:

  1. 首个原生2D SSM,从第一性原理推导
  2. Wavefront并行扫描,充分利用GPU并行能力
  3. 子线性扩展,高分辨率下显著优于注意力
  4. 四方向融合,捕获完整空间依赖
  5. SOTA性能,成为视觉SSM的新基准

参考文献


相关主题

Footnotes

  1. Sun, Y., et al. (2024). Mamba2D: Efficient 2D State Space Model with Wavefront Parallel Scan. arXiv:2405.14410. https://arxiv.org/abs/2405.14410