概述
因果表示学习(Causal Representation Learning)是连接因果推断与深度学习的桥梁,旨在从观测数据中学习能够反映底层因果结构的表示。1 与传统表示学习关注统计相关性不同,因果表示学习追求的是可解释、可干预、泛化性强的表示。
从相关性到因果性
传统表示学习的局限
传统深度学习通过优化重建损失或对比损失学习表示:
这种方法学到的表示:
- 捕获统计规律而非因果机制
- 对分布偏移脆弱
- 难以进行反事实推理
因果表示的愿景
因果表示学习追求:
- 因果解耦:表示的每个维度对应一个独立的因果机制
- 干预不变性:对某些干预具有不变性
- 反事实可推断:给定表示可以回答反事实问题
- 跨环境泛化:在不同环境中保持有效性
因果表示的数学基础
结构因果模型回顾
给定 SCM ,其中:
- :结构方程
- :独立外生变量
SCM 定义了因果生成过程:
其中 是 的直接原因(父节点)。
因果表示的形式化
设 是观测数据, 是潜在表示。因果表示学习的目标是找到编码器/解码器对:
使得 捕获底层因果结构。
因果表示的识别条件:
因果表示可识别的充分条件包括:
- 因果充分性:表示覆盖所有生成数据的因果机制
- 独立性约束:表示维度间满足因果独立假设
- 环境多样性:存在多个环境/干预提供识别信号
独立机制原则
独立因果机制(ICM)
核心假设:复杂系统的因果生成过程由若干独立的机制组成,每个机制对应一个条件分布:
ICM的直观理解:
因果结构:X → Z → Y
对应的独立机制:
┌────────────────────┐
│ 机制1: P(X) │ 生成X的边缘分布
└────────────────────┘
↓
┌────────────────────┐
│ 机制2: P(Z|X) │ X如何影响Z(独立于Y的生成机制)
└────────────────────┘
↓
┌────────────────────┐
│ 机制3: P(Y|Z) │ Z如何影响Y(独立于X的生成机制)
└────────────────────┘
独立因果噪声(ICN)
ICN假设进一步约束噪声的独立性:
为什么ICN重要:
- 可区分性:不同的噪声分布对应不同的因果机制
- 可识别性:为因果方向识别提供信号
- 干预鲁棒性:改变一个机制不影响其他机制
ICM与统计独立的区别
| 方面 | 统计独立 | 因果独立(ICM) |
|---|---|---|
| 条件 | $P(Y | |
| 干预效应 | 不变 | 干预X时 $P(Y |
| 实际含义 | 无相关性 | 独立因果机制 |
示例:
考虑数据生成 :
- 统计关系:、、 可能高度相关
- 因果机制:、 是独立机制
不变因果预测(ICP)
核心思想
不变因果预测(Invariant Causal Prediction, ICP)2利用环境/领域的异质性来识别因果变量:
核心假设:因果机制 在不同环境中保持不变,而非因果机制会随环境变化。
ICP算法
class InvariantCausalPrediction:
"""
不变因果预测算法实现
"""
def __init__(self, alpha=0.05):
self.alpha = alpha # 显著性水平
self.confounder_set = None
def fit(self, environments):
"""
输入: 多个环境的数据
environments = [env1, env2, ...]
每个 env = {X, Y, E},E是环境标识符
"""
n_vars = len(environments[0]['X'].columns)
candidates = []
# 1. 尝试所有可能的父节点集合
for parent_set in powerset(range(n_vars)):
if self._test_invariance(parent_set, environments):
candidates.append(parent_set)
# 2. 取交集得到不变的因果父节点
if candidates:
self.confounder_set = set.intersection(*[set(c) for c in candidates])
else:
self.confounder_set = set()
return self
def _test_invariance(self, parent_set, environments):
"""
测试给定父节点集合是否在所有环境中不变
"""
from scipy import stats
regressions = []
for env in environments:
# 在每个环境中拟合回归
X_parents = env['X'][list(parent_set)]
Y = env['Y']
beta = np.linalg.lstsq(X_parents, Y, rcond=None)[0]
regressions.append(beta)
# 检验回归系数是否在不同环境中一致
# (使用适当的多样本检验)
return self._homogeneity_test(regressions, self.alpha)ICP的理论保证
识别定理:如果存在一个环境集合使得因果父节点集合不变,且数据生成过程满足特定条件,则ICP能一致地识别出因果父节点。
关键条件:
- 环境异质性:至少存在一个环境,使得混杂变量不干扰因果识别
- 正确指定:因果图结构正确
- 足够的信噪比:统计检验有足够的功效
ICP的扩展
| 扩展 | 方法 | 改进 |
|---|---|---|
| 非线性ICP | 神经网络+不变损失 | 处理非线性因果关系 |
| 潜在变量ICP | 处理未观测混杂 | 处理更复杂场景 |
| 分布式ICP | 联邦学习设置 | 保护数据隐私 |
可识别的因果表示
因果表示识别问题
给定观测数据 ,能否唯一恢复底层因果表示 ?
识别的不确定性:
由于因果方向在纯观测数据下不可区分(如 和 产生相同联合分布),直接识别完整的因果结构通常是不可能的。
弱因果表示识别
可识别条件下的表示:
在特定条件下,可以识别到”等价类”级别的因果表示:
- 后门调整可用的表示: 可识别的表示
- 前门调整可用的表示:通过中介变量的表示
- 时间结构可用的表示:时间序列因果表示
表示空间的约束
class CausalRepresentationSpace:
"""
因果表示空间的结构约束
"""
def __init__(self, n_factors, causal_graph):
self.n_factors = n_factors
self.causal_graph = causal_graph
def enforce_causal_constraints(self, z):
"""
对表示应用因果结构约束
"""
# 1. 前向传播(按拓扑序)
z_constrained = z.clone()
for node in self.causal_graph.topological_order():
parents = self.causal_graph.parents(node)
if parents:
# z[node] 应该主要由 parents 决定
z_constrained[node] = self._mechanism(
z_constrained[parents],
self.causal_graph.mechanism(node)
)
return z_constrained
def independent_mechanism_loss(self, z):
"""
ICM损失:鼓励表示维度间的独立机制
"""
loss = 0
for node in self.causal_graph.nodes:
parents = self.causal_graph.parents(node)
if parents:
# 预测z[node]从z[parents]
pred = self.predict_from_parents(z, parents, node)
loss += reconstruction_loss(z[node], pred)
return loss神经因果模型
CausalVAE架构
CausalVAE3将因果结构融入变分自编码器:
架构设计:
┌─────────────────────────────────────────────────────────────┐
│ CausalVAE │
│ │
│ 输入x ──→ 编码器 E ──→ 潜在变量 z ──┐ │
│ │ │
│ ┌───────────┴───────────┐ │
│ ▼ ▼ │
│ 因果解码器 独立解码器 │
│ (符合因果结构) (标准重建) │
│ │ │ │
│ └───────────┬───────────┘ │
│ ▼ │
│ 重构 x̂ │
└─────────────────────────────────────────────────────────────┘
因果解码器:
其中 遵循因果图定义的依赖关系:
训练目标
因果结构损失:
\mathcal{L}_{\text{causal}} = \lambda_1 \mathcal{L}_{\text{ICM}} + \lambda_2 \mathcal{L}_{\text{sparsity}} + \lambda_3 \mathcal{L}_{\text{dag}}}其中:
- :独立机制损失
- :稀疏性损失(鼓励稀疏因果图)
- :DAG约束(无环)
因果归一化流
因果归一化流4利用因果结构增强标准化流模型:
核心思想:
传统归一化流假设潜在空间与数据空间同构。因果归一化流强制潜在空间具有因果结构:
数学框架:
class CausalNormalizingFlow(Flow):
"""
因果归一化流
"""
def __init__(self, base_dist, causal_graph):
self.base_dist = base_dist # 满足因果结构的先验
self.causal_graph = causal_graph
def forward(self, x):
"""
前向: x → z (推断因果表示)
"""
z, log_det = self._causal_forward(x)
log_prob = self.base_dist.log_prob(z) + log_det
return z, log_prob
def _causal_forward(self, x):
"""
按因果顺序进行变换
"""
z = x.clone()
log_det = 0
for node in self.causal_graph.topological_order():
parents = self.causal_graph.parents(node)
if parents:
# 使用父节点信息变换当前节点
z[:, node], ld = self._transform_node(
z[:, node], z[:, parents]
)
log_det += ld
return z, log_det因果表示学习基准
CausalBench 2025
CausalBench5是评估因果表示学习方法的综合基准:
评估维度:
| 维度 | 指标 | 描述 |
|---|---|---|
| 因果发现 | SID, SHD | 与真实因果图的相似度 |
| 解耦性 | DCI, MIG | 表示维度的独立程度 |
| 泛化性 | OOD accuracy | 分布外测试性能 |
| 可解释性 | 人工评估 | 因果解释的合理性 |
数据集分类
| 数据集 | 特点 | 适用任务 |
|---|---|---|
| Causal3DIdent | 3D图像+因果因素 | 视觉因果表示 |
| CelebA-Causal | 人脸属性+因果关系 | 属性因果学习 |
| Molecular | 分子图+因果机制 | 科学因果发现 |
| Intervention | 多环境+干预 | 不变表示学习 |
实践指南
实现步骤
- 定义因果图结构:明确变量间的因果假设
- 设计编码器:将观测映射到潜在空间
- 引入因果约束:ICM、稀疏性、DAG约束
- 多环境训练:利用环境异质性增强因果表示
- 评估与验证:使用因果指标评估
代码模板
import torch
import torch.nn as nn
class CausalRepresentationLearner(nn.Module):
def __init__(self, input_dim, latent_dim, causal_graph):
super().__init__()
self.causal_graph = causal_graph
# 编码器
self.encoder = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Linear(256, latent_dim)
)
# 因果解码器
self.decoder = nn.ModuleDict({
str(node): nn.Sequential(
nn.Linear(len(causal_graph.parents(node)) + 1, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
for node in causal_graph.nodes
})
def forward(self, x):
# 编码
z = self.encoder(x)
# 因果解码(按拓扑序)
x_recon = {}
for node in self.causal_graph.topological_order():
parents = self.causal_graph.parents(node)
if parents:
parent_z = torch.stack([z[:, p] for p in parents], dim=1)
z_node = torch.cat([z[:, node:node+1], parent_z], dim=1)
else:
z_node = z[:, node:node+1]
x_recon[node] = self.decoder[str(node)](z_node)
return z, x_recon
def causal_loss(self, z, x):
"""
因果结构损失
"""
icm_loss = 0
for node in self.causal_graph.nodes:
parents = self.causal_graph.parents(node)
if parents:
# 预测z[node]从z[parents]
pred = self._predict_mechanism(z, parents, node)
icm_loss += F.mse_loss(z[:, node], pred)
return icm_loss注意事项
- 因果假设的敏感性:结果依赖正确的因果图假设
- 环境异质性的利用:确保训练环境足够多样
- 可识别性限制:在纯观测数据下,通常只能识别到等价类
相关内容
- 因果推断基础 — SCM、do-calculus基础
- 因果解耦 — 因果解耦表示的具体方法
- 因果VAE与隐变量模型 — CausalVAE与iVAE深度解析
- 反事实推理 — 因果表示的反事实应用
参考文献
Last updated: 2026-05-14
Footnotes
-
Schölkopf et al. “Toward Causal Representation Learning.” Proceedings of the IEEE, 2021. ↩
-
Peters et al. “Invariant Causal Prediction.” JMLR, 2016. ↩
-
Locatello et al. “CausalVAE: Disentangled Representation Learning via Neural Structural Causal Models.” CVPR 2020. ↩
-
arXiv:2301.00976. “Causal Normalizing Flows.” 2023. ↩
-
arXiv:2603.17403. “Causal Representation Learning Benchmark.” 2026. ↩