CausalDA:基于因果发现的统一免源域适应
概述
CausalDA(arXiv:2403.07601)提出了首个**统一的免源域适应(SFDA)**框架,从因果发现的视角综合处理四种SFDA场景:Closed-set、Open-set、Partial-set和Generalized设置。该方法利用预训练的视觉-语言模型(如CLIP)来发现潜在因果因子,实现无需源域数据访问的鲁棒适应。
核心贡献
- 统一问题定义:首次将四种SFDA场景统一为一个框架
- 因果视角:从因果发现而非统计关联的角度建模
- CLIP辅助:利用视觉-语言模型的丰富世界知识
- 理论保证:具有自监督信息瓶颈的理论分析
SFDA场景统一
| 场景 | 协变量偏移 | 语义偏移 | 抗遗忘 | 难度 |
|---|---|---|---|---|
| Closed-set | ✓ | ✗ | ✗ | 最简单 |
| Generalized | ✓ | ✗ | ✓ | 较难 |
| Open-set | ✓ | ✓ | ✗ | 较难 |
| Partial-set | ✓ | ✓ | ✗ | 较难 |
| Unified (本文) | ✓ | ✓ | ✓ | 最难 |
因果理论基础
结构因果模型(SCM)
CausalDA假设数据生成过程遵循以下因果结构:
U (非因果因子)
↓
┌─────┴─────┐
↓ ↓
X ←────── S (因果因子)
↓
↓
↓
Y (标签)
- :因果因子,决定标签的语义本质
- :非因果因子,导致分布偏移(光照、背景、噪声等)
- :图像由因果和非因果因子共同决定
- :标签仅由因果因子决定
关键洞察
域偏移由引起,但仅由决定。因此,分离和可以实现跨域泛化。
潜在因果因子分解
其中:
- (外部因子):从CLIP提取的通用世界知识
- (内部因子):从目标域数据学习的域不变表示
方法详解
1. 整体框架
Phase 1: 外部因果发现 (External Causal Discovery)
┌─────────────────────────────────────────┐
│ CLIP + 提示学习 → 发现潜在因果因子 S_e │
│ ↓ │
│ 自监督信息瓶颈 → 保证 S_e 与 Y 的关系 │
└─────────────────────────────────────────┘
↓
Phase 2: 内部因果发现 (Internal Causal Discovery)
┌─────────────────────────────────────────┐
│ 目标域数据 → 学习内部因果因子 S_i │
│ ↓ │
│ 半监督因果对齐 → S_e 与 S_i 联合建模 │
└─────────────────────────────────────────┘
↓
最终输出: 适应后的分类器
2. Phase 1: 外部因果发现
2.1 CLIP提示学习
利用CLIP的文本编码器生成类原型:
# 类别提示模板
templates = [
"a photo of a {}",
"a picture of a {}",
"a {} in the scene"
]
# 文本特征提取
for cls in classes:
text_features = []
for template in templates:
text = template.format(cls)
feat = clip.encode_text(text)
text_features.append(feat)
class_prototype[cls] = mean(text_features)2.2 自监督信息瓶颈
目标函数:
其中:
- :原型匹配损失(Prototype Matching)
- :变分互信息损失(Variational Mutual Information)
原型匹配损失:
变分互信息估计:
其中 是学到的转换模型。
3. Phase 2: 内部因果发现
3.1 半监督因果对齐
其中:
- :无监督损失(熵/一致性)
- :半监督因果对齐损失
- :源模型提取的特征分布
- :目标模型提取的特征分布
3.2 因果因子对齐
通过KL散度对齐外部和内部因果因子:
4. 统一适应算法
# CausalDA 核心算法
def causalda(source_model, target_data, clip_model, epochs=100):
"""
Causal Discovery for Unified Source-Free Domain Adaptation
"""
# Phase 1: 外部因果发现
S_e = external_causal_discovery(clip_model, classes)
# Phase 2: 内部因果发现
target_model = copy.deepcopy(source_model)
for epoch in range(epochs):
# 无监督适应
for x_t in target_data:
# 提取特征
f_t = target_model.encode(x_t)
# 因果因子分解
s_e = S_e(x_t) # 外部因子(冻结)
s_i = learn_internal_factor(f_t) # 内部因子
# 联合建模
pred = combine_factors(s_e, s_i)
# 更新内部因子
loss = compute_loss(pred, target_model)
loss.backward()
optimizer.step()
return target_model
def external_causal_discovery(clip_model, classes):
"""
Phase 1: 使用CLIP发现外部因果因子
"""
# 提取类原型
class_prototypes = {}
for cls in classes:
proto = extract_class_prototype(clip_model, cls)
class_prototypes[cls] = proto
# 信息瓶颈优化
for iteration in range(max_iterations):
# 随机采样图像
x = sample_batch(target_data)
# 特征提取
f = clip_model.encode_image(x)
# 计算原型匹配损失
L_pmi = prototype_matching_loss(f, class_prototypes)
# 计算变分互信息损失
L_vmi = variational_mi_loss(f)
# 总损失
L = L_pmi - alpha * L_vmi
L.backward()
optimizer.step()
return ExternalCausalFactor(class_prototypes)实验结果
Office-Home数据集
| 方法 | Ar→Cl | Ar→Pr | Ar→Rw | Cl→Ar | Cl→Pr | Cl→Rw | 平均 |
|---|---|---|---|---|---|---|---|
| Source Only | 45.1% | 59.5% | 58.5% | 49.2% | 56.1% | 58.7% | 54.5% |
| SHOT | 63.2% | 75.8% | 76.1% | 62.1% | 73.2% | 73.5% | 70.7% |
| NRC | 65.8% | 77.2% | 78.9% | 63.5% | 75.8% | 75.2% | 72.7% |
| TPDS | 66.4% | 78.2% | 78.9% | 63.9% | 76.5% | 76.1% | 73.3% |
| GKD | 68.1% | 79.5% | 80.2% | 65.8% | 77.8% | 77.9% | 74.9% |
| CausalDA-C-B32 | 70.5% | 82.3% | 83.1% | 68.2% | 80.1% | 79.6% | 77.3% |
VisDA-C数据集
| 方法 | 准确率 |
|---|---|
| Source Only | 52.4% |
| SHOT | 78.3% |
| NRC | 81.2% |
| PLUE | 88.3% |
| CausalDA | 90.3% |
DomainNet-126
| 方法 | 准确率 |
|---|---|
| Source Only | 42.1% |
| SHOT | 71.2% |
| GKD | 77.5% |
| CausalDA | 82.2% |
跨域泛化(Out-of-Distribution)
| 方法 | PACS | OfficeHome | VLCS | 平均 |
|---|---|---|---|---|
| Source Only | 64.8% | 56.1% | 62.3% | 61.1% |
| SHOT | 78.9% | 75.2% | 76.8% | 77.0% |
| CausalDA | 82.3% | 79.8% | 80.1% | 80.7% |
消融实验
因果子分解的作用
| 配置 | Office-Home |
|---|---|
| 仅外部因子 | 75.2% |
| 仅内部因子 | 74.8% |
| 联合建模 | 77.3% |
信息瓶颈的作用
| Office-Home | 说明 | |
|---|---|---|
| 0 | 75.1% | 无VMI正则 |
| 0.1 | 76.8% | 轻正则 |
| 0.5 | 77.3% | 适中 |
| 1.0 | 76.5% | 过正则 |
CLIP模型选择
| CLIP变体 | Office-Home |
|---|---|
| CLIP-B/32 | 75.6% |
| CLIP-B/16 | 76.8% |
| CLIP-L/14 | 77.3% |
理论分析
因果鲁棒性
定理:如果 ,则跨任意 分布的域适应是可能的。
直觉:由于标签仅由因果因子 决定,而非因果因子 (导致域偏移)不影响标签,因此学习 即可实现跨域泛化。
信息瓶颈的理论保证
引理:信息瓶颈项 保证了学到的表示 与输入 之间去除冗余信息。
直觉: 应该只编码与 相关的因果信息,而去除与 无关的 信息。
与其他方法的对比
因果 vs 统计方法
| 维度 | 统计方法 (SHOT, NRC) | 因果方法 (CausalDA) |
|---|---|---|
| 学习目标 | 最小化分布差异 | 发现因果结构 |
| 域偏移处理 | 对齐特征分布 | 分离因果/非因果因子 |
| 泛化能力 | 受限于同分布假设 | 可处理分布外数据 |
| 可解释性 | 黑盒 | 因果关系明确 |
| 源数据依赖 | 间接(通过模型) | 无需源数据 |
方法路线对比
传统SFDA:
源模型 → 伪标签生成 → 特征聚类 → 自训练
↓
问题:错误累积、噪声传播
CausalDA:
源模型 → 因果因子分解(S_e, S_i) → 干预学习
↓
优势:对非因果因子的鲁棒性
应用场景
1. 医疗影像跨机构适应
场景:跨医院医学影像分类
├─ 医院A (源): 有标注,设备型号A
├─ 医院B (目标): 无标注,设备型号B
└─ 挑战: HIPAA隐私限制
→ CausalDA利用CLIP知识,无需访问源数据
2. 自动驾驶跨环境部署
场景:仿真→真实世界 / 不同国家
├─ 仿真数据 (源): 合成图像,标注完备
├─ 真实道路 (目标): 多变天气、光照、国家
└─ CausalDA的因果视角天然适合处理
不同国家/环境的语义不变性
3. 工业检测系统
场景:跨产线缺陷检测
├─ 产线A (源): 特定光照、相机角度
├─ 产线B (目标): 新环境配置
└─ CausalDA的S_e捕获缺陷的因果本质,
S_i学习新环境的表示
总结
CausalDA的核心贡献是从因果发现的视角统一了四种SFDA场景。
关键创新:
- 潜在因果因子分解: 分离外部知识和内部学习
- 自监督信息瓶颈:保证因果因子的质量
- 统一框架:无需针对特定场景的先验知识
性能提升:
- Office-Home: 77.3% (vs Prior SOTA 74.9%)
- VisDA-C: 90.3% (vs Prior SOTA 88.3%)
- DomainNet-126: 82.2% (vs Prior SOTA 77.5%)
理论贡献:
- 首个具有理论保证的统一SFDA框架
- 因果鲁棒性定理
- 信息瓶颈的理论分析