概述

因果表示学习(Causal Representation Learning)是连接因果推断与深度学习的桥梁,旨在从观测数据中学习能够反映底层因果结构的表示。1 与传统表示学习关注统计相关性不同,因果表示学习追求的是可解释、可干预、泛化性强的表示。


从相关性到因果性

传统表示学习的局限

传统深度学习通过优化重建损失或对比损失学习表示:

这种方法学到的表示:

  • 捕获统计规律而非因果机制
  • 对分布偏移脆弱
  • 难以进行反事实推理

因果表示的愿景

因果表示学习追求:

  1. 因果解耦:表示的每个维度对应一个独立的因果机制
  2. 干预不变性:对某些干预具有不变性
  3. 反事实可推断:给定表示可以回答反事实问题
  4. 跨环境泛化:在不同环境中保持有效性

因果表示的数学基础

结构因果模型回顾

给定 SCM ,其中:

  • :结构方程
  • :独立外生变量

SCM 定义了因果生成过程:

其中 的直接原因(父节点)。

因果表示的形式化

是观测数据, 是潜在表示。因果表示学习的目标是找到编码器/解码器对:

使得 捕获底层因果结构。

因果表示的识别条件

因果表示可识别的充分条件包括:

  1. 因果充分性:表示覆盖所有生成数据的因果机制
  2. 独立性约束:表示维度间满足因果独立假设
  3. 环境多样性:存在多个环境/干预提供识别信号

独立机制原则

独立因果机制(ICM)

核心假设:复杂系统的因果生成过程由若干独立的机制组成,每个机制对应一个条件分布:

ICM的直观理解

因果结构:X → Z → Y

对应的独立机制:
┌────────────────────┐
│ 机制1: P(X)       │  生成X的边缘分布
└────────────────────┘
         ↓
┌────────────────────┐
│ 机制2: P(Z|X)      │  X如何影响Z(独立于Y的生成机制)
└────────────────────┘
         ↓
┌────────────────────┐
│ 机制3: P(Y|Z)      │  Z如何影响Y(独立于X的生成机制)
└────────────────────┘

独立因果噪声(ICN)

ICN假设进一步约束噪声的独立性:

为什么ICN重要

  1. 可区分性:不同的噪声分布对应不同的因果机制
  2. 可识别性:为因果方向识别提供信号
  3. 干预鲁棒性:改变一个机制不影响其他机制

ICM与统计独立的区别

