SfMamba:基于选择性扫描的免源域适应
概述
SfMamba(arXiv:2601.08608)提出了一种基于选择性扫描机制的高效免源域适应(Source-Free Domain Adaptation, SFDA)框架。其核心创新是引入通道级视觉状态空间块(Ch-VSS),解决了传统方法在感受野和计算效率之间的权衡问题。
核心贡献
- 首个Mamba-based SFDA:利用状态空间模型的长期依赖建模能力,同时保持线性复杂度
- 通道级视觉状态空间块:沿通道维度进行双向扫描,学习领域不变频率特征
- 语义一致打乱策略:通过打乱背景补丁缓解错误累积
与现有方法对比
| 方法 | 架构 | 复杂度 | 准确率 |
|---|---|---|---|
| SHOT | CNN | 76.2% | |
| NRC | CNN | 78.1% | |
| DATMamba | Mamba | 80.9% | |
| SfMamba | Mamba | 81.7% |
问题背景
免源域适应定义
给定:
- 预训练的源域模型
- 无标签目标域数据
目标:学习一个目标域模型 ,使得 在目标域上表现良好。
关键约束:训练过程中不能访问源域数据。
现有方法的挑战
| 方法类型 | 挑战 |
|---|---|
| CNN-based | 感受野有限,难以捕获长距离依赖 |
| ViT-based | 全局注意力导致 复杂度 |
| Mamba-based | 通道交互不足,领域不变特征学习受限 |
方法详解
1. 整体框架
输入图像
↓
VMamba编码器 (2D扫描)
↓
通道级视觉状态空间块 (Ch-VSS)
↓
语义一致性打乱 (SCS)
↓
分类器
↓
KL散度一致性正则化
2. VMamba编码器
基于VMamba的2D选择性扫描:
输入: X ∈ R^{B×C×H×W}
↓
重塑: X' ∈ R^{B×C×(HW)} (展平为序列)
↓
四方向扫描:
- 左→右
- 右→左
- 上→下
- 下→上
↓
合并: 沿扫描方向拼接
↓
输出: Y ∈ R^{B×C×H×W}
选择性机制:
其中 是时间步增量。
3. 通道级视觉状态空间块(Ch-VSS)
核心创新:沿通道维度应用双向SSM,实现通道间信息交互。
输入特征: F ∈ R^{B×D×H×W}
↓
沿通道展平: F' ∈ R^{B×D×(HW)}
↓
双向SSM:
- Forward: 沿通道正序扫描
- Backward: 沿通道逆序扫描
↓
输出: F'' ∈ R^{B×D×(HW)}
↓
重塑: F'' ∈ R^{B×D×H×W}
数学形式:
其中 表示通道索引。
通道扫描的优势:
- 学习通道间的频率关系
- 捕获领域不变的低频特征
- 保持线性复杂度
4. 语义一致打乱策略(SCS)
问题:错误累积导致伪标签质量下降。
解决思路:
- 通过打乱背景补丁生成扰动样本
- 要求扰动样本的预测与原始样本一致
4.1 低激活补丁识别
使用Grad-CAM识别低激活区域(背景):
# 伪代码
cam = compute_gradcam(model, x) # (H, W)
threshold = cam.mean() # 使用均值作为阈值
low_activation_mask = cam < threshold # (H, W)4.2 打乱操作
原始图像 (4×4补丁)
┌────┬────┬────┬────┐
│ P1 │ P2 │ P3 │ P4 │
├────┼────┼────┼────┤
│ P5 │ P6 │ P7 │ P8 │ ← 背景补丁
├────┼────┼────┼────┤
│ P9 │ P10│ P11│ P12│ ← 前景补丁
├────┼────┼────┼────┤
│ P13│ P14│ P15│ P16│ ← 前景补丁
└────┴────┴────┴────┘
↓ 打乱
┌────┬────┬────┬────┐
│ P1 │ P2 │ P3 │ P4 │
├────┼────┼────┼────┤
│ P6 │ P8 │ P5 │ P7 │ ← 打乱背景
├────┼────┼────┼────┤
│ P9 │ P10│ P11│ P12│ ← 保持前景
├────┼────┼────┼────┤
│ P13│ P14│ P15│ P16│ ← 保持前景
└────┴────┴────┴────┘
4.3 一致性损失
其中 是打乱后的图像。
5. 训练目标
| 损失项 | 作用 |
|---|---|
| 分类损失(伪标签监督) | |
| 一致性正则化 | |
| 熵正则化(鼓励高置信度预测) |
实验结果
Office-Home数据集
| 方法 | Ar→Cl | Ar→Pr | Ar→Rw | 平均 |
|---|---|---|---|---|
| Source Only | 58.9% | 65.3% | 76.0% | 66.7% |
| SHOT | 72.1% | 78.8% | 81.5% | 77.5% |
| NRC | 74.2% | 80.2% | 82.4% | 78.9% |
| DATMamba | 76.8% | 81.5% | 84.3% | 80.9% |
| SfMamba-S | 77.5% | 82.3% | 85.4% | 81.7% |
VisDA-C数据集
| 方法 | 准确率 |
|---|---|
| Source Only | 52.4% |
| SHOT | 78.3% |
| NRC | 81.2% |
| C-SFTrans | 88.3% |
| SfMamba-S | 89.3% |
DomainNet-126
| 方法 | 准确率 |
|---|---|
| Source Only | 42.1% |
| SHOT | 71.2% |
| SHOT++ | 75.2% |
| SfMamba-S | 77.9% |
效率对比
| 方法 | 参数量 | FLOPs | 吞吐量 |
|---|---|---|---|
| C-SFTrans | 86.6M | 17.6G | 1.0× |
| DATMamba | 61.2M | 10.1G | 1.5× |
| SfMamba-S | 58.9M | 9.2G | 1.26× |
消融实验
Ch-VSS的效果
| 配置 | Office-Home | 说明 |
|---|---|---|
| 无Ch-VSS | 79.8% | 使用标准VMamba |
| Ch-VSS (单方向) | 80.9% | 仅正向扫描 |
| Ch-VSS (双向) | 81.7% | 双向通道扫描 |
SCS的效果
| 配置 | Office-Home | 打乱比例 |
|---|---|---|
| 无SCS | 80.2% | - |
| SCS (10%) | 80.8% | 打乱10%补丁 |
| SCS (20%) | 81.7% | 打乱20%补丁 |
| SCS (30%) | 81.3% | 打乱30%补丁 |
各组件贡献
| 组件 | Office-Home | 提升 |
|---|---|---|
| VMamba基线 | 78.5% | - |
| + Ch-VSS | 80.9% | +2.4% |
| + SCS | 81.7% | +0.8% |
| + 一致性损失 | 81.7% | - |
PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class ChannelWiseSSM(nn.Module):
"""
Channel-wise Visual State Space Block
"""
def __init__(self, dim, d_state=16, dt_rank='auto'):
super().__init__()
self.dim = dim
self.d_state = d_state
# Input projection
self.x_proj = nn.Linear(dim, d_state * 2, bias=False)
# dt projection
self.dt_proj = nn.Linear(d_state, dim, bias=True)
# A and B parameters
self.A_log = nn.Parameter(torch.randn(dim, d_state))
self.B_log = nn.Parameter(torch.randn(dim, d_state))
# Output projection
self.out_proj = nn.Linear(dim, dim, bias=False)
def forward(self, x):
"""
x: (B, D, N) where D=channels, N=spatial tokens
"""
B, D, N = x.shape
# Compute input-dependent parameters
x_gate = self.x_proj(x.transpose(1, 2)) # (B, N, d_state*2)
x_inner, x_dt = x_gate.chunk(2, dim=-1) # (B, N, d_state) each
# Sigmoid gate
x_inner = torch.sigmoid(x_inner)
x_dt = torch.softplus(self.dt_proj(x_dt)) # (B, N, D)
# State transition
A = -torch.exp(self.A_log.float()) # (D, d_state)
B = torch.exp(self.B_log.float()) # (D, d_state)
# Bidirectional scanning
# Forward direction
h_forward = torch.zeros(B, D, self.d_state, device=x.device)
outputs_forward = []
for n in range(N):
h_forward = x_inner[:, n, :] * (B * x[:, :, n]) + (A.exp() * h_forward) * (1 - x_inner[:, n, :])
outputs_forward.append(h_forward)
# Backward direction
h_backward = torch.zeros(B, D, self.d_state, device=x.device)
outputs_backward = []
for n in reversed(range(N)):
h_backward = x_inner[:, n, :] * (B * x[:, :, n]) + (A.exp() * h_backward) * (1 - x_inner[:, n, :])
outputs_backward.insert(0, h_backward)
# Concatenate forward and backward
h_combined = torch.stack([h_f + h_b for h_f, h_b in zip(outputs_forward, outputs_backward)], dim=2)
# Output projection
y = self.out_proj(h_combined.transpose(1, 2)) # (B, N, D)
return y.transpose(1, 2) # (B, D, N)
class SemanticConsistentShuffle(nn.Module):
"""
Semantic-Consistent Shuffle for background patches
"""
def __init__(self, shuffle_ratio=0.2):
super().__init__()
self.shuffle_ratio = shuffle_ratio
def forward(self, x, cam=None):
"""
x: (B, C, H, W)
cam: (B, H, W) class activation maps, if provided
"""
B, C, H, W = x.shape
patch_size = 4
n_patches = (H // patch_size) * (W // patch_size)
# Identify low activation patches using CAM or default
if cam is None:
# Use center region as proxy for foreground
low_mask = torch.ones(B, H, W, device=x.device, dtype=torch.bool)
h_start, h_end = H // 4, 3 * H // 4
w_start, w_end = W // 4, 3 * W // 4
low_mask[:, h_start:h_end, w_start:w_end] = False
else:
threshold = cam.mean()
low_mask = cam < threshold
# Reshape to patches
x_patches = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
x_patches = x_patches.reshape(B, C, -1, patch_size, patch_size) # (B, C, n_patches, p, p)
mask_patches = low_mask.unfold(0, patch_size, patch_size).unfold(1, patch_size, patch_size)
mask_patches = mask_patches.reshape(B, 1, -1) # (B, 1, n_patches)
# Get low activation indices
n_shuffle = int(n_patches * self.shuffle_ratio)
low_indices = mask_patches.squeeze(1).topk(n_shuffle, dim=1, largest=True)[1]
# Shuffle
x_shuffled = x_patches.clone()
for b in range(B):
shuffle_idx = low_indices[b][torch.randperm(n_shuffle)]
x_shuffled[b, :, shuffle_idx] = x_patches[b, :, low_indices[b]]
# Reshape back
x_shuffled = x_shuffled.reshape(B, C, H // patch_size, W // patch_size, patch_size, patch_size)
x_shuffled = x_shuffled.permute(0, 1, 2, 4, 3, 5).reshape(B, C, H, W)
return x_shuffled与其他SFDA方法的对比
架构演进
| 时代 | 代表方法 | 架构 | 特点 |
|---|---|---|---|
| CNN时代 | SHOT, NRC | ResNet, VGG | 局部感受野 |
| ViT时代 | DSiT, C-SFTrans | ViT, DeiT | 全局注意力 |
| Mamba时代 | SfMamba, DATMamba | VMamba | 线性复杂度+全局建模 |
SfMamba vs DATMamba
| 维度 | DATMamba | SfMamba |
|---|---|---|
| 扫描方向 | 空间4方向 | 空间4方向 + 通道2方向 |
| 通道交互 | 无 | Ch-VSS |
| 打乱策略 | 无 | SCS |
| 准确率 | 80.9% | 81.7% |
总结
SfMamba的核心贡献是将Mamba架构引入SFDA领域,并通过通道级扫描和语义一致打乱策略实现了SOTA性能。
关键创新:
- Ch-VSS块:双向通道扫描,学习领域不变频率特征
- SCS策略:打乱背景补丁,保持语义一致性
- 线性效率: 复杂度,适合高分辨率图像
性能总结:
- Office-Home: 81.7% (SOTA)
- VisDA-C: 89.3% (SOTA)
- DomainNet-126: 77.9% (SOTA)
- 参数量: 58.9M (比C-SFTrans少32%)