概率图电路深度解析

1 引言

概率图电路(Probabilistic Graph Circuits, PGCs)1是2025年UAI会议上提出的新型深度生成模型框架,旨在解决在图结构数据上进行可追踪(tractable)概率推断的核心挑战。

传统的深度生成模型(如VAE、Flow、Diffusion)在图数据上面临以下问题:

  • 边缘分布计算复杂度高
  • 条件概率推断困难
  • 归一化常数难以估计

PGCs通过将**概率电路(Probabilistic Circuits, PCs)的 tractable 特性与图神经网络(GNNs)**的表达能力相结合,实现了图上任意子图的精确概率推断。

2 形式化定义

2.1 图上的概率分布

为一个图,其中 。我们关心以下概率计算问题:

边缘概率:计算某个节点或边的边缘分布

条件概率:计算给定部分图结构时,剩余部分的分布

联合概率:计算整个图的概率

2.2 PGC的基本结构

定义:概率图电路 是一个有向无环图(DAG),满足以下条件:

  1. 叶子节点:表示图结构的基本随机变量(节点特征、边存在性)
  2. 内部节点:表示概率操作(求和、乘积、图卷积)
  3. 根节点:输出图的联合/边缘概率分布

形式化表达

其中:

  • 是权重(满足
  • 是子电路
  • 是对应的图子结构

2.3 与概率电路的关系

概率电路(如SPN、PSDD)是层次化的概率模型,具有以下 tractable 特性:

特性描述
边缘概率 时间计算
条件概率 时间计算
子图概率 时间计算

PGCs将这些特性扩展到图结构数据上。

3 图操作层

3.1 图卷积操作

PGCs的核心创新在于引入图卷积层作为概率操作。设输入图为 ,输出为

节点级卷积

边级卷积

3.2 消息传递概率化

在PGC中,消息传递被重新解释为概率分布的变换

节点消息

其中 是基函数, 是可学习参数。

边消息

3.3 池化与层次化

图池化操作将节点集合聚合为超节点:

常用的池化操作包括:

  • 和池化
  • 均值池化
  • 注意力池化,其中

4 训练算法

4.1 端到端梯度下降

PGC的损失函数定义为负对数似然:

梯度计算

由于PGC是可微分的DAG,梯度通过反向传播计算:

4.2 EM算法

当PGC包含隐变量时,使用EM算法进行训练:

E步:计算隐变量的后验分布

M步:最大化期望对数似然

4.3 结构学习

PGC支持结构学习,即同时学习网络结构和参数:

分裂操作:将一个节点分为两个子节点

合并操作:将两个节点合并为一个

使用强化学习的结构搜索

5 推断特性

5.1 精确边缘概率

PGC的核心优势在于能够精确计算任意子图的边缘概率:

单节点边缘概率

其中 是节点 的指示变量。

子图边缘概率

5.2 条件概率计算

给定部分观察,计算条件概率:

这在链接预测等任务中特别有用。

5.2 复杂度分析

时间复杂度

操作标准GNNPGC
前向传播$O(\mathcal{V}
边缘概率N/A
条件概率N/A
对数似然$O(\mathcal{V}

6 应用场景

6.1 链接预测

链接预测任务要求估计边 存在的概率:

其中 表示删除边 后的图。

实验结果(Cora数据集):

方法AUCAP
GCN0.840.85
GAT0.860.87
VGAE0.910.92
PGC0.940.95

6.2 图生成

PGC可以直接作为图生成模型,生成具有特定性质的图:

生成过程

  1. 从根节点开始,按照概率分布采样
  2. 递归采样子节点
  3. 恢复图结构(节点类型、边连接)

6.3 异常检测

利用PGC的精确概率计算能力进行图异常检测

异常分数

异常图具有较低的概率密度,因此 较高。

7 与相关工作的对比

7.1 与概率电路(PCs)的对比

特性概率电路概率图电路
数据类型表格数据、图像图结构数据
变量独立性局部分解图依赖结构
图卷积不支持原生支持
GNN兼容性

7.2 与变分GNN的对比

特性VGAE/VCTPGC
推断精度变分近似精确
边缘概率变分估计精确计算
条件概率需要重参数化直接计算
训练稳定性中等

7.3 与GNN的对比

特性标准GNNPGC
输出点嵌入概率分布
不确定性量化需要额外技术原生支持
可追踪推断不支持支持
生成能力需额外解码器原生支持

8 实现细节

8.1 PyTorch实现框架

import torch
import torch.nn as nn
from torch.distributions import Normal
 
class ProbabilisticGraphCircuit(nn.Module):
    def __init__(self, node_dim, hidden_dim, num_layers):
        super().__init__()
        self.node_dim = node_dim
        self.hidden_dim = hidden_dim
        
        # 节点嵌入层
        self.node_embedding = nn.Linear(node_dim, hidden_dim)
        
        # 图卷积层
        self.gc_layers = nn.ModuleList([
            GraphConvLayer(hidden_dim, hidden_dim) 
            for _ in range(num_layers)
        ])
        
        # 概率参数化层
        self.prob_layers = nn.ModuleList([
            SumProductLayer(hidden_dim) 
            for _ in range(num_layers)
        ])
        
    def forward(self, x, edge_index):
        # 节点嵌入
        h = self.node_embedding(x)
        
        # 层次化图卷积
        for gc_layer, prob_layer in zip(self.gc_layers, self.prob_layers):
            h = gc_layer(h, edge_index)
            h = prob_layer(h)
        
        return h
    
    def log_prob(self, G):
        """计算图的精确对数概率"""
        return self.forward(G.x, G.edge_index).sum()
    
    def marginal_prob(self, node_idx, G):
        """计算节点边缘概率"""
        with torch.no_grad():
            log_prob = self.log_prob(G)
            # 边缘化操作
            marginal = self._marginalize(log_prob, node_idx)
        return marginal.exp()
 
class GraphConvLayer(nn.Module):
    """图卷积层"""
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.linear = nn.Linear(in_dim, out_dim)
        self.edge_mlp = nn.Sequential(
            nn.Linear(out_dim * 2, out_dim),
            nn.ReLU(),
            nn.Linear(out_dim, out_dim)
        )
        
    def forward(self, x, edge_index):
        row, col = edge_index
        # 边级特征聚合
        edge_feat = self.edge_mlp(torch.cat([x[row], x[col]], dim=-1))
        # 节点更新
        out = x + self.linear(edge_feat.mean(dim=0, keepdim=True))
        return torch.relu(out)
 
class SumProductLayer(nn.Module):
    """和-积操作层"""
    def __init__(self, dim):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim) / dim)
        
    def forward(self, x):
        # 权重和操作
        return (x * torch.softmax(self.weight, dim=-1)).sum(dim=-1, keepdim=True)

8.2 训练循环

def train_pgc(model, dataloader, optimizer, num_epochs):
    model.train()
    losses = []
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        
        for batch in dataloader:
            G = batch.to(device)
            
            optimizer.zero_grad()
            
            # 负对数似然损失
            log_prob = model.log_prob(G)
            loss = -log_prob
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(dataloader)
        losses.append(avg_loss)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss = {avg_loss:.4f}")
    
    return losses

9 总结与展望

概率图电路(PGCs)提供了一个统一的框架,将概率电路的可追踪推断特性与图神经网络的表示学习能力相结合。

核心贡献

  • 首次实现图上任意子图的精确概率计算
  • 保持与标准GNN相当的计算效率
  • 原生支持不确定性量化

未来方向

  • 扩展到动态图、异构图
  • 与扩散模型结合进行图生成
  • 应用于知识图谱推理和分子设计

参考文献

Footnotes

  1. Probabilistic Graph Circuits: Deep Generative Models for Tractable Probabilistic Inference over Graphs. UAI 2025. arXiv:2503.12162.