概述

GroupMamba是由CVPR 2025提出的参数高效视觉状态空间模型,通过分组机制和通道亲和力调制,在保持性能的同时显著降低参数量。1

核心创新:

  • 分组选择性扫描:将通道分成多组,每组独立扫描不同方向
  • 通道亲和力调制:跨组信息交互
  • 蒸馏训练:稳定大模型训练
  • 26%参数减少:相比同性能模型

设计动机

现有视觉SSM的问题

问题描述影响
训练不稳定SSM在大模型上训练困难难以扩展
计算低效全通道SSM计算量大效率瓶颈
覆盖不足单方向扫描空间覆盖有限表征能力弱

GroupMamba解决思路

借鉴分组卷积的成功经验:

  • 分组可以减少参数和计算量
  • 多组可以增加方向覆盖
  • 跨组交互保持信息流动

核心架构

调制分组Mamba层 (Modulated Group Mamba Layer)

输入张量 (B, C, H, W)
        ↓
    通道分组 (4组)
        ↓
┌─────────────────────────────────────┐
│  ┌─────────┐ ┌─────────┐           │
│  │ Group 1 │ │ Group 2 │  独立扫描 │
│  │ 方向→   │ │ 方向←   │           │
│  └─────────┘ └─────────┘           │
│  ┌─────────┐ ┌─────────┐           │
│  │ Group 3 │ │ Group 4 │           │
│  │ 方向↓   │ │ 方向↑   │           │
│  └─────────┘ └─────────┘           │
└─────────────────────────────────────┘
        ↓
  通道亲和力调制 (CAM)
        ↓
    输出张量 (B, C, H, W)

扫描方向分配

组号扫描方向覆盖区域
Group 1水平从左到右行内前向依赖
Group 2水平从右到左行内后向依赖
Group 3垂直从上到下列内前向依赖
Group 4垂直从下到上列内后向依赖

数学公式

分组扫描

其中 表示组号。

通道亲和力调制

CAM通过可学习的投影实现跨组信息交互:

class ChannelAffinityModulation(nn.Module):
    def __init__(self, dim, num_groups=4):
        super().__init__()
        self.num_groups = num_groups
        self.group_dim = dim // num_groups
        
        # 跨组亲和力矩阵
        self.affinity = nn.Parameter(torch.eye(num_groups))
        
        # 输出投影
        self.proj = nn.Linear(dim, dim)
    
    def forward(self, group_outputs):
        # group_outputs: [B, num_groups, H, W, group_dim]
        B, G, H, W, D = group_outputs.shape
        
        # 聚合组信息
        group_features = group_outputs.mean(dim=(2, 3))  # [B, G, D]
        
        # 计算亲和力加权
        weights = F.softmax(self.affinity, dim=-1)  # [G, G]
        aggregated = torch.einsum('bgd,gh->bhd', group_features, weights)
        
        # 广播回空间位置
        aggregated = aggregated.unsqueeze(2).unsqueeze(2).expand(-1, -1, H, W, -1)
        
        # 与原始输出融合
        fused = torch.cat([group_outputs, aggregated], dim=-1)
        return self.proj(fused)

整体架构

层次化结构

输入图像 (3×H×W)
    ↓
Stem: 3×3 Conv + LayerNorm
    ↓
Stage 1: Patch Embedding + GMamba × 2
    ↓ (2×下采样)
Stage 2: Patch Embedding + GMamba × 3
    ↓ (2×下采样)
Stage 3: Patch Embedding + GMamba × 4
    ↓ (2×下采样)
Stage 4: Patch Embedding + GMamba × 2
    ↓
Head: Linear + GELU

变体配置

变体通道数层数参数量FLOPs
GroupMamba-Ti[64,128,256,512][2,3,4,2]23M4.5G
GroupMamba-S[96,192,384,768][2,3,4,2]49M9.8G
GroupMamba-B[128,256,512,1024][2,3,4,2]102M20.5G

蒸馏训练策略

问题背景

大模型训练不稳定的原因:

  1. SSM的选择性机制对初始化敏感
  2. 深层梯度可能爆炸或消失
  3. 分组训练缺乏监督信号

蒸馏方案

从纯SSM到GroupMamba的知识蒸馏

class DistillationLoss(nn.Module):
    def __init__(self, teacher_model, student_model):
        super().__init__()
        self.teacher = teacher_model
        self.student = student_model
        self.alpha = 0.5  # 蒸馏权重
    
    def forward(self, x, labels):
        with torch.no_grad():
            teacher_logits = self.teacher(x)
        
        student_logits = self.student(x)
        
        # 交叉熵损失
        ce_loss = F.cross_entropy(student_logits, labels)
        
        # 蒸馏损失 (KL散度)
        kd_loss = F.kl_div(
            F.log_softmax(student_logits / T, dim=-1),
            F.softmax(teacher_logits / T, dim=-1),
            reduction='batchmean'
        ) * (T * T)
        
        return (1 - self.alpha) * ce_loss + self.alpha * kd_loss

