SfMamba:基于选择性扫描的免源域适应

概述

SfMamba(arXiv:2601.08608)提出了一种基于选择性扫描机制的高效免源域适应(Source-Free Domain Adaptation, SFDA)框架。其核心创新是引入通道级视觉状态空间块(Ch-VSS),解决了传统方法在感受野和计算效率之间的权衡问题。

核心贡献

  1. 首个Mamba-based SFDA:利用状态空间模型的长期依赖建模能力,同时保持线性复杂度
  2. 通道级视觉状态空间块:沿通道维度进行双向扫描,学习领域不变频率特征
  3. 语义一致打乱策略:通过打乱背景补丁缓解错误累积

与现有方法对比

方法架构复杂度准确率
SHOTCNN76.2%
NRCCNN78.1%
DATMambaMamba80.9%
SfMambaMamba81.7%

问题背景

免源域适应定义

给定:

  • 预训练的源域模型
  • 无标签目标域数据

目标:学习一个目标域模型 ,使得 在目标域上表现良好。

关键约束:训练过程中不能访问源域数据。

现有方法的挑战

方法类型挑战
CNN-based感受野有限,难以捕获长距离依赖
ViT-based全局注意力导致 复杂度
Mamba-based通道交互不足,领域不变特征学习受限

方法详解

1. 整体框架

输入图像
    ↓
VMamba编码器 (2D扫描)
    ↓
通道级视觉状态空间块 (Ch-VSS)
    ↓
语义一致性打乱 (SCS)
    ↓
分类器
    ↓
KL散度一致性正则化

2. VMamba编码器

基于VMamba的2D选择性扫描:

输入: X ∈ R^{B×C×H×W}
    ↓
重塑: X' ∈ R^{B×C×(HW)} (展平为序列)
    ↓
四方向扫描:
  - 左→右
  - 右→左
  - 上→下
  - 下→上
    ↓
合并: 沿扫描方向拼接
    ↓
输出: Y ∈ R^{B×C×H×W}

选择性机制

其中 是时间步增量。

3. 通道级视觉状态空间块(Ch-VSS)

核心创新:沿通道维度应用双向SSM,实现通道间信息交互。

输入特征: F ∈ R^{B×D×H×W}
    ↓
沿通道展平: F' ∈ R^{B×D×(HW)}
    ↓
双向SSM:
  - Forward: 沿通道正序扫描
  - Backward: 沿通道逆序扫描
    ↓
输出: F'' ∈ R^{B×D×(HW)}
    ↓
重塑: F'' ∈ R^{B×D×H×W}

数学形式

其中 表示通道索引。

通道扫描的优势

  • 学习通道间的频率关系
  • 捕获领域不变的低频特征
  • 保持线性复杂度

4. 语义一致打乱策略(SCS)

问题:错误累积导致伪标签质量下降。

解决思路

  • 通过打乱背景补丁生成扰动样本
  • 要求扰动样本的预测与原始样本一致

4.1 低激活补丁识别

使用Grad-CAM识别低激活区域(背景):

# 伪代码
cam = compute_gradcam(model, x)  # (H, W)
threshold = cam.mean()  # 使用均值作为阈值
low_activation_mask = cam < threshold  # (H, W)

4.2 打乱操作

原始图像 (4×4补丁)
┌────┬────┬────┬────┐
│ P1 │ P2 │ P3 │ P4 │
├────┼────┼────┼────┤
│ P5 │ P6 │ P7 │ P8 │  ← 背景补丁
├────┼────┼────┼────┤
│ P9 │ P10│ P11│ P12│  ← 前景补丁
├────┼────┼────┼────┤
│ P13│ P14│ P15│ P16│  ← 前景补丁
└────┴────┴────┴────┘
        ↓ 打乱
┌────┬────┬────┬────┐
│ P1 │ P2 │ P3 │ P4 │
├────┼────┼────┼────┤
│ P6 │ P8 │ P5 │ P7 │  ← 打乱背景
├────┼────┼────┼────┤
│ P9 │ P10│ P11│ P12│  ← 保持前景
├────┼────┼────┼────┤
│ P13│ P14│ P15│ P16│  ← 保持前景
└────┴────┴────┴────┘

4.3 一致性损失

其中 是打乱后的图像。

5. 训练目标

损失项作用
分类损失(伪标签监督)
一致性正则化
熵正则化(鼓励高置信度预测)

实验结果

Office-Home数据集

方法Ar→ClAr→PrAr→Rw平均
Source Only58.9%65.3%76.0%66.7%
SHOT72.1%78.8%81.5%77.5%
NRC74.2%80.2%82.4%78.9%
DATMamba76.8%81.5%84.3%80.9%
SfMamba-S77.5%82.3%85.4%81.7%

VisDA-C数据集

方法准确率
Source Only52.4%
SHOT78.3%
NRC81.2%
C-SFTrans88.3%
SfMamba-S89.3%

DomainNet-126

方法准确率
Source Only42.1%
SHOT71.2%
SHOT++75.2%
SfMamba-S77.9%

效率对比

方法参数量FLOPs吞吐量
C-SFTrans86.6M17.6G1.0×
DATMamba61.2M10.1G1.5×
SfMamba-S58.9M9.2G1.26×

消融实验

Ch-VSS的效果

配置Office-Home说明
无Ch-VSS79.8%使用标准VMamba
Ch-VSS (单方向)80.9%仅正向扫描
Ch-VSS (双向)81.7%双向通道扫描

SCS的效果

配置Office-Home打乱比例
无SCS80.2%-
SCS (10%)80.8%打乱10%补丁
SCS (20%)81.7%打乱20%补丁
SCS (30%)81.3%打乱30%补丁

各组件贡献

