1. 概述

1.1 什么是因果解耦

因果解耦(Causal Disentanglement) 旨在学习满足以下条件的表示:

  1. 因果因素可解释:每个潜在维度对应一个有意义的因果因素
  2. 因果关系明确:因素间的关系可以用因果图描述
  3. 干预可控:修改一个因素只影响相关的因果机制

1.2 与统计解耦的对比

维度统计解耦(VAE/β-VAE)因果解耦
独立性定义统计独立(互信息=0)因果独立(独立机制)
可解释性低(因素可能混合)高(因素有因果语义)
干预效果不可控可预测和可控
分布外泛化有限

2. 独立因果机制原理

2.1 ICM-VAE 框架

核心原理:独立因果机制 → 独立的条件分布

其中 是因果变量 的父节点。

2.2 损失函数设计

ICM-VAE 的损失函数包含两个部分:

第一项确保表示能重建观测数据,第二项强制因果先验结构。

2.3 SCM(结构因果模型)集成

每个因果变量由其父节点通过非线性函数决定:

其中:

  • 的父节点集合
  • :非线性因果机制
  • :独立的外生噪声

3. 方法分类

3.1 方法总览

因果解耦方法
├── 基于变分推断
│   ├── CausalVAE
│   ├── iVAE / AdaVAE
│   └── ICM-VAE
├── 基于对比学习
│   ├── C-Disentanglement
│   └── Contrastive CRL
├── 基于不变性原理
│   ├── IRM-based CRL
│   └── Environment-based
├── 基于 Diffusion
│   ├── CausalDiffVAE
│   └── Diffusion + SCM
└── 基于 Transformer
    └── CausalTransformer

3.2 CausalVAE 详解

架构

输入图像 X
    ↓
编码器 E_φ → 独立潜在变量 z = (z₁, z₂, ..., zₙ)
    ↓ (因果层:通过 SCM 变换)
因果变量 c = (c₁, c₂, ..., cₙ)
    ↓
解码器 G_θ → 重构图像 X̂

因果层实现

class CausalLayer(nn.Module):
    """将独立噪声通过 SCM 转化为因果相关变量"""
    def __init__(self, graph_adj):
        super().__init__()
        self.graph = graph_adj  # 邻接矩阵表示因果图
        self.mechanisms = nn.ModuleDict()
        
        for i in range(n_vars):
            parents = np.where(graph_adj[:, i])[0]
            self.mechanisms[f'f_{i}'] = nn.Sequential(
                nn.Linear(len(parents) + 1, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, 1)
            )
    
    def forward(self, z):
        """拓扑序前向传播"""
        c = torch.zeros_like(z)
        for i in topological_order:
            parents = self.parents[i]
            inputs = torch.cat([z[:, parents], c[:, parents]], dim=-1)
            c[:, i] = self.mechanisms[f'f_{i}'](inputs).squeeze(-1)
        return c

3.3 C-Disentanglement

核心思想:通过混淆因子(confounder)框架处理因果解耦

关键假设:

  • 存在未观察到的混淆因子影响多个因果因素
  • 通过对比学习识别和分离混淆效应

3.4 多环境因果解耦

当数据来自多个环境(具有不同的干预)时:

其中 鼓励找到在环境间不变的因果表示。


4. 因果图的先验与学习

4.1 图结构先验

先验类型适用场景表示方法
完全图无先验知识全部连接
稀疏先验假设因果关系稀疏L1正则/结构化先验
层次先验已知层级结构分层邻接矩阵
专家知识部分已知因果关系固定部分边

4.2 图结构学习

基于梯度的方法

# 可学习邻接矩阵(带稀疏正则)
adjacency = torch.zeros(n_vars, n_vars)
adjacency = nn.Parameter(adjacency)  # 可学习
 
# 稀疏正则项
sparsity_loss = lambda_reg * torch.sum(torch.abs(adjacency))

基于结构学习的方法

  • PC算法改编:条件独立性检验
  • 分数匹配方法:基于因果发现评分
  • 端到端学习:与表示学习联合优化

5. 评估指标

5.1 解耦质量评估

指标描述理想值
MIG(Mutual Information Gap)因果因素的信息差距
DCI(Disentanglement, Completeness, Informativeness)解耦度、完整性、信息量
SAP(Separated Attribute Predictability)因素可分离性

5.2 因果有效性评估

指标描述
干预效果准确性干预因果变量后的变化符合预期
反事实一致性反事实预测与真实一致
分布外泛化在新环境下性能保持

5.3 下游任务评估

  • 分类准确率:因果表示用于下游任务的性能
  • 少样本泛化:在少样本场景下的表现
  • 反事实问答:因果推理问题的准确率

6. 代码实现框架

6.1 基础 CausalVAE

#include <bits/stdc++.h>
using namespace std;
 
// 简化的因果层实现
struct CausalLayer {
    int n_vars, hidden_dim;
    vector<vector<int>> parents;  // 拓扑序
    vector<nn.Linear> mechanisms;
    
    CausalLayer(int n_vars, int hidden_dim, const vector<vector<int>>& graph)
        : n_vars(n_vars), hidden_dim(hidden_dim), parents(graph) {
        // 为每个变量创建因果机制网络
        for (int i = 0; i < n_vars; ++i) {
            int n_parents = parents[i].size();
            mechanisms.push_back(nn.Linear(n_parents + 1, 1));
        }
    }
    
    torch::Tensor forward(torch::Tensor z) {
        torch::Tensor c = torch::zeros_like(z);
        // 拓扑序前向传播
        for (int i : topological_order) {
            vector<torch::Tensor> inputs;
            inputs.push_back(z.index({"...", i}));
            for (int p : parents[i]) {
                inputs.push_back(c.index({"...", p}));
            }
            c.index({"...", i}) = mechanisms[i](torch::cat(inputs, -1));
        }
        return c;
    }
};

6.2 训练循环

def training_loop(encoder, causal_layer, decoder, dataloader, optimizer):
    for x, in dataloader:
        # 前向传播
        z = encoder(x)  # 独立潜在变量
        c = causal_layer(z)  # 因果变换
        x_recon = decoder(c)
        
        # 损失计算
        recon_loss = F.mse_loss(x_recon, x)
        causal_prior_loss = compute_causal_prior_loss(c, causal_layer.graph)
        
        loss = recon_loss + beta * causal_prior_loss
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

7. 参考文献