概述
MambaVision是由NVIDIA提出的首个混合Mamba-Transformer视觉骨干架构,旨在结合状态空间模型(SSM)的高效长程建模能力与Transformer的强大表达能力。1
与之前的纯SSM视觉架构不同,MambaVision采用分阶段混合设计:
- 前两个阶段使用纯CNN块,保留局部特征提取能力
- 后两个阶段使用混合Mamba+Transformer块,融合两种架构的优势
核心设计
整体架构
输入图像 (H×W×3)
↓
Stage 1: Patch Embedding + CNN Blocks (局部特征)
↓
Stage 2: Patch Embedding + CNN Blocks (局部特征)
↓
Stage 3: Hybrid Mamba-Transformer Blocks (混合建模)
↓
Stage 4: Hybrid Mamba-Transformer Blocks (混合建模)
↓
输出特征 (H/32×W/32×C)
阶段配置
| 阶段 | 输出分辨率 | 通道数 | 层数 | 模块类型 |
|---|---|---|---|---|
| Stage 1 | H/4 × W/4 | 64 | 3 | CNN |
| Stage 2 | H/8 × W/8 | 128 | 3 | CNN |
| Stage 3 | H/16 × W/16 | 256 | 27 | Hybrid Mamba-Transformer |
| Stage 4 | H/32 × W/32 | 512 | 3 | Hybrid Mamba-Transformer |
Mamba块设计
MambaVision对原始Mamba进行了关键修改,使其更适合视觉任务:
问题分析
原始Mamba的因果卷积设计存在两个问题:
- 缺失的空间建模:因果扫描沿序列方向,忽略了空间局部性
- 不对称信息流:只能看到当前位置之前的token
解决方案:对称分支设计
class MambaVisionBlock(nn.Module):
def __init__(self, dim, state_dim=16):
super().__init__()
# 分支1:选择性状态空间(压缩历史信息)
self.sss = MambaBlock(dim, state_dim)
# 分支2:对称路径(保持空间信息)
self.sym_branch = nn.Sequential(
nn.Conv2d(dim, dim, 3, padding=1, groups=dim), # 深度可分离卷积
nn.Conv2d(dim, dim, 1) # 逐点卷积
)
# 门控机制
self.gate = nn.Sigmoid()
def forward(self, x):
# 并行处理
sss_out = self.sss(x)
sym_out = self.sym_branch(x)
# 门控融合
return self.gate(sss_out) * sym_out关键改进
- 移除因果卷积:用常规卷积替代,保持空间结构
- 添加对称分支:提供一个不经过SSM的信息通路
- 门控融合:动态控制两个分支的信息流
与其他视觉SSM的对比
扫描策略对比
| 模型 | 扫描方式 | 空间建模 | 因果性 |
|---|---|---|---|
| Vision Mamba (Vim) | 十字扫描 | 部分 | 因果 |
| VMamba | 十字扫描 + Cross-scan | 部分 | 非因果 |
| MambaVision | 2D选择性扫描 | 完整 | 非因果 |
| Pure ViT | 全注意力 | 完整 | 非因果 |
准确率-吞吐量权衡
吞吐量 (images/sec on A100)
↑
│ ◆ MambaVision-L
1000 │ ●
│ ■
500 │ ■
│ ●
200 │ ● ◆ MambaVision-B
│ ■ ◆
100 │ ●■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
│
└─────────────────────────────────────────→
准确率 (ImageNet Top-1 %)
● Pure SSM ■ Pure ViT ◆ Hybrid
实验结果
ImageNet分类
| 模型 | 参数量 | FLOPs | Top-1 | Throughput |
|---|---|---|---|---|
| MambaVision-T | 22M | 4.5G | 82.3% | 1520 img/s |
| MambaVision-S | 50M | 8.9G | 83.3% | 890 img/s |
| MambaVision-B | 98M | 15.G | 84.2% | 580 img/s |
| MambaVision-L | 228M | 37.9G | 85.0% | 290 img/s |
下游任务
COCO目标检测 (Mask R-CNN)
| Backbone | AP^b | AP^m | AP^25 |
|---|---|---|---|
| MambaVision-T | 51.0 | 46.2 | 71.3 |
| MambaVision-S | 52.1 | 47.1 | 72.5 |
| MambaVision-B | 53.4 | 48.0 | 73.8 |
| Swin-T | 50.6 | 46.1 | 70.9 |
ADE20K语义分割 (UperNet)
| Backbone | mIoU (SS) | mIoU (MS) |
|---|---|---|
| MambaVision-S | 48.2 | 49.5 |
| MambaVision-B | 50.1 | 51.4 |
| Swin-S | 47.6 | 49.0 |
设计启示
为什么需要混合设计?
- 局部特征提取:CNN的归纳偏置(局部性、平移不变性)对于低级视觉特征提取非常有效
- 全局依赖建模:SSM和Transformer在高级特征阶段捕获长程依赖
- 效率权衡:纯Transformer在早期阶段计算代价过高
SSM vs Transformer的选择
| 特性 | SSM (Mamba) | Transformer |
|---|---|---|
| 计算复杂度 | O(N) | O(N²) |
| 长程依赖 | 压缩表示 | 完整表示 |
| 局部建模 | 较弱 | 强 |
| 内存效率 | 高 | 中 |
| 表达多样性 | 中 | 高 |
PyTorch实现概要
import torch
import torch.nn as nn
from einops import rearrange
class MambaVisionLayer(nn.Module):
def __init__(self, dim, state_dim=16, expand=2):
super().__init__()
self.dim = dim
self.state_dim = state_dim
self.expand = expand
d_inner = dim * expand
# 输入投影
self.in_proj = nn.Linear(dim, d_inner * 2, bias=False)
# SSM分支
self.ssm = MambaBlock(d_inner, state_dim)
# 对称分支(卷积)
self.conv = nn.Conv2d(d_inner, d_inner, 3, padding=1, groups=d_inner)
# 输出投影
self.out_proj = nn.Linear(d_inner, dim, bias=False)
# 门控
self.gate = nn.SiLU()
def forward(self, x, H, W):
B, N, C = x.shape
# 输入投影 + 分割
x_inner = self.in_proj(x)
x_ssm, x_conv = x_inner.chunk(2, dim=-1)
# SSM分支
ssm_out = self.ssm(x_ssm)
# 对称分支(转换为2D,处理,再转回序列)
x_conv = rearrange(x_conv, 'b (h w) c -> b c h w', h=H, w=W)
conv_out = self.conv(x_conv)
conv_out = rearrange(conv_out, 'b c h w -> b (h w) c')
# 门控融合
out = self.gate(ssm_out) * conv_out
# 输出投影
return self.out_proj(out)总结
MambaVision的核心贡献:
- 首个混合Mamba-Transformer视觉骨干,结合两种架构的优势
- 对称分支设计,解决原始Mamba的空间建模问题
- 分阶段混合策略,在准确率和效率间取得最佳权衡
- 在多个视觉任务上达到SOTA,验证了混合架构的有效性
参考文献
相关主题
Footnotes
-
Liu, Y., et al. (2024). MambaVision: A Hybrid Mamba-Transformer Vision Backbone. arXiv:2407.07505. https://arxiv.org/abs/2407.07505 ↩