CausalDA:基于因果发现的统一免源域适应

概述

CausalDA(arXiv:2403.07601)提出了首个**统一的免源域适应(SFDA)**框架,从因果发现的视角综合处理四种SFDA场景:Closed-set、Open-set、Partial-set和Generalized设置。该方法利用预训练的视觉-语言模型(如CLIP)来发现潜在因果因子,实现无需源域数据访问的鲁棒适应。

核心贡献

  1. 统一问题定义:首次将四种SFDA场景统一为一个框架
  2. 因果视角:从因果发现而非统计关联的角度建模
  3. CLIP辅助:利用视觉-语言模型的丰富世界知识
  4. 理论保证:具有自监督信息瓶颈的理论分析

SFDA场景统一

场景协变量偏移语义偏移抗遗忘难度
Closed-set最简单
Generalized较难
Open-set较难
Partial-set较难
Unified (本文)最难

因果理论基础

结构因果模型(SCM)

CausalDA假设数据生成过程遵循以下因果结构:

       U (非因果因子)
          ↓
    ┌─────┴─────┐
    ↓           ↓
   X ←────── S (因果因子)
    ↓
    ↓
    ↓
   Y (标签)
  • :因果因子,决定标签的语义本质
  • :非因果因子,导致分布偏移(光照、背景、噪声等)
  • :图像由因果和非因果因子共同决定
  • :标签仅由因果因子决定

关键洞察

域偏移由引起,但仅由决定。因此,分离可以实现跨域泛化。

潜在因果因子分解

其中:

  • (外部因子):从CLIP提取的通用世界知识
  • (内部因子):从目标域数据学习的域不变表示

方法详解

1. 整体框架

Phase 1: 外部因果发现 (External Causal Discovery)
┌─────────────────────────────────────────┐
│  CLIP + 提示学习 → 发现潜在因果因子 S_e  │
│         ↓                                 │
│  自监督信息瓶颈 → 保证 S_e 与 Y 的关系   │
└─────────────────────────────────────────┘
                    ↓
Phase 2: 内部因果发现 (Internal Causal Discovery)
┌─────────────────────────────────────────┐
│  目标域数据 → 学习内部因果因子 S_i       │
│         ↓                                 │
│  半监督因果对齐 → S_e 与 S_i 联合建模    │
└─────────────────────────────────────────┘
                    ↓
最终输出: 适应后的分类器

2. Phase 1: 外部因果发现

2.1 CLIP提示学习

利用CLIP的文本编码器生成类原型:

# 类别提示模板
templates = [
    "a photo of a {}",
    "a picture of a {}",
    "a {} in the scene"
]
 
# 文本特征提取
for cls in classes:
    text_features = []
    for template in templates:
        text = template.format(cls)
        feat = clip.encode_text(text)
        text_features.append(feat)
    class_prototype[cls] = mean(text_features)

2.2 自监督信息瓶颈

目标函数:

其中:

  • :原型匹配损失(Prototype Matching)
  • :变分互信息损失(Variational Mutual Information)

原型匹配损失

变分互信息估计

其中 是学到的转换模型。

3. Phase 2: 内部因果发现

3.1 半监督因果对齐

其中:

  • :无监督损失(熵/一致性)
  • :半监督因果对齐损失
  • :源模型提取的特征分布
  • :目标模型提取的特征分布

3.2 因果因子对齐

通过KL散度对齐外部和内部因果因子:

4. 统一适应算法

# CausalDA 核心算法
def causalda(source_model, target_data, clip_model, epochs=100):
    """
    Causal Discovery for Unified Source-Free Domain Adaptation
    """
    # Phase 1: 外部因果发现
    S_e = external_causal_discovery(clip_model, classes)
    
    # Phase 2: 内部因果发现
    target_model = copy.deepcopy(source_model)
    
    for epoch in range(epochs):
        # 无监督适应
        for x_t in target_data:
            # 提取特征
            f_t = target_model.encode(x_t)
            
            # 因果因子分解
            s_e = S_e(x_t)  # 外部因子(冻结)
            s_i = learn_internal_factor(f_t)  # 内部因子
            
            # 联合建模
            pred = combine_factors(s_e, s_i)
            
            # 更新内部因子
            loss = compute_loss(pred, target_model)
            loss.backward()
            optimizer.step()
    
    return target_model
 
 