分阶段训练

  1. 阶段1:训练纯SSM教师模型(收敛)
  2. 阶段2:将知识蒸馏到GroupMamba学生模型

实验结果

ImageNet分类

模型参数量FLOPsTop-1相对效率
GroupMamba-Ti23M4.5G80.2%+15%
Vim-T22M5.2G79.3%baseline
GroupMamba-S49M9.8G83.3%+26%
Vim-S48M12.1G81.5%baseline

效率提升分析

参数效率对比 (ImageNet Top-1 vs 参数量):

准确率 ↑
 83% │                              ◆ GroupMamba-S
      │                         ■
 82% │                    ■
      │               ◆
 81% │          ■
      │     ◆
 80% │■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■
      └──────────────────────────────────→ 参数量(M)
           25    50    75    100
           
■ Vim    ◆ GroupMamba

下游任务

COCO目标检测 (Mask R-CNN)

BackboneAP^bAP^m相对提升
Vim-S48.243.5baseline
GroupMamba-S50.145.2+4%

ADE20K语义分割 (UperNet)

BackbonemIoU相对提升
Vim-S45.8baseline
GroupMamba-S48.1+5%

与其他视觉SSM对比

特性VimVMambaMambaVisionGroupMamba
扫描方向2方向4方向(十字)4方向(选择性)4方向(分组)
参数效率最高
训练稳定性好(蒸馏)
ImageNet79.3%82.6%83.3%83.3%
检测性能

消融实验

分组数量

组数参数量Top-1内存占用
1 (无分组)48M81.5%100%
226M82.1%54%
416M83.3%33%
811M82.8%23%

通道亲和力调制

配置Top-1说明
无CAM81.9%组间无信息交互
CAM简单平均82.6%等权平均
CAM可学习83.3%学习最优融合

PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class GroupMambaBlock(nn.Module):
    def __init__(self, dim, num_groups=4, state_dim=16, expand=2):
        super().__init__()
        self.num_groups = num_groups
        self.group_dim = dim // num_groups
        
        d_inner = dim * expand
        
        # 分组输入投影
        self.in_proj = nn.Linear(dim, d_inner * num_groups, bias=False)
        
        # 分组SSM层
        self.ssm_layers = nn.ModuleList([
            SelectiveSSM2D(self.group_dim * expand, state_dim, direction=direction)
            for direction in ['hR', 'hL', 'vD', 'vU']  # 4个方向
        ])
        
        # 通道亲和力调制
        self.cam = ChannelAffinityModulation(d_inner, num_groups)
        
        # 输出投影
        self.out_proj = nn.Linear(d_inner, dim, bias=False)
        
        # 门控
        self.gate = nn.Sigmoid()
    
    def forward(self, x):
        B, C, H, W = x.shape
        
        # 输入投影 + 分组
        x_flat = x.flatten(2).transpose(1, 2)  # B, HW, C
        x_proj = self.in_proj(x_flat)  # B, HW, d_inner*4
        x_groups = x_proj.reshape(B, H*W, self.num_groups, -1)  # B, HW, 4, d_inner
        
        # 四个方向的独立SSM处理
        group_outputs = []
        for g, (ssm, direction) in enumerate(zip(self.ssm_layers, ['hR', 'hL', 'vD', 'vU'])):
            group_feat = x_groups[:, :, g, :]  # B, HW, d_inner
            # 重塑为2D进行处理
            group_feat_2d = group_feat.reshape(B, H, W, -1).permute(0, 3, 1, 2)
            out_2d = ssm(group_feat_2d)
            out_flat = out_2d.permute(0, 2, 3, 1).reshape(B, H*W, -1)
            group_outputs.append(out_flat)
        
        # 通道亲和力调制
        group_outputs = torch.stack(group_outputs, dim=2)  # B, HW, 4, d_inner
        fused = self.cam(group_outputs)
        
        # 门控 + 输出
        fused = fused.reshape(B, H*W, -1)
        out = self.gate(fused[:, :, :C]) * self.out_proj(fused)
        
        # 恢复形状
        out = out.transpose(1, 2).reshape(B, C, H, W)
        return out + x  # 残差连接

总结

GroupMamba的核心贡献:

  1. 分组选择性扫描,借鉴分组卷积思想
  2. 通道亲和力调制,实现跨组信息交互
  3. 蒸馏训练策略,稳定大模型训练
  4. 26%参数效率提升,显著优于基线
  5. 多方向覆盖,4个方向捕获完整空间依赖

设计启示

  • 分组是提高效率的有效手段
  • 多方向覆盖比单方向更有效
  • 蒸馏可以桥接效率和性能

参考文献


相关主题

Footnotes

  1. Chen, K., et al. (2024). GroupMamba: Parameter-Efficient and Accurate Visual State Space Model. arXiv:2407.13772. https://arxiv.org/abs/2407.13772