方面统计独立因果独立(ICM)
条件$P(Y
干预效应不变干预X时 $P(Y
实际含义无相关性独立因果机制

示例

考虑数据生成

  • 统计关系: 可能高度相关
  • 因果机制: 是独立机制

不变因果预测(ICP)

核心思想

不变因果预测(Invariant Causal Prediction, ICP)2利用环境/领域的异质性来识别因果变量:

核心假设:因果机制 在不同环境中保持不变,而非因果机制会随环境变化。

ICP算法

class InvariantCausalPrediction:
    """
    不变因果预测算法实现
    """
    def __init__(self, alpha=0.05):
        self.alpha = alpha  # 显著性水平
        self.confounder_set = None
        
    def fit(self, environments):
        """
        输入: 多个环境的数据
              environments = [env1, env2, ...]
              每个 env = {X, Y, E},E是环境标识符
        """
        n_vars = len(environments[0]['X'].columns)
        candidates = []
        
        # 1. 尝试所有可能的父节点集合
        for parent_set in powerset(range(n_vars)):
            if self._test_invariance(parent_set, environments):
                candidates.append(parent_set)
        
        # 2. 取交集得到不变的因果父节点
        if candidates:
            self.confounder_set = set.intersection(*[set(c) for c in candidates])
        else:
            self.confounder_set = set()
            
        return self
    
    def _test_invariance(self, parent_set, environments):
        """
        测试给定父节点集合是否在所有环境中不变
        """
        from scipy import stats
        
        regressions = []
        for env in environments:
            # 在每个环境中拟合回归
            X_parents = env['X'][list(parent_set)]
            Y = env['Y']
            beta = np.linalg.lstsq(X_parents, Y, rcond=None)[0]
            regressions.append(beta)
        
        # 检验回归系数是否在不同环境中一致
        # (使用适当的多样本检验)
        return self._homogeneity_test(regressions, self.alpha)

ICP的理论保证

识别定理:如果存在一个环境集合使得因果父节点集合不变,且数据生成过程满足特定条件,则ICP能一致地识别出因果父节点。

关键条件

  1. 环境异质性:至少存在一个环境,使得混杂变量不干扰因果识别
  2. 正确指定:因果图结构正确
  3. 足够的信噪比:统计检验有足够的功效

ICP的扩展

扩展方法改进
非线性ICP神经网络+不变损失处理非线性因果关系
潜在变量ICP处理未观测混杂处理更复杂场景
分布式ICP联邦学习设置保护数据隐私

可识别的因果表示

因果表示识别问题

给定观测数据 ,能否唯一恢复底层因果表示

识别的不确定性

由于因果方向在纯观测数据下不可区分(如 产生相同联合分布),直接识别完整的因果结构通常是不可能的。

弱因果表示识别

可识别条件下的表示

在特定条件下,可以识别到”等价类”级别的因果表示:

  1. 后门调整可用的表示 可识别的表示
  2. 前门调整可用的表示:通过中介变量的表示
  3. 时间结构可用的表示:时间序列因果表示

表示空间的约束

class CausalRepresentationSpace:
    """
    因果表示空间的结构约束
    """
    def __init__(self, n_factors, causal_graph):
        self.n_factors = n_factors
        self.causal_graph = causal_graph
        
    def enforce_causal_constraints(self, z):
        """
        对表示应用因果结构约束
        """
        # 1. 前向传播(按拓扑序)
        z_constrained = z.clone()
        for node in self.causal_graph.topological_order():
            parents = self.causal_graph.parents(node)
            if parents:
                # z[node] 应该主要由 parents 决定
                z_constrained[node] = self._mechanism(
                    z_constrained[parents], 
                    self.causal_graph.mechanism(node)
                )
                
        return z_constrained
    
    def independent_mechanism_loss(self, z):
        """
        ICM损失:鼓励表示维度间的独立机制
        """
        loss = 0
        for node in self.causal_graph.nodes:
            parents = self.causal_graph.parents(node)
            if parents:
                # 预测z[node]从z[parents]
                pred = self.predict_from_parents(z, parents, node)
                loss += reconstruction_loss(z[node], pred)
                
        return loss

神经因果模型

CausalVAE架构

CausalVAE3将因果结构融入变分自编码器:

架构设计

┌─────────────────────────────────────────────────────────────┐
│                        CausalVAE                             │
│                                                               │
│  输入x ──→ 编码器 E ──→ 潜在变量 z ──┐                        │
│                                        │                       │
│                            ┌───────────┴───────────┐          │
│                            ▼                       ▼          │
│                     因果解码器              独立解码器        │
│                    (符合因果结构)            (标准重建)       │
│                            │                       │          │
│                            └───────────┬───────────┘          │
│                                        ▼                      │
│                                   重构 x̂                      │
└─────────────────────────────────────────────────────────────┘

因果解码器

其中 遵循因果图定义的依赖关系:

训练目标

因果结构损失

\mathcal{L}_{\text{causal}} = \lambda_1 \mathcal{L}_{\text{ICM}} + \lambda_2 \mathcal{L}_{\text{sparsity}} + \lambda_3 \mathcal{L}_{\text{dag}}}

其中:

  • :独立机制损失
  • :稀疏性损失(鼓励稀疏因果图)
  • :DAG约束(无环)

因果归一化流

因果归一化流4利用因果结构增强标准化流模型:

核心思想

传统归一化流假设潜在空间与数据空间同构。因果归一化流强制潜在空间具有因果结构:

数学框架

class CausalNormalizingFlow(Flow):
    """
    因果归一化流
    """
    def __init__(self, base_dist, causal_graph):
        self.base_dist = base_dist  # 满足因果结构的先验
        self.causal_graph = causal_graph
        
    def forward(self, x):
        """
        前向: x → z (推断因果表示)
        """
        z, log_det = self._causal_forward(x)
        log_prob = self.base_dist.log_prob(z) + log_det
        return z, log_prob
    
    def _causal_forward(self, x):
        """
        按因果顺序进行变换
        """
        z = x.clone()
        log_det = 0
        
        for node in self.causal_graph.topological_order():
            parents = self.causal_graph.parents(node)
            if parents:
                # 使用父节点信息变换当前节点
                z[:, node], ld = self._transform_node(
                    z[:, node], z[:, parents]
                )
                log_det += ld
                
        return z, log_det

因果表示学习基准

CausalBench 2025

CausalBench5是评估因果表示学习方法的综合基准:

评估维度

维度指标描述
因果发现SID, SHD与真实因果图的相似度
解耦性DCI, MIG表示维度的独立程度
泛化性OOD accuracy分布外测试性能
可解释性人工评估因果解释的合理性

数据集分类

数据集特点适用任务
Causal3DIdent3D图像+因果因素视觉因果表示
CelebA-Causal人脸属性+因果关系属性因果学习
Molecular分子图+因果机制科学因果发现
Intervention多环境+干预不变表示学习

实践指南

实现步骤

  1. 定义因果图结构:明确变量间的因果假设
  2. 设计编码器:将观测映射到潜在空间
  3. 引入因果约束:ICM、稀疏性、DAG约束
  4. 多环境训练:利用环境异质性增强因果表示
  5. 评估与验证:使用因果指标评估

代码模板

import torch
import torch.nn as nn
 
class CausalRepresentationLearner(nn.Module):
    def __init__(self, input_dim, latent_dim, causal_graph):
        super().__init__()
        self.causal_graph = causal_graph
        
        # 编码器
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )
        
        # 因果解码器
        self.decoder = nn.ModuleDict({
            str(node): nn.Sequential(
                nn.Linear(len(causal_graph.parents(node)) + 1, 64),
                nn.ReLU(),
                nn.Linear(64, 1)
            )
            for node in causal_graph.nodes
        })
        
    def forward(self, x):
        # 编码
        z = self.encoder(x)
        
        # 因果解码(按拓扑序)
        x_recon = {}
        for node in self.causal_graph.topological_order():
            parents = self.causal_graph.parents(node)
            if parents:
                parent_z = torch.stack([z[:, p] for p in parents], dim=1)
                z_node = torch.cat([z[:, node:node+1], parent_z], dim=1)
            else:
                z_node = z[:, node:node+1]
            x_recon[node] = self.decoder[str(node)](z_node)
            
        return z, x_recon
    
    def causal_loss(self, z, x):
        """
        因果结构损失
        """
        icm_loss = 0
        for node in self.causal_graph.nodes:
            parents = self.causal_graph.parents(node)
            if parents:
                # 预测z[node]从z[parents]
                pred = self._predict_mechanism(z, parents, node)
                icm_loss += F.mse_loss(z[:, node], pred)
        return icm_loss

注意事项

  1. 因果假设的敏感性:结果依赖正确的因果图假设
  2. 环境异质性的利用:确保训练环境足够多样
  3. 可识别性限制:在纯观测数据下,通常只能识别到等价类

相关内容


参考文献


Last updated: 2026-05-14

Footnotes

  1. Schölkopf et al. “Toward Causal Representation Learning.” Proceedings of the IEEE, 2021.

  2. Peters et al. “Invariant Causal Prediction.” JMLR, 2016.

  3. Locatello et al. “CausalVAE: Disentangled Representation Learning via Neural Structural Causal Models.” CVPR 2020.

  4. arXiv:2301.00976. “Causal Normalizing Flows.” 2023.

  5. arXiv:2603.17403. “Causal Representation Learning Benchmark.” 2026.