概述

MambaVision是由NVIDIA提出的首个混合Mamba-Transformer视觉骨干架构,旨在结合状态空间模型(SSM)的高效长程建模能力与Transformer的强大表达能力。1

与之前的纯SSM视觉架构不同,MambaVision采用分阶段混合设计

  • 前两个阶段使用纯CNN块,保留局部特征提取能力
  • 后两个阶段使用混合Mamba+Transformer块,融合两种架构的优势

核心设计

整体架构

输入图像 (H×W×3)
        ↓
Stage 1: Patch Embedding + CNN Blocks (局部特征)
        ↓
Stage 2: Patch Embedding + CNN Blocks (局部特征)
        ↓
Stage 3: Hybrid Mamba-Transformer Blocks (混合建模)
        ↓
Stage 4: Hybrid Mamba-Transformer Blocks (混合建模)
        ↓
输出特征 (H/32×W/32×C)

阶段配置

阶段输出分辨率通道数层数模块类型
Stage 1H/4 × W/4643CNN
Stage 2H/8 × W/81283CNN
Stage 3H/16 × W/1625627Hybrid Mamba-Transformer
Stage 4H/32 × W/325123Hybrid Mamba-Transformer

Mamba块设计

MambaVision对原始Mamba进行了关键修改,使其更适合视觉任务:

问题分析

原始Mamba的因果卷积设计存在两个问题:

  1. 缺失的空间建模:因果扫描沿序列方向,忽略了空间局部性
  2. 不对称信息流:只能看到当前位置之前的token

解决方案:对称分支设计

class MambaVisionBlock(nn.Module):
    def __init__(self, dim, state_dim=16):
        super().__init__()
        # 分支1:选择性状态空间(压缩历史信息)
        self.sss = MambaBlock(dim, state_dim)
        
        # 分支2:对称路径(保持空间信息)
        self.sym_branch = nn.Sequential(
            nn.Conv2d(dim, dim, 3, padding=1, groups=dim),  # 深度可分离卷积
            nn.Conv2d(dim, dim, 1)  # 逐点卷积
        )
        
        # 门控机制
        self.gate = nn.Sigmoid()
    
    def forward(self, x):
        # 并行处理
        sss_out = self.sss(x)
        sym_out = self.sym_branch(x)
        # 门控融合
        return self.gate(sss_out) * sym_out

关键改进

  1. 移除因果卷积:用常规卷积替代,保持空间结构
  2. 添加对称分支:提供一个不经过SSM的信息通路
  3. 门控融合:动态控制两个分支的信息流

与其他视觉SSM的对比

扫描策略对比

模型扫描方式空间建模因果性
Vision Mamba (Vim)十字扫描部分因果
VMamba十字扫描 + Cross-scan部分非因果
MambaVision2D选择性扫描完整非因果
Pure ViT全注意力完整非因果

准确率-吞吐量权衡

吞吐量 (images/sec on A100)
     ↑
     │                                    ◆ MambaVision-L
1000 │                          ●
     │                    ■
 500 │              ■
     │        ●
 200 │    ●           ◆ MambaVision-B
     │  ■ ◆
 100 │       ●■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
     │
     └─────────────────────────────────────────→
                      准确率 (ImageNet Top-1 %)
                      
     ● Pure SSM      ■ Pure ViT      ◆ Hybrid

实验结果

ImageNet分类

模型参数量FLOPsTop-1Throughput
MambaVision-T22M4.5G82.3%1520 img/s
MambaVision-S50M8.9G83.3%890 img/s
MambaVision-B98M15.G84.2%580 img/s
MambaVision-L228M37.9G85.0%290 img/s

下游任务

COCO目标检测 (Mask R-CNN)

BackboneAP^bAP^mAP^25
MambaVision-T51.046.271.3
MambaVision-S52.147.172.5
MambaVision-B53.448.073.8
Swin-T50.646.170.9

ADE20K语义分割 (UperNet)

BackbonemIoU (SS)mIoU (MS)
MambaVision-S48.249.5
MambaVision-B50.151.4
Swin-S47.649.0

设计启示

为什么需要混合设计?

  1. 局部特征提取:CNN的归纳偏置(局部性、平移不变性)对于低级视觉特征提取非常有效
  2. 全局依赖建模:SSM和Transformer在高级特征阶段捕获长程依赖
  3. 效率权衡:纯Transformer在早期阶段计算代价过高

SSM vs Transformer的选择

特性SSM (Mamba)Transformer
计算复杂度O(N)O(N²)
长程依赖压缩表示完整表示
局部建模较弱
内存效率
表达多样性

PyTorch实现概要

import torch
import torch.nn as nn
from einops import rearrange
 
class MambaVisionLayer(nn.Module):
    def __init__(self, dim, state_dim=16, expand=2):
        super().__init__()
        self.dim = dim
        self.state_dim = state_dim
        self.expand = expand
        
        d_inner = dim * expand
        
        # 输入投影
        self.in_proj = nn.Linear(dim, d_inner * 2, bias=False)
        
        # SSM分支
        self.ssm = MambaBlock(d_inner, state_dim)
        
        # 对称分支(卷积)
        self.conv = nn.Conv2d(d_inner, d_inner, 3, padding=1, groups=d_inner)
        
        # 输出投影
        self.out_proj = nn.Linear(d_inner, dim, bias=False)
        
        # 门控
        self.gate = nn.SiLU()
    
    def forward(self, x, H, W):
        B, N, C = x.shape
        
        # 输入投影 + 分割
        x_inner = self.in_proj(x)
        x_ssm, x_conv = x_inner.chunk(2, dim=-1)
        
        # SSM分支
        ssm_out = self.ssm(x_ssm)
        
        # 对称分支(转换为2D,处理,再转回序列)
        x_conv = rearrange(x_conv, 'b (h w) c -> b c h w', h=H, w=W)
        conv_out = self.conv(x_conv)
        conv_out = rearrange(conv_out, 'b c h w -> b (h w) c')
        
        # 门控融合
        out = self.gate(ssm_out) * conv_out
        
        # 输出投影
        return self.out_proj(out)

总结

MambaVision的核心贡献:

  1. 首个混合Mamba-Transformer视觉骨干,结合两种架构的优势
  2. 对称分支设计,解决原始Mamba的空间建模问题
  3. 分阶段混合策略,在准确率和效率间取得最佳权衡
  4. 在多个视觉任务上达到SOTA,验证了混合架构的有效性

参考文献


相关主题

Footnotes

  1. Liu, Y., et al. (2024). MambaVision: A Hybrid Mamba-Transformer Vision Backbone. arXiv:2407.07505. https://arxiv.org/abs/2407.07505