def external_causal_discovery(clip_model, classes):
    """
    Phase 1: 使用CLIP发现外部因果因子
    """
    # 提取类原型
    class_prototypes = {}
    for cls in classes:
        proto = extract_class_prototype(clip_model, cls)
        class_prototypes[cls] = proto
    
    # 信息瓶颈优化
    for iteration in range(max_iterations):
        # 随机采样图像
        x = sample_batch(target_data)
        
        # 特征提取
        f = clip_model.encode_image(x)
        
        # 计算原型匹配损失
        L_pmi = prototype_matching_loss(f, class_prototypes)
        
        # 计算变分互信息损失
        L_vmi = variational_mi_loss(f)
        
        # 总损失
        L = L_pmi - alpha * L_vmi
        
        L.backward()
        optimizer.step()
    
    return ExternalCausalFactor(class_prototypes)

实验结果

Office-Home数据集

方法Ar→ClAr→PrAr→RwCl→ArCl→PrCl→Rw平均
Source Only45.1%59.5%58.5%49.2%56.1%58.7%54.5%
SHOT63.2%75.8%76.1%62.1%73.2%73.5%70.7%
NRC65.8%77.2%78.9%63.5%75.8%75.2%72.7%
TPDS66.4%78.2%78.9%63.9%76.5%76.1%73.3%
GKD68.1%79.5%80.2%65.8%77.8%77.9%74.9%
CausalDA-C-B3270.5%82.3%83.1%68.2%80.1%79.6%77.3%

VisDA-C数据集

方法准确率
Source Only52.4%
SHOT78.3%
NRC81.2%
PLUE88.3%
CausalDA90.3%

DomainNet-126

方法准确率
Source Only42.1%
SHOT71.2%
GKD77.5%
CausalDA82.2%

跨域泛化(Out-of-Distribution)

方法PACSOfficeHomeVLCS平均
Source Only64.8%56.1%62.3%61.1%
SHOT78.9%75.2%76.8%77.0%
CausalDA82.3%79.8%80.1%80.7%

消融实验

因果子分解的作用

配置Office-Home
仅外部因子 75.2%
仅内部因子 74.8%
联合建模 77.3%

信息瓶颈的作用

Office-Home说明
075.1%无VMI正则
0.176.8%轻正则
0.577.3%适中
1.076.5%过正则

CLIP模型选择

CLIP变体Office-Home
CLIP-B/3275.6%
CLIP-B/1676.8%
CLIP-L/1477.3%

理论分析

因果鲁棒性

定理:如果 ,则跨任意 分布的域适应是可能的。

直觉:由于标签仅由因果因子 决定,而非因果因子 (导致域偏移)不影响标签,因此学习 即可实现跨域泛化。

信息瓶颈的理论保证

引理:信息瓶颈项 保证了学到的表示 与输入 之间去除冗余信息。

直觉 应该只编码与 相关的因果信息,而去除与 无关的 信息。


与其他方法的对比

因果 vs 统计方法

维度统计方法 (SHOT, NRC)因果方法 (CausalDA)
学习目标最小化分布差异发现因果结构
域偏移处理对齐特征分布分离因果/非因果因子
泛化能力受限于同分布假设可处理分布外数据
可解释性黑盒因果关系明确
源数据依赖间接(通过模型)无需源数据

方法路线对比

传统SFDA:
源模型 → 伪标签生成 → 特征聚类 → 自训练
       ↓
问题:错误累积、噪声传播

CausalDA:
源模型 → 因果因子分解(S_e, S_i) → 干预学习
       ↓
优势:对非因果因子的鲁棒性

应用场景

1. 医疗影像跨机构适应

场景:跨医院医学影像分类
├─ 医院A (源): 有标注,设备型号A
├─ 医院B (目标): 无标注,设备型号B
└─ 挑战: HIPAA隐私限制
    → CausalDA利用CLIP知识,无需访问源数据

2. 自动驾驶跨环境部署

场景:仿真→真实世界 / 不同国家
├─ 仿真数据 (源): 合成图像,标注完备
├─ 真实道路 (目标): 多变天气、光照、国家
└─ CausalDA的因果视角天然适合处理
   不同国家/环境的语义不变性

3. 工业检测系统

场景:跨产线缺陷检测
├─ 产线A (源): 特定光照、相机角度
├─ 产线B (目标): 新环境配置
└─ CausalDA的S_e捕获缺陷的因果本质,
   S_i学习新环境的表示

总结

CausalDA的核心贡献是从因果发现的视角统一了四种SFDA场景。

关键创新

  1. 潜在因果因子分解 分离外部知识和内部学习
  2. 自监督信息瓶颈:保证因果因子的质量
  3. 统一框架:无需针对特定场景的先验知识

性能提升

  • Office-Home: 77.3% (vs Prior SOTA 74.9%)
  • VisDA-C: 90.3% (vs Prior SOTA 88.3%)
  • DomainNet-126: 82.2% (vs Prior SOTA 77.5%)

理论贡献

  • 首个具有理论保证的统一SFDA框架
  • 因果鲁棒性定理
  • 信息瓶颈的理论分析

参考