1. 概述
1.1 可识别性问题的起源
标准 VAE 存在不可识别(Unidentifiable) 问题:即使有无限数据,也存在无限多组参数产生相同的观测分布 。
问题:
这意味着无法从观测数据中恢复真实的潜在结构。
1.2 可识别 VAE 的核心思想
可识别变分自编码器(Identifiable VAE, iVAE) 通过引入辅助变量或因果结构假设,打破不可识别性,实现从观测数据恢复真实因果表示。
2. iVAE 理论基础
2.1 非线性 ICA 的不可识别性
考虑非线性独立成分分析模型:
其中 是非线性函数。在无额外假设的情况下:
定理(Hyvärinen & Pajunen, 1999):非线性 ICA 在本质上不可识别,即使有无限数据也无法恢复真实的独立成分。
2.2 可识别性条件
iVAE 通过以下条件实现可识别性:
辅助变量条件
假设存在辅助变量 与潜在变量 有以下关系:
定理(Khemakhem et al., 2020):在辅助变量条件下,非线性 ICA 可识别。
时序/多环境条件
在时序数据或多环境设置下:
2.3 iVAE 目标函数
其中:
- 是辅助变量(环境标识、时间戳等)
- 是依赖 的先验分布
3. CausalVAE 架构
3.1 核心思想
CausalVAE(Yang et al., CVPR 2021)将结构因果模型(SCM)集成到 VAE 框架中,同时学习:
- 因果图结构
- 因果表示
3.2 架构图
输入图像 X
↓
编码器 E_φ → 独立潜在变量 z = (z₁, z₂, ..., zₙ)
↓ (因果层)
因果层 L_ψ → 因果变量 c = (c₁, c₂, ..., cₙ)
↓ (SCM: c = f(z, θ))
↓
解码器 G_θ → 重构图像 X̂
3.3 因果层实现
因果层是 CausalVAE 的核心创新,它通过 SCM 将独立潜在变量转化为因果相关变量:
SCM 定义
其中:
- 是 的父节点集合
- 是非线性函数
- 是独立的外生噪声
拓扑序前向传播
class CausalLayer(nn.Module):
def __init__(self, n_vars, hidden_dim, adjacency_matrix):
super().__init__()
self.n_vars = n_vars
self.adj = adjacency_matrix # 因果邻接矩阵
self.topo_order = self.compute_topological_order()
# 为每个变量创建因果机制
for i in range(n_vars):
n_parents = sum(self.adj[:, i])
self.mechanisms[f'f_{i}'] = nn.Sequential(
nn.Linear(n_parents + 1, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, z):
"""拓扑序前向传播"""
c = torch.zeros_like(z)
for idx in self.topo_order:
parents = torch.where(self.adj[:, idx])[0]
# 构建输入
inputs = [z[:, idx]] # 外生噪声
for p in parents:
inputs.append(c[:, p])
x = torch.cat(inputs, dim=-1)
c[:, idx] = self.mechanisms[f'f_{idx}'](x).squeeze(-1)
return c4. 损失函数设计
4.1 重建损失
其中 是通过因果层变换得到的因果变量。
4.2 因果先验损失
强制因果变量满足 SCM 定义的独立性:
其中 由 SCM 隐式定义。
4.3 图结构正则
鼓励稀疏的因果图:
其中 是因果邻接矩阵。
4.4 总损失
5. 变体与扩展
5.1 AdaVAE
Adaptive VAE 通过自适应方式学习因果图的邻接矩阵:
class AdaVAE(CausalVAE):
def __init__(self, n_vars, hidden_dim, lambda_sparse=0.01):
super().__init__(n_vars, hidden_dim)
# 可学习的邻接矩阵(初始为全连接)
self.adj = nn.Parameter(
torch.ones(n_vars, n_vars) - torch.eye(n_vars)
)
self.lambda_sparse = lambda_sparse
def get_sparse_adj(self):
"""通过 sigmoid 和阈值获得稀疏邻接矩阵"""
adj_sigmoid = torch.sigmoid(self.adj)
# 阈值化
return (adj_sigmoid > 0.5).float()5.2 Causal-Disentangled VAE (CDVAE)
结合对比学习和因果解耦:
其中 拉远不同样本中相同因果因素的表示。
5.3 多视图 CausalVAE
处理多模态数据的扩展:
6. 实验与分析
6.1 基准数据集
| 数据集 | 描述 | 因果因素 |
|---|---|---|
| Shapes3D | 3D 形状图像 | 颜色、形状、大小、旋转 |
| MPI3D | 真实/仿真 3D 物体 | 6 个物理属性 |
| Sprites | 精灵图像 | 角色类型、颜色、背景 |
| CelebA | 人脸图像 | 发型、眼镜、表情等 |
6.2 可识别性验证
实验设置
- 生成数据:
- 训练 CausalVAE
- 验证学到的表示与真实因果变量的一致性
评估指标
| 指标 | 计算方法 | 理想值 |
|---|---|---|
| MIG | 1.0 | |
| DCI | Disentanglement × Completeness × Informativeness | 高 |
| 干预效果 | ATE 准确性 | 高 |
6.3 反事实生成质量
原始图像 → CausalVAE → 干预因果变量 → 反事实图像
↓
对比真实干预后的图像
↓
评估 FID、LPIPS、干预准确性
7. 代码实现
7.1 完整 CausalVAE
#include <bits/stdc++.h>
#include <torch/torch.h>
namespace {
// 因果机制
struct CausalMechanismImpl : torch::nn::Module {
torch::nn::Sequential net;
CausalMechanismImpl(int input_dim, int hidden_dim) {
net = torch::nn::Sequential(
torch::nn::Linear(input_dim, hidden_dim),
torch::nn::ReLU(),
torch::nn::Linear(hidden_dim, 1)
);
register_module("net", net);
}
torch::Tensor forward(torch::Tensor x) {
return net->forward(x);
}
};
TORCH_MODULE(CausalMechanism);
// 因果层
struct CausalLayerImpl : torch::nn::Module {
int n_vars, hidden_dim;
torch::Tensor adj; // 邻接矩阵
std::vector<int> topo_order;
std::vector<CausalMechanism> mechanisms;
CausalLayerImpl(int n_vars, int hidden_dim, torch::Tensor adj)
: n_vars(n_vars), hidden_dim(hidden_dim), adj(adj) {
topo_order = compute_topological_order(adj);
for (int i = 0; i < n_vars; ++i) {
int n_parents = adj.index({torch::indexing::Slice(), i}).sum().item<int>();
mechanisms.push_back(
CausalMechanism(n_parents + 1, hidden_dim)
);
register_module("mechanism_" + std::to_string(i), mechanisms.back());
}
}
torch::Tensor forward(torch::Tensor z) {
// z: (batch_size, n_vars)
torch::Tensor c = torch::zeros_like(z);
for (int idx : topo_order) {
std::vector<torch::Tensor> inputs;
inputs.push_back(z.index({torch::indexing::Slice(), idx}));
// 添加父节点的因果变量
for (int p = 0; p < n_vars; ++p) {
if (adj.index({p, idx}).item<float>() > 0.5f) {
inputs.push_back(c.index({torch::indexing::Slice(), p}));
}
}
torch::Tensor x = torch::cat(inputs, -1);
c.index_put_({torch::indexing::Slice(), idx},
mechanisms[idx]->forward(x).squeeze(-1));
}
return c;
}
};
TORCH_MODULE(CausalLayer);
// CausalVAE 主类
struct CausalVAEImpl : torch::nn::Module {
torch::nn::ModuleEncoder encoder;
CausalLayer causal_layer;
torch::nn::ModuleDecoder decoder;
CausalVAEImpl(/* 参数 */) {
encoder = register_module("encoder", ModuleEncoder(...));
causal_layer = register_module("causal_layer", CausalLayer(...));
decoder = register_module("decoder", ModuleDecoder(...));
}
std::pair<torch::Tensor, torch::Tensor> forward(torch::Tensor x) {
torch::Tensor z = encoder->forward(x);
torch::Tensor c = causal_layer->forward(z);
torch::Tensor x_recon = decoder->forward(c);
return {x_recon, c};
}
};
TORCH_MODULE(CausalVAE);
} // namespace7.2 训练循环
def train_causalvae(model, dataloader, optimizer, epochs):
for epoch in range(epochs):
for x, in dataloader:
# 前向
x_recon, c = model(x)
# 损失
recon_loss = F.mse_loss(x_recon, x)
kl_loss = compute_causal_kl(model, c)
loss = recon_loss + beta * kl_loss
# 反向
optimizer.zero_grad()
loss.backward()
optimizer.step()