1. 概述
1.1 什么是反事实生成
反事实生成(Counterfactual Generation) 是指在给定因果模型的情况下,回答”如果…会怎样”(What if)类型的问题:
给定观测数据 和因果模型 ,生成在干预 下的反事实结果 。
1.2 反事实 vs 条件生成
| 方面 | 条件生成 | 反事实生成 |
|---|---|---|
| 干预方式 | 观察性条件 $p(X | c)$ |
| 因果效应 | 无法区分直接/间接效应 | 可追踪因果路径 |
| 反事实能力 | 不支持 | 完全支持 |
| 理论基础 | 贝叶斯概率 | 结构因果模型 |
1.3 核心挑战
- 可识别性:何时能从观测数据中识别反事实分布
- 生成质量:如何保证反事实图像的真实性和一致性
- 因果图的准确性:错误的因果图导致错误的反事实
- 高维复杂性:高维图像空间中的因果推理
2. 因果生成模型框架
2.1 整体架构
观测数据 X ──→ 编码器 E ──→ 潜在因果变量 Z_c
│
↓ (因果机制 SCM)
反事实干预 do(Z_i = z_i')
│
↓
反事实因果变量 Z_c^cf
│
↓
解码器 D ──→ 反事实生成 X^cf
2.2 结构因果模型(SCM)
因果生成模型的核心是结构因果模型:
其中:
- 是 的父节点
- 是非线性因果机制
- 是独立的外生噪声
2.3 潜在空间干预
加性干预
在潜在空间直接修改因果变量:
其中 是对因果因素 的干预量。
结构干预
修改 SCM 的因果机制:
3. 关键算法
3.1 CausalVAE
CausalVAE(CVPR 2021)是最早的因果生成模型之一。
架构特点
- 编码器:将图像编码为独立潜在变量
- 因果层:通过 SCM 将 变换为因果变量
- 解码器:从因果变量生成图像
反事实生成过程
def counterfactual_generation(model, x, target_attr, delta):
"""CausalVAE 反事实生成"""
# 编码
z = model.encoder(x)
# 前向计算因果变量
c = model.causal_layer(z)
# 反事实干预
c_cf = c.clone()
c_cf[:, target_attr] += delta
# 解码
x_cf = model.decoder(c_cf)
return x_cf3.2 Causal Diffusion Autoencoders
核心思想
将 Diffusion 模型的强大生成能力与因果表示的可控性结合。
架构
图像 X
↓
因果编码器 E_c → 因果表示 z_c
↓
Diffusion Forward Process: z_t = α_t · z_c + σ_t · ε
↓
Diffusion Decoder: 生成反事实图像
损失函数
其中 强制因果先验结构。
3.3 CausalCF
基于梯度的方法
通过梯度分析识别因果因素对输出的影响:
def causal_attribution(model, x, latent_dim):
"""基于梯度的因果归因"""
z = model.encoder(x)
z.requires_grad_(True)
x_recon = model.decoder(z)
attr = torch.zeros_like(z)
for i in range(latent_dim):
grad = torch.autograd.grad(
x_recon.sum(), z, retain_graph=True
)[0]
attr[:, i] = grad[:, i].abs().sum(dim=tuple(range(1, x.dim())))
return attr3.4 对比总结
| 方法 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
| CausalVAE | 理论基础扎实,可解释性强 | 生成质量受限 | 可解释AI |
| Causal Diffusion | 生成质量高 | 计算开销大 | 高质量图像生成 |
| CausalCF | 无需显式因果图 | 需要多次前向传播 | 因果归因分析 |
4. 因果可控生成应用
4.1 属性控制生成
在图像生成中,通过干预特定因果因素来控制生成属性:
| 干预 | 效果 |
|---|---|
| 生成红色物体 | |
| 生成圆形物体 | |
| 生成暗光场景 |
4.2 反事实问答
反事实视觉问答:回答”如果图像中的某个部分改变了,结果会怎样?“类型的问题。
任务定义
给定图像 、问题 、目标干预 ,生成答案 。
方法
4.3 医学影像反事实生成
应用场景
- 疾病模拟:模拟疾病进展或消退的影像
- 治疗效果预测:预测干预后的影像变化
- 跨设备泛化:生成不同设备拍摄的模拟影像
示例
原始 CT 影像
↓ (识别肿瘤因果因素)
干预:do(肿瘤大小 = 0)
↓
反事实影像:无肿瘤的 CT
5. 评估方法
5.1 因果有效性评估
干预效果度量
理想情况下,干预第 个因果因素应只影响与之相关的图像属性。
反事实一致性
相同的反事实干预应产生一致的生成结果。
5.2 生成质量评估
| 指标 | 描述 |
|---|---|
| FID | 与真实图像的分布距离 |
| LPIPS | 感知相似度 |
| Counterfactual Accuracy | 反事实准确性 |
5.3 因果发现评估
- SHD(Structural Hamming Distance):与真实因果图的编辑距离
- MEC(Minimum Description Length Score):因果图评分
6. 代码实现
6.1 基础反事实生成器
#include <bits/stdc++.h>
using namespace std;
// 简化的反事实生成器
class CounterfactualGenerator {
public:
// 编码器
virtual torch::Tensor encode(torch::Tensor x) = 0;
// 因果层(SCM)
virtual torch::Tensor causal_transform(torch::Tensor z) = 0;
// 反事实干预
torch::Tensor intervene(torch::Tensor c, int idx, float value) {
torch::Tensor c_cf = c.clone();
c_cf.index_put_({torch::indexing::Slice(), idx}, value);
return c_cf;
}
// 解码器
virtual torch::Tensor decode(torch::Tensor c_cf) = 0;
// 反事实生成
torch::Tensor generate_counterfactual(
torch::Tensor x, int attr_idx, float attr_value
) {
torch::Tensor z = encode(x);
torch::Tensor c = causal_transform(z);
torch::Tensor c_cf = intervene(c, attr_idx, attr_value);
return decode(c_cf);
}
};6.2 训练流程
class CausalGenerativeModel:
def __init__(self, encoder, causal_layer, decoder):
self.encoder = encoder
self.causal_layer = causal_layer
self.decoder = decoder
def training_step(self, x, adj_matrix):
# 编码
z = self.encoder(x)
# 因果变换
c = self.causal_layer(z, adj_matrix)
# 重构
x_recon = self.decoder(c)
# 重建损失
recon_loss = F.mse_loss(x_recon, x)
# 因果结构损失
causal_prior = self.causal_layer.get_prior()
causal_loss = self.compute_causal_consistency(c, causal_prior)
return recon_loss + lambda_causal * causal_loss