1. 概述

1.1 什么是反事实生成

反事实生成(Counterfactual Generation) 是指在给定因果模型的情况下,回答”如果…会怎样”(What if)类型的问题:

给定观测数据 和因果模型 ,生成在干预 下的反事实结果

1.2 反事实 vs 条件生成

方面条件生成反事实生成
干预方式观察性条件 $p(Xc)$
因果效应无法区分直接/间接效应可追踪因果路径
反事实能力不支持完全支持
理论基础贝叶斯概率结构因果模型

1.3 核心挑战

  1. 可识别性:何时能从观测数据中识别反事实分布
  2. 生成质量:如何保证反事实图像的真实性和一致性
  3. 因果图的准确性:错误的因果图导致错误的反事实
  4. 高维复杂性:高维图像空间中的因果推理

2. 因果生成模型框架

2.1 整体架构

观测数据 X ──→ 编码器 E ──→ 潜在因果变量 Z_c
                                    │
                                    ↓ (因果机制 SCM)
                              反事实干预 do(Z_i = z_i')
                                    │
                                    ↓
                              反事实因果变量 Z_c^cf
                                    │
                                    ↓
解码器 D ──→ 反事实生成 X^cf

2.2 结构因果模型(SCM)

因果生成模型的核心是结构因果模型:

其中:

  • 的父节点
  • 是非线性因果机制
  • 是独立的外生噪声

2.3 潜在空间干预

加性干预

在潜在空间直接修改因果变量:

其中 是对因果因素 的干预量。

结构干预

修改 SCM 的因果机制:


3. 关键算法

3.1 CausalVAE

CausalVAE(CVPR 2021)是最早的因果生成模型之一。

架构特点

  1. 编码器:将图像编码为独立潜在变量
  2. 因果层:通过 SCM 将 变换为因果变量
  3. 解码器:从因果变量生成图像

反事实生成过程

def counterfactual_generation(model, x, target_attr, delta):
    """CausalVAE 反事实生成"""
    # 编码
    z = model.encoder(x)
    
    # 前向计算因果变量
    c = model.causal_layer(z)
    
    # 反事实干预
    c_cf = c.clone()
    c_cf[:, target_attr] += delta
    
    # 解码
    x_cf = model.decoder(c_cf)
    
    return x_cf

3.2 Causal Diffusion Autoencoders

核心思想

将 Diffusion 模型的强大生成能力与因果表示的可控性结合。

架构

图像 X
    ↓
因果编码器 E_c → 因果表示 z_c
    ↓
Diffusion Forward Process: z_t = α_t · z_c + σ_t · ε
    ↓
Diffusion Decoder: 生成反事实图像

损失函数

其中 强制因果先验结构。

3.3 CausalCF

基于梯度的方法

通过梯度分析识别因果因素对输出的影响:

def causal_attribution(model, x, latent_dim):
    """基于梯度的因果归因"""
    z = model.encoder(x)
    z.requires_grad_(True)
    
    x_recon = model.decoder(z)
    attr = torch.zeros_like(z)
    
    for i in range(latent_dim):
        grad = torch.autograd.grad(
            x_recon.sum(), z, retain_graph=True
        )[0]
        attr[:, i] = grad[:, i].abs().sum(dim=tuple(range(1, x.dim())))
    
    return attr

3.4 对比总结

方法优点缺点适用场景
CausalVAE理论基础扎实,可解释性强生成质量受限可解释AI
Causal Diffusion生成质量高计算开销大高质量图像生成
CausalCF无需显式因果图需要多次前向传播因果归因分析

4. 因果可控生成应用

4.1 属性控制生成

在图像生成中,通过干预特定因果因素来控制生成属性:

干预效果
生成红色物体
生成圆形物体
生成暗光场景

4.2 反事实问答

反事实视觉问答:回答”如果图像中的某个部分改变了,结果会怎样?“类型的问题。

任务定义

给定图像 、问题 、目标干预 ,生成答案

方法

4.3 医学影像反事实生成

应用场景

  • 疾病模拟:模拟疾病进展或消退的影像
  • 治疗效果预测:预测干预后的影像变化
  • 跨设备泛化:生成不同设备拍摄的模拟影像

示例

原始 CT 影像
    ↓ (识别肿瘤因果因素)
干预:do(肿瘤大小 = 0)
    ↓
反事实影像:无肿瘤的 CT

5. 评估方法

5.1 因果有效性评估

干预效果度量

理想情况下,干预第 个因果因素应只影响与之相关的图像属性。

反事实一致性

相同的反事实干预应产生一致的生成结果。

5.2 生成质量评估

指标描述
FID与真实图像的分布距离
LPIPS感知相似度
Counterfactual Accuracy反事实准确性

5.3 因果发现评估

  • SHD(Structural Hamming Distance):与真实因果图的编辑距离
  • MEC(Minimum Description Length Score):因果图评分

6. 代码实现

6.1 基础反事实生成器

#include <bits/stdc++.h>
using namespace std;
 
// 简化的反事实生成器
class CounterfactualGenerator {
public:
    // 编码器
    virtual torch::Tensor encode(torch::Tensor x) = 0;
    
    // 因果层(SCM)
    virtual torch::Tensor causal_transform(torch::Tensor z) = 0;
    
    // 反事实干预
    torch::Tensor intervene(torch::Tensor c, int idx, float value) {
        torch::Tensor c_cf = c.clone();
        c_cf.index_put_({torch::indexing::Slice(), idx}, value);
        return c_cf;
    }
    
    // 解码器
    virtual torch::Tensor decode(torch::Tensor c_cf) = 0;
    
    // 反事实生成
    torch::Tensor generate_counterfactual(
        torch::Tensor x, int attr_idx, float attr_value
    ) {
        torch::Tensor z = encode(x);
        torch::Tensor c = causal_transform(z);
        torch::Tensor c_cf = intervene(c, attr_idx, attr_value);
        return decode(c_cf);
    }
};

6.2 训练流程

class CausalGenerativeModel:
    def __init__(self, encoder, causal_layer, decoder):
        self.encoder = encoder
        self.causal_layer = causal_layer
        self.decoder = decoder
    
    def training_step(self, x, adj_matrix):
        # 编码
        z = self.encoder(x)
        
        # 因果变换
        c = self.causal_layer(z, adj_matrix)
        
        # 重构
        x_recon = self.decoder(c)
        
        # 重建损失
        recon_loss = F.mse_loss(x_recon, x)
        
        # 因果结构损失
        causal_prior = self.causal_layer.get_prior()
        causal_loss = self.compute_causal_consistency(c, causal_prior)
        
        return recon_loss + lambda_causal * causal_loss

7. 参考文献