概述
GroupMamba是由CVPR 2025提出的参数高效视觉状态空间模型,通过分组机制和通道亲和力调制,在保持性能的同时显著降低参数量。1
核心创新:
- 分组选择性扫描:将通道分成多组,每组独立扫描不同方向
- 通道亲和力调制:跨组信息交互
- 蒸馏训练:稳定大模型训练
- 26%参数减少:相比同性能模型
设计动机
现有视觉SSM的问题
| 问题 | 描述 | 影响 |
|---|---|---|
| 训练不稳定 | SSM在大模型上训练困难 | 难以扩展 |
| 计算低效 | 全通道SSM计算量大 | 效率瓶颈 |
| 覆盖不足 | 单方向扫描空间覆盖有限 | 表征能力弱 |
GroupMamba解决思路
借鉴分组卷积的成功经验:
- 分组可以减少参数和计算量
- 多组可以增加方向覆盖
- 跨组交互保持信息流动
核心架构
调制分组Mamba层 (Modulated Group Mamba Layer)
输入张量 (B, C, H, W)
↓
通道分组 (4组)
↓
┌─────────────────────────────────────┐
│ ┌─────────┐ ┌─────────┐ │
│ │ Group 1 │ │ Group 2 │ 独立扫描 │
│ │ 方向→ │ │ 方向← │ │
│ └─────────┘ └─────────┘ │
│ ┌─────────┐ ┌─────────┐ │
│ │ Group 3 │ │ Group 4 │ │
│ │ 方向↓ │ │ 方向↑ │ │
│ └─────────┘ └─────────┘ │
└─────────────────────────────────────┘
↓
通道亲和力调制 (CAM)
↓
输出张量 (B, C, H, W)
扫描方向分配
| 组号 | 扫描方向 | 覆盖区域 |
|---|---|---|
| Group 1 | 水平从左到右 | 行内前向依赖 |
| Group 2 | 水平从右到左 | 行内后向依赖 |
| Group 3 | 垂直从上到下 | 列内前向依赖 |
| Group 4 | 垂直从下到上 | 列内后向依赖 |
数学公式
分组扫描:
其中 表示组号。
通道亲和力调制:
CAM通过可学习的投影实现跨组信息交互:
class ChannelAffinityModulation(nn.Module):
def __init__(self, dim, num_groups=4):
super().__init__()
self.num_groups = num_groups
self.group_dim = dim // num_groups
# 跨组亲和力矩阵
self.affinity = nn.Parameter(torch.eye(num_groups))
# 输出投影
self.proj = nn.Linear(dim, dim)
def forward(self, group_outputs):
# group_outputs: [B, num_groups, H, W, group_dim]
B, G, H, W, D = group_outputs.shape
# 聚合组信息
group_features = group_outputs.mean(dim=(2, 3)) # [B, G, D]
# 计算亲和力加权
weights = F.softmax(self.affinity, dim=-1) # [G, G]
aggregated = torch.einsum('bgd,gh->bhd', group_features, weights)
# 广播回空间位置
aggregated = aggregated.unsqueeze(2).unsqueeze(2).expand(-1, -1, H, W, -1)
# 与原始输出融合
fused = torch.cat([group_outputs, aggregated], dim=-1)
return self.proj(fused)整体架构
层次化结构
输入图像 (3×H×W)
↓
Stem: 3×3 Conv + LayerNorm
↓
Stage 1: Patch Embedding + GMamba × 2
↓ (2×下采样)
Stage 2: Patch Embedding + GMamba × 3
↓ (2×下采样)
Stage 3: Patch Embedding + GMamba × 4
↓ (2×下采样)
Stage 4: Patch Embedding + GMamba × 2
↓
Head: Linear + GELU
变体配置
| 变体 | 通道数 | 层数 | 参数量 | FLOPs |
|---|---|---|---|---|
| GroupMamba-Ti | [64,128,256,512] | [2,3,4,2] | 23M | 4.5G |
| GroupMamba-S | [96,192,384,768] | [2,3,4,2] | 49M | 9.8G |
| GroupMamba-B | [128,256,512,1024] | [2,3,4,2] | 102M | 20.5G |
蒸馏训练策略
问题背景
大模型训练不稳定的原因:
- SSM的选择性机制对初始化敏感
- 深层梯度可能爆炸或消失
- 分组训练缺乏监督信号
蒸馏方案
从纯SSM到GroupMamba的知识蒸馏:
class DistillationLoss(nn.Module):
def __init__(self, teacher_model, student_model):
super().__init__()
self.teacher = teacher_model
self.student = student_model
self.alpha = 0.5 # 蒸馏权重
def forward(self, x, labels):
with torch.no_grad():
teacher_logits = self.teacher(x)
student_logits = self.student(x)
# 交叉熵损失
ce_loss = F.cross_entropy(student_logits, labels)
# 蒸馏损失 (KL散度)
kd_loss = F.kl_div(
F.log_softmax(student_logits / T, dim=-1),
F.softmax(teacher_logits / T, dim=-1),
reduction='batchmean'
) * (T * T)
return (1 - self.alpha) * ce_loss + self.alpha * kd_loss分阶段训练:
- 阶段1:训练纯SSM教师模型(收敛)
- 阶段2:将知识蒸馏到GroupMamba学生模型
实验结果
ImageNet分类
| 模型 | 参数量 | FLOPs | Top-1 | 相对效率 |
|---|---|---|---|---|
| GroupMamba-Ti | 23M | 4.5G | 80.2% | +15% |
| Vim-T | 22M | 5.2G | 79.3% | baseline |
| GroupMamba-S | 49M | 9.8G | 83.3% | +26% |
| Vim-S | 48M | 12.1G | 81.5% | baseline |
效率提升分析
参数效率对比 (ImageNet Top-1 vs 参数量):
准确率 ↑
83% │ ◆ GroupMamba-S
│ ■
82% │ ■
│ ◆
81% │ ■
│ ◆
80% │■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
└──────────────────────────────────→ 参数量(M)
25 50 75 100
■ Vim ◆ GroupMamba
下游任务
COCO目标检测 (Mask R-CNN)
| Backbone | AP^b | AP^m | 相对提升 |
|---|---|---|---|
| Vim-S | 48.2 | 43.5 | baseline |
| GroupMamba-S | 50.1 | 45.2 | +4% |
ADE20K语义分割 (UperNet)
| Backbone | mIoU | 相对提升 |
|---|---|---|
| Vim-S | 45.8 | baseline |
| GroupMamba-S | 48.1 | +5% |
与其他视觉SSM对比
| 特性 | Vim | VMamba | MambaVision | GroupMamba |
|---|---|---|---|---|
| 扫描方向 | 2方向 | 4方向(十字) | 4方向(选择性) | 4方向(分组) |
| 参数效率 | 低 | 中 | 高 | 最高 |
| 训练稳定性 | 中 | 中 | 好 | 好(蒸馏) |
| ImageNet | 79.3% | 82.6% | 83.3% | 83.3% |
| 检测性能 | 中 | 高 | 高 | 高 |
消融实验
分组数量
| 组数 | 参数量 | Top-1 | 内存占用 |
|---|---|---|---|
| 1 (无分组) | 48M | 81.5% | 100% |
| 2 | 26M | 82.1% | 54% |
| 4 | 16M | 83.3% | 33% |
| 8 | 11M | 82.8% | 23% |
通道亲和力调制
| 配置 | Top-1 | 说明 |
|---|---|---|
| 无CAM | 81.9% | 组间无信息交互 |
| CAM简单平均 | 82.6% | 等权平均 |
| CAM可学习 | 83.3% | 学习最优融合 |
PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class GroupMambaBlock(nn.Module):
def __init__(self, dim, num_groups=4, state_dim=16, expand=2):
super().__init__()
self.num_groups = num_groups
self.group_dim = dim // num_groups
d_inner = dim * expand
# 分组输入投影
self.in_proj = nn.Linear(dim, d_inner * num_groups, bias=False)
# 分组SSM层
self.ssm_layers = nn.ModuleList([
SelectiveSSM2D(self.group_dim * expand, state_dim, direction=direction)
for direction in ['hR', 'hL', 'vD', 'vU'] # 4个方向
])
# 通道亲和力调制
self.cam = ChannelAffinityModulation(d_inner, num_groups)
# 输出投影
self.out_proj = nn.Linear(d_inner, dim, bias=False)
# 门控
self.gate = nn.Sigmoid()
def forward(self, x):
B, C, H, W = x.shape
# 输入投影 + 分组
x_flat = x.flatten(2).transpose(1, 2) # B, HW, C
x_proj = self.in_proj(x_flat) # B, HW, d_inner*4
x_groups = x_proj.reshape(B, H*W, self.num_groups, -1) # B, HW, 4, d_inner
# 四个方向的独立SSM处理
group_outputs = []
for g, (ssm, direction) in enumerate(zip(self.ssm_layers, ['hR', 'hL', 'vD', 'vU'])):
group_feat = x_groups[:, :, g, :] # B, HW, d_inner
# 重塑为2D进行处理
group_feat_2d = group_feat.reshape(B, H, W, -1).permute(0, 3, 1, 2)
out_2d = ssm(group_feat_2d)
out_flat = out_2d.permute(0, 2, 3, 1).reshape(B, H*W, -1)
group_outputs.append(out_flat)
# 通道亲和力调制
group_outputs = torch.stack(group_outputs, dim=2) # B, HW, 4, d_inner
fused = self.cam(group_outputs)
# 门控 + 输出
fused = fused.reshape(B, H*W, -1)
out = self.gate(fused[:, :, :C]) * self.out_proj(fused)
# 恢复形状
out = out.transpose(1, 2).reshape(B, C, H, W)
return out + x # 残差连接总结
GroupMamba的核心贡献:
- 分组选择性扫描,借鉴分组卷积思想
- 通道亲和力调制,实现跨组信息交互
- 蒸馏训练策略,稳定大模型训练
- 26%参数效率提升,显著优于基线
- 多方向覆盖,4个方向捕获完整空间依赖
设计启示
- 分组是提高效率的有效手段
- 多方向覆盖比单方向更有效
- 蒸馏可以桥接效率和性能
参考文献
相关主题
Footnotes
-
Chen, K., et al. (2024). GroupMamba: Parameter-Efficient and Accurate Visual State Space Model. arXiv:2407.13772. https://arxiv.org/abs/2407.13772 ↩