1. 概述

1.1 可识别性问题的起源

标准 VAE 存在不可识别(Unidentifiable) 问题:即使有无限数据,也存在无限多组参数产生相同的观测分布

问题

这意味着无法从观测数据中恢复真实的潜在结构。

1.2 可识别 VAE 的核心思想

可识别变分自编码器(Identifiable VAE, iVAE) 通过引入辅助变量因果结构假设,打破不可识别性,实现从观测数据恢复真实因果表示。


2. iVAE 理论基础

2.1 非线性 ICA 的不可识别性

考虑非线性独立成分分析模型:

其中 是非线性函数。在无额外假设的情况下:

定理(Hyvärinen & Pajunen, 1999):非线性 ICA 在本质上不可识别,即使有无限数据也无法恢复真实的独立成分。

2.2 可识别性条件

iVAE 通过以下条件实现可识别性:

辅助变量条件

假设存在辅助变量 与潜在变量 有以下关系:

定理(Khemakhem et al., 2020):在辅助变量条件下,非线性 ICA 可识别。

时序/多环境条件

在时序数据或多环境设置下:

2.3 iVAE 目标函数

其中:

  • 是辅助变量(环境标识、时间戳等)
  • 是依赖 的先验分布

3. CausalVAE 架构

3.1 核心思想

CausalVAE(Yang et al., CVPR 2021)将结构因果模型(SCM)集成到 VAE 框架中,同时学习:

  1. 因果图结构
  2. 因果表示

3.2 架构图

输入图像 X
    ↓
编码器 E_φ → 独立潜在变量 z = (z₁, z₂, ..., zₙ)
    ↓ (因果层)
因果层 L_ψ → 因果变量 c = (c₁, c₂, ..., cₙ)
    ↓ (SCM: c = f(z, θ))
    ↓
解码器 G_θ → 重构图像 X̂

3.3 因果层实现

因果层是 CausalVAE 的核心创新,它通过 SCM 将独立潜在变量转化为因果相关变量:

SCM 定义

其中:

  • 的父节点集合
  • 是非线性函数
  • 是独立的外生噪声

拓扑序前向传播

class CausalLayer(nn.Module):
    def __init__(self, n_vars, hidden_dim, adjacency_matrix):
        super().__init__()
        self.n_vars = n_vars
        self.adj = adjacency_matrix  # 因果邻接矩阵
        self.topo_order = self.compute_topological_order()
        
        # 为每个变量创建因果机制
        for i in range(n_vars):
            n_parents = sum(self.adj[:, i])
            self.mechanisms[f'f_{i}'] = nn.Sequential(
                nn.Linear(n_parents + 1, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, 1)
            )
    
    def forward(self, z):
        """拓扑序前向传播"""
        c = torch.zeros_like(z)
        
        for idx in self.topo_order:
            parents = torch.where(self.adj[:, idx])[0]
            
            # 构建输入
            inputs = [z[:, idx]]  # 外生噪声
            for p in parents:
                inputs.append(c[:, p])
            
            x = torch.cat(inputs, dim=-1)
            c[:, idx] = self.mechanisms[f'f_{idx}'](x).squeeze(-1)
        
        return c

4. 损失函数设计

4.1 重建损失

其中 是通过因果层变换得到的因果变量。

4.2 因果先验损失

强制因果变量满足 SCM 定义的独立性:

其中 由 SCM 隐式定义。

4.3 图结构正则

鼓励稀疏的因果图:

其中 是因果邻接矩阵。

4.4 总损失


5. 变体与扩展

5.1 AdaVAE

Adaptive VAE 通过自适应方式学习因果图的邻接矩阵:

class AdaVAE(CausalVAE):
    def __init__(self, n_vars, hidden_dim, lambda_sparse=0.01):
        super().__init__(n_vars, hidden_dim)
        # 可学习的邻接矩阵(初始为全连接)
        self.adj = nn.Parameter(
            torch.ones(n_vars, n_vars) - torch.eye(n_vars)
        )
        self.lambda_sparse = lambda_sparse
    
    def get_sparse_adj(self):
        """通过 sigmoid 和阈值获得稀疏邻接矩阵"""
        adj_sigmoid = torch.sigmoid(self.adj)
        # 阈值化
        return (adj_sigmoid > 0.5).float()

5.2 Causal-Disentangled VAE (CDVAE)

结合对比学习和因果解耦:

其中 拉远不同样本中相同因果因素的表示。

5.3 多视图 CausalVAE

处理多模态数据的扩展:


