概率图电路深度解析
1 引言
概率图电路(Probabilistic Graph Circuits, PGCs)1是2025年UAI会议上提出的新型深度生成模型框架,旨在解决在图结构数据上进行可追踪(tractable)概率推断的核心挑战。
传统的深度生成模型(如VAE、Flow、Diffusion)在图数据上面临以下问题:
- 边缘分布计算复杂度高
- 条件概率推断困难
- 归一化常数难以估计
PGCs通过将**概率电路(Probabilistic Circuits, PCs)的 tractable 特性与图神经网络(GNNs)**的表达能力相结合,实现了图上任意子图的精确概率推断。
2 形式化定义
2.1 图上的概率分布
设 为一个图,其中 ,。我们关心以下概率计算问题:
边缘概率:计算某个节点或边的边缘分布
条件概率:计算给定部分图结构时,剩余部分的分布
联合概率:计算整个图的概率
2.2 PGC的基本结构
定义:概率图电路 是一个有向无环图(DAG),满足以下条件:
- 叶子节点:表示图结构的基本随机变量(节点特征、边存在性)
- 内部节点:表示概率操作(求和、乘积、图卷积)
- 根节点:输出图的联合/边缘概率分布
形式化表达:
其中:
- 是权重(满足 )
- 是子电路
- 是对应的图子结构
2.3 与概率电路的关系
概率电路(如SPN、PSDD)是层次化的概率模型,具有以下 tractable 特性:
| 特性 | 描述 |
|---|---|
| 边缘概率 | 时间计算 |
| 条件概率 | 时间计算 |
| 子图概率 | 时间计算 |
PGCs将这些特性扩展到图结构数据上。
3 图操作层
3.1 图卷积操作
PGCs的核心创新在于引入图卷积层作为概率操作。设输入图为 ,输出为 :
节点级卷积:
边级卷积:
3.2 消息传递概率化
在PGC中,消息传递被重新解释为概率分布的变换:
节点消息:
其中 是基函数, 是可学习参数。
边消息:
3.3 池化与层次化
图池化操作将节点集合聚合为超节点:
常用的池化操作包括:
- 和池化:
- 均值池化:
- 注意力池化:,其中
4 训练算法
4.1 端到端梯度下降
PGC的损失函数定义为负对数似然:
梯度计算:
由于PGC是可微分的DAG,梯度通过反向传播计算:
4.2 EM算法
当PGC包含隐变量时,使用EM算法进行训练:
E步:计算隐变量的后验分布
M步:最大化期望对数似然
4.3 结构学习
PGC支持结构学习,即同时学习网络结构和参数:
分裂操作:将一个节点分为两个子节点
合并操作:将两个节点合并为一个
使用强化学习的结构搜索:
5 推断特性
5.1 精确边缘概率
PGC的核心优势在于能够精确计算任意子图的边缘概率:
单节点边缘概率:
其中 是节点 的指示变量。
子图边缘概率:
5.2 条件概率计算
给定部分观察,计算条件概率:
这在链接预测等任务中特别有用。
5.2 复杂度分析
时间复杂度:
| 操作 | 标准GNN | PGC |
|---|---|---|
| 前向传播 | $O( | \mathcal{V} |
| 边缘概率 | N/A | |
| 条件概率 | N/A | |
| 对数似然 | $O( | \mathcal{V} |
6 应用场景
6.1 链接预测
链接预测任务要求估计边 存在的概率:
其中 表示删除边 后的图。
实验结果(Cora数据集):
| 方法 | AUC | AP |
|---|---|---|
| GCN | 0.84 | 0.85 |
| GAT | 0.86 | 0.87 |
| VGAE | 0.91 | 0.92 |
| PGC | 0.94 | 0.95 |
6.2 图生成
PGC可以直接作为图生成模型,生成具有特定性质的图:
生成过程:
- 从根节点开始,按照概率分布采样
- 递归采样子节点
- 恢复图结构(节点类型、边连接)
6.3 异常检测
利用PGC的精确概率计算能力进行图异常检测:
异常分数:
异常图具有较低的概率密度,因此 较高。
7 与相关工作的对比
7.1 与概率电路(PCs)的对比
| 特性 | 概率电路 | 概率图电路 |
|---|---|---|
| 数据类型 | 表格数据、图像 | 图结构数据 |
| 变量独立性 | 局部分解 | 图依赖结构 |
| 图卷积 | 不支持 | 原生支持 |
| GNN兼容性 | 低 | 高 |
7.2 与变分GNN的对比
| 特性 | VGAE/VCT | PGC |
|---|---|---|
| 推断精度 | 变分近似 | 精确 |
| 边缘概率 | 变分估计 | 精确计算 |
| 条件概率 | 需要重参数化 | 直接计算 |
| 训练稳定性 | 中等 | 高 |
7.3 与GNN的对比
| 特性 | 标准GNN | PGC |
|---|---|---|
| 输出 | 点嵌入 | 概率分布 |
| 不确定性量化 | 需要额外技术 | 原生支持 |
| 可追踪推断 | 不支持 | 支持 |
| 生成能力 | 需额外解码器 | 原生支持 |
8 实现细节
8.1 PyTorch实现框架
import torch
import torch.nn as nn
from torch.distributions import Normal
class ProbabilisticGraphCircuit(nn.Module):
def __init__(self, node_dim, hidden_dim, num_layers):
super().__init__()
self.node_dim = node_dim
self.hidden_dim = hidden_dim
# 节点嵌入层
self.node_embedding = nn.Linear(node_dim, hidden_dim)
# 图卷积层
self.gc_layers = nn.ModuleList([
GraphConvLayer(hidden_dim, hidden_dim)
for _ in range(num_layers)
])
# 概率参数化层
self.prob_layers = nn.ModuleList([
SumProductLayer(hidden_dim)
for _ in range(num_layers)
])
def forward(self, x, edge_index):
# 节点嵌入
h = self.node_embedding(x)
# 层次化图卷积
for gc_layer, prob_layer in zip(self.gc_layers, self.prob_layers):
h = gc_layer(h, edge_index)
h = prob_layer(h)
return h
def log_prob(self, G):
"""计算图的精确对数概率"""
return self.forward(G.x, G.edge_index).sum()
def marginal_prob(self, node_idx, G):
"""计算节点边缘概率"""
with torch.no_grad():
log_prob = self.log_prob(G)
# 边缘化操作
marginal = self._marginalize(log_prob, node_idx)
return marginal.exp()
class GraphConvLayer(nn.Module):
"""图卷积层"""
def __init__(self, in_dim, out_dim):
super().__init__()
self.linear = nn.Linear(in_dim, out_dim)
self.edge_mlp = nn.Sequential(
nn.Linear(out_dim * 2, out_dim),
nn.ReLU(),
nn.Linear(out_dim, out_dim)
)
def forward(self, x, edge_index):
row, col = edge_index
# 边级特征聚合
edge_feat = self.edge_mlp(torch.cat([x[row], x[col]], dim=-1))
# 节点更新
out = x + self.linear(edge_feat.mean(dim=0, keepdim=True))
return torch.relu(out)
class SumProductLayer(nn.Module):
"""和-积操作层"""
def __init__(self, dim):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim) / dim)
def forward(self, x):
# 权重和操作
return (x * torch.softmax(self.weight, dim=-1)).sum(dim=-1, keepdim=True)8.2 训练循环
def train_pgc(model, dataloader, optimizer, num_epochs):
model.train()
losses = []
for epoch in range(num_epochs):
epoch_loss = 0.0
for batch in dataloader:
G = batch.to(device)
optimizer.zero_grad()
# 负对数似然损失
log_prob = model.log_prob(G)
loss = -log_prob
loss.backward()
optimizer.step()
epoch_loss += loss.item()
avg_loss = epoch_loss / len(dataloader)
losses.append(avg_loss)
if epoch % 10 == 0:
print(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
return losses9 总结与展望
概率图电路(PGCs)提供了一个统一的框架,将概率电路的可追踪推断特性与图神经网络的表示学习能力相结合。
核心贡献:
- 首次实现图上任意子图的精确概率计算
- 保持与标准GNN相当的计算效率
- 原生支持不确定性量化
未来方向:
- 扩展到动态图、异构图
- 与扩散模型结合进行图生成
- 应用于知识图谱推理和分子设计
参考文献
Footnotes
-
Probabilistic Graph Circuits: Deep Generative Models for Tractable Probabilistic Inference over Graphs. UAI 2025. arXiv:2503.12162. ↩