1. 概述
1.1 什么是因果解耦
因果解耦(Causal Disentanglement) 旨在学习满足以下条件的表示:
- 因果因素可解释:每个潜在维度对应一个有意义的因果因素
- 因果关系明确:因素间的关系可以用因果图描述
- 干预可控:修改一个因素只影响相关的因果机制
1.2 与统计解耦的对比
| 维度 | 统计解耦(VAE/β-VAE) | 因果解耦 |
|---|---|---|
| 独立性定义 | 统计独立(互信息=0) | 因果独立(独立机制) |
| 可解释性 | 低(因素可能混合) | 高(因素有因果语义) |
| 干预效果 | 不可控 | 可预测和可控 |
| 分布外泛化 | 有限 | 强 |
2. 独立因果机制原理
2.1 ICM-VAE 框架
核心原理:独立因果机制 → 独立的条件分布
其中 是因果变量 的父节点。
2.2 损失函数设计
ICM-VAE 的损失函数包含两个部分:
第一项确保表示能重建观测数据,第二项强制因果先验结构。
2.3 SCM(结构因果模型)集成
每个因果变量由其父节点通过非线性函数决定:
其中:
- : 的父节点集合
- :非线性因果机制
- :独立的外生噪声
3. 方法分类
3.1 方法总览
因果解耦方法
├── 基于变分推断
│ ├── CausalVAE
│ ├── iVAE / AdaVAE
│ └── ICM-VAE
├── 基于对比学习
│ ├── C-Disentanglement
│ └── Contrastive CRL
├── 基于不变性原理
│ ├── IRM-based CRL
│ └── Environment-based
├── 基于 Diffusion
│ ├── CausalDiffVAE
│ └── Diffusion + SCM
└── 基于 Transformer
└── CausalTransformer
3.2 CausalVAE 详解
架构
输入图像 X
↓
编码器 E_φ → 独立潜在变量 z = (z₁, z₂, ..., zₙ)
↓ (因果层:通过 SCM 变换)
因果变量 c = (c₁, c₂, ..., cₙ)
↓
解码器 G_θ → 重构图像 X̂
因果层实现
class CausalLayer(nn.Module):
"""将独立噪声通过 SCM 转化为因果相关变量"""
def __init__(self, graph_adj):
super().__init__()
self.graph = graph_adj # 邻接矩阵表示因果图
self.mechanisms = nn.ModuleDict()
for i in range(n_vars):
parents = np.where(graph_adj[:, i])[0]
self.mechanisms[f'f_{i}'] = nn.Sequential(
nn.Linear(len(parents) + 1, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, z):
"""拓扑序前向传播"""
c = torch.zeros_like(z)
for i in topological_order:
parents = self.parents[i]
inputs = torch.cat([z[:, parents], c[:, parents]], dim=-1)
c[:, i] = self.mechanisms[f'f_{i}'](inputs).squeeze(-1)
return c3.3 C-Disentanglement
核心思想:通过混淆因子(confounder)框架处理因果解耦
关键假设:
- 存在未观察到的混淆因子影响多个因果因素
- 通过对比学习识别和分离混淆效应
3.4 多环境因果解耦
当数据来自多个环境(具有不同的干预)时:
其中 鼓励找到在环境间不变的因果表示。
4. 因果图的先验与学习
4.1 图结构先验
| 先验类型 | 适用场景 | 表示方法 |
|---|---|---|
| 完全图 | 无先验知识 | 全部连接 |
| 稀疏先验 | 假设因果关系稀疏 | L1正则/结构化先验 |
| 层次先验 | 已知层级结构 | 分层邻接矩阵 |
| 专家知识 | 部分已知因果关系 | 固定部分边 |
4.2 图结构学习
基于梯度的方法
# 可学习邻接矩阵(带稀疏正则)
adjacency = torch.zeros(n_vars, n_vars)
adjacency = nn.Parameter(adjacency) # 可学习
# 稀疏正则项
sparsity_loss = lambda_reg * torch.sum(torch.abs(adjacency))基于结构学习的方法
- PC算法改编:条件独立性检验
- 分数匹配方法:基于因果发现评分
- 端到端学习:与表示学习联合优化
5. 评估指标
5.1 解耦质量评估
| 指标 | 描述 | 理想值 |
|---|---|---|
| MIG(Mutual Information Gap) | 因果因素的信息差距 | 高 |
| DCI(Disentanglement, Completeness, Informativeness) | 解耦度、完整性、信息量 | 高 |
| SAP(Separated Attribute Predictability) | 因素可分离性 | 高 |
5.2 因果有效性评估
| 指标 | 描述 |
|---|---|
| 干预效果准确性 | 干预因果变量后的变化符合预期 |
| 反事实一致性 | 反事实预测与真实一致 |
| 分布外泛化 | 在新环境下性能保持 |
5.3 下游任务评估
- 分类准确率:因果表示用于下游任务的性能
- 少样本泛化:在少样本场景下的表现
- 反事实问答:因果推理问题的准确率
6. 代码实现框架
6.1 基础 CausalVAE
#include <bits/stdc++.h>
using namespace std;
// 简化的因果层实现
struct CausalLayer {
int n_vars, hidden_dim;
vector<vector<int>> parents; // 拓扑序
vector<nn.Linear> mechanisms;
CausalLayer(int n_vars, int hidden_dim, const vector<vector<int>>& graph)
: n_vars(n_vars), hidden_dim(hidden_dim), parents(graph) {
// 为每个变量创建因果机制网络
for (int i = 0; i < n_vars; ++i) {
int n_parents = parents[i].size();
mechanisms.push_back(nn.Linear(n_parents + 1, 1));
}
}
torch::Tensor forward(torch::Tensor z) {
torch::Tensor c = torch::zeros_like(z);
// 拓扑序前向传播
for (int i : topological_order) {
vector<torch::Tensor> inputs;
inputs.push_back(z.index({"...", i}));
for (int p : parents[i]) {
inputs.push_back(c.index({"...", p}));
}
c.index({"...", i}) = mechanisms[i](torch::cat(inputs, -1));
}
return c;
}
};6.2 训练循环
def training_loop(encoder, causal_layer, decoder, dataloader, optimizer):
for x, in dataloader:
# 前向传播
z = encoder(x) # 独立潜在变量
c = causal_layer(z) # 因果变换
x_recon = decoder(c)
# 损失计算
recon_loss = F.mse_loss(x_recon, x)
causal_prior_loss = compute_causal_prior_loss(c, causal_layer.graph)
loss = recon_loss + beta * causal_prior_loss
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()