6. 实验与分析

6.1 基准数据集

数据集描述因果因素
Shapes3D3D 形状图像颜色、形状、大小、旋转
MPI3D真实/仿真 3D 物体6 个物理属性
Sprites精灵图像角色类型、颜色、背景
CelebA人脸图像发型、眼镜、表情等

6.2 可识别性验证

实验设置

  1. 生成数据:
  2. 训练 CausalVAE
  3. 验证学到的表示与真实因果变量的一致性

评估指标

指标计算方法理想值
MIG1.0
DCIDisentanglement × Completeness × Informativeness
干预效果ATE 准确性

6.3 反事实生成质量

原始图像 → CausalVAE → 干预因果变量 → 反事实图像
     ↓
对比真实干预后的图像
     ↓
评估 FID、LPIPS、干预准确性

7. 代码实现

7.1 完整 CausalVAE

#include <bits/stdc++.h>
#include <torch/torch.h>
 
namespace {
 
// 因果机制
struct CausalMechanismImpl : torch::nn::Module {
    torch::nn::Sequential net;
    
    CausalMechanismImpl(int input_dim, int hidden_dim) {
        net = torch::nn::Sequential(
            torch::nn::Linear(input_dim, hidden_dim),
            torch::nn::ReLU(),
            torch::nn::Linear(hidden_dim, 1)
        );
        register_module("net", net);
    }
    
    torch::Tensor forward(torch::Tensor x) {
        return net->forward(x);
    }
};
 
TORCH_MODULE(CausalMechanism);
 
// 因果层
struct CausalLayerImpl : torch::nn::Module {
    int n_vars, hidden_dim;
    torch::Tensor adj;  // 邻接矩阵
    std::vector<int> topo_order;
    std::vector<CausalMechanism> mechanisms;
    
    CausalLayerImpl(int n_vars, int hidden_dim, torch::Tensor adj)
        : n_vars(n_vars), hidden_dim(hidden_dim), adj(adj) {
        
        topo_order = compute_topological_order(adj);
        
        for (int i = 0; i < n_vars; ++i) {
            int n_parents = adj.index({torch::indexing::Slice(), i}).sum().item<int>();
            mechanisms.push_back(
                CausalMechanism(n_parents + 1, hidden_dim)
            );
            register_module("mechanism_" + std::to_string(i), mechanisms.back());
        }
    }
    
    torch::Tensor forward(torch::Tensor z) {
        // z: (batch_size, n_vars)
        torch::Tensor c = torch::zeros_like(z);
        
        for (int idx : topo_order) {
            std::vector<torch::Tensor> inputs;
            inputs.push_back(z.index({torch::indexing::Slice(), idx}));
            
            // 添加父节点的因果变量
            for (int p = 0; p < n_vars; ++p) {
                if (adj.index({p, idx}).item<float>() > 0.5f) {
                    inputs.push_back(c.index({torch::indexing::Slice(), p}));
                }
            }
            
            torch::Tensor x = torch::cat(inputs, -1);
            c.index_put_({torch::indexing::Slice(), idx}, 
                         mechanisms[idx]->forward(x).squeeze(-1));
        }
        
        return c;
    }
};
 
TORCH_MODULE(CausalLayer);
 
// CausalVAE 主类
struct CausalVAEImpl : torch::nn::Module {
    torch::nn::ModuleEncoder encoder;
    CausalLayer causal_layer;
    torch::nn::ModuleDecoder decoder;
    
    CausalVAEImpl(/* 参数 */) {
        encoder = register_module("encoder", ModuleEncoder(...));
        causal_layer = register_module("causal_layer", CausalLayer(...));
        decoder = register_module("decoder", ModuleDecoder(...));
    }
    
    std::pair<torch::Tensor, torch::Tensor> forward(torch::Tensor x) {
        torch::Tensor z = encoder->forward(x);
        torch::Tensor c = causal_layer->forward(z);
        torch::Tensor x_recon = decoder->forward(c);
        return {x_recon, c};
    }
};
 
TORCH_MODULE(CausalVAE);
 
} // namespace

7.2 训练循环

def train_causalvae(model, dataloader, optimizer, epochs):
    for epoch in range(epochs):
        for x, in dataloader:
            # 前向
            x_recon, c = model(x)
            
            # 损失
            recon_loss = F.mse_loss(x_recon, x)
            kl_loss = compute_causal_kl(model, c)
            
            loss = recon_loss + beta * kl_loss
            
            # 反向
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

8. 参考文献