组件Office-Home提升
VMamba基线78.5%-
+ Ch-VSS80.9%+2.4%
+ SCS81.7%+0.8%
+ 一致性损失81.7%-

PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class ChannelWiseSSM(nn.Module):
    """
    Channel-wise Visual State Space Block
    """
    def __init__(self, dim, d_state=16, dt_rank='auto'):
        super().__init__()
        self.dim = dim
        self.d_state = d_state
        
        # Input projection
        self.x_proj = nn.Linear(dim, d_state * 2, bias=False)
        
        # dt projection
        self.dt_proj = nn.Linear(d_state, dim, bias=True)
        
        # A and B parameters
        self.A_log = nn.Parameter(torch.randn(dim, d_state))
        self.B_log = nn.Parameter(torch.randn(dim, d_state))
        
        # Output projection
        self.out_proj = nn.Linear(dim, dim, bias=False)
    
    def forward(self, x):
        """
        x: (B, D, N) where D=channels, N=spatial tokens
        """
        B, D, N = x.shape
        
        # Compute input-dependent parameters
        x_gate = self.x_proj(x.transpose(1, 2))  # (B, N, d_state*2)
        x_inner, x_dt = x_gate.chunk(2, dim=-1)  # (B, N, d_state) each
        
        # Sigmoid gate
        x_inner = torch.sigmoid(x_inner)
        x_dt = torch.softplus(self.dt_proj(x_dt))  # (B, N, D)
        
        # State transition
        A = -torch.exp(self.A_log.float())  # (D, d_state)
        B = torch.exp(self.B_log.float())  # (D, d_state)
        
        # Bidirectional scanning
        # Forward direction
        h_forward = torch.zeros(B, D, self.d_state, device=x.device)
        outputs_forward = []
        
        for n in range(N):
            h_forward = x_inner[:, n, :] * (B * x[:, :, n]) + (A.exp() * h_forward) * (1 - x_inner[:, n, :])
            outputs_forward.append(h_forward)
        
        # Backward direction
        h_backward = torch.zeros(B, D, self.d_state, device=x.device)
        outputs_backward = []
        
        for n in reversed(range(N)):
            h_backward = x_inner[:, n, :] * (B * x[:, :, n]) + (A.exp() * h_backward) * (1 - x_inner[:, n, :])
            outputs_backward.insert(0, h_backward)
        
        # Concatenate forward and backward
        h_combined = torch.stack([h_f + h_b for h_f, h_b in zip(outputs_forward, outputs_backward)], dim=2)
        
        # Output projection
        y = self.out_proj(h_combined.transpose(1, 2))  # (B, N, D)
        
        return y.transpose(1, 2)  # (B, D, N)
 
 
class SemanticConsistentShuffle(nn.Module):
    """
    Semantic-Consistent Shuffle for background patches
    """
    def __init__(self, shuffle_ratio=0.2):
        super().__init__()
        self.shuffle_ratio = shuffle_ratio
    
    def forward(self, x, cam=None):
        """
        x: (B, C, H, W)
        cam: (B, H, W) class activation maps, if provided
        """
        B, C, H, W = x.shape
        patch_size = 4
        n_patches = (H // patch_size) * (W // patch_size)
        
        # Identify low activation patches using CAM or default
        if cam is None:
            # Use center region as proxy for foreground
            low_mask = torch.ones(B, H, W, device=x.device, dtype=torch.bool)
            h_start, h_end = H // 4, 3 * H // 4
            w_start, w_end = W // 4, 3 * W // 4
            low_mask[:, h_start:h_end, w_start:w_end] = False
        else:
            threshold = cam.mean()
            low_mask = cam < threshold
        
        # Reshape to patches
        x_patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
        x_patches = x_patches.reshape(B, C, -1, patch_size, patch_size)  # (B, C, n_patches, p, p)
        
        mask_patches = low_mask.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size)
        mask_patches = mask_patches.reshape(B, 1, -1)  # (B, 1, n_patches)
        
        # Get low activation indices
        n_shuffle = int(n_patches * self.shuffle_ratio)
        low_indices = mask_patches.squeeze(1).topk(n_shuffle, dim=1, largest=True)[1]
        
        # Shuffle
        x_shuffled = x_patches.clone()
        for b in range(B):
            shuffle_idx = low_indices[b][torch.randperm(n_shuffle)]
            x_shuffled[b, :, shuffle_idx] = x_patches[b, :, low_indices[b]]
        
        # Reshape back
        x_shuffled = x_shuffled.reshape(B, C, H // patch_size, W // patch_size, patch_size, patch_size)
        x_shuffled = x_shuffled.permute(0, 1, 2, 4, 3, 5).reshape(B, C, H, W)
        
        return x_shuffled

与其他SFDA方法的对比

架构演进

时代代表方法架构特点
CNN时代SHOT, NRCResNet, VGG局部感受野
ViT时代DSiT, C-SFTransViT, DeiT全局注意力
Mamba时代SfMamba, DATMambaVMamba线性复杂度+全局建模

SfMamba vs DATMamba

维度DATMambaSfMamba
扫描方向空间4方向空间4方向 + 通道2方向
通道交互Ch-VSS
打乱策略SCS
准确率80.9%81.7%

总结

SfMamba的核心贡献是将Mamba架构引入SFDA领域,并通过通道级扫描和语义一致打乱策略实现了SOTA性能。

关键创新

  1. Ch-VSS块:双向通道扫描,学习领域不变频率特征
  2. SCS策略:打乱背景补丁,保持语义一致性
  3. 线性效率 复杂度,适合高分辨率图像

性能总结

  • Office-Home: 81.7% (SOTA)
  • VisDA-C: 89.3% (SOTA)
  • DomainNet-126: 77.9% (SOTA)
  • 参数量: 58.9M (比C-SFTrans少32%)

参考