概述
概率图电路(Probabilistic Graph Circuits, PGC)是一种在图结构数据上实现可处理概率推断的生成模型框架。1
传统的图生成模型(如GNN-based生成器、GraphRNN)面临以下挑战:
- 推断不精确: 只能近似推断,难以计算精确概率
- 计算复杂度高: 大规模图上的推断代价昂贵
- 缺乏理论保证: 缺乏PAC学习等理论保证
PGC的核心思想是:
将概率电路的可处理推断能力扩展到图结构数据,同时保持图生成模型的表达能力。
这一框架使得:
- 图上的精确边际推断可以在多项式时间内完成
- 图结构的条件概率可以精确计算
- 图生成模型具备可验证的推断能力
1. 问题背景
1.1 图上的概率推断挑战
在图结构数据上进行概率推断面临独特挑战:
| 挑战 | 描述 | 影响 |
|---|---|---|
| 结构异质性 | 节点和边的类型多样 | 统一建模困难 |
| 规模复杂性 | 节点数指数级组合空间 | 推断困难 |
| 依赖复杂性 | 节点间存在长程依赖 | 条件独立假设失效 |
| 动态性 | 图结构随时间变化 | 时序建模复杂 |
1.2 现有方法的局限性
| 方法 | 优点 | 缺点 |
|---|---|---|
| GNN+VAE | 端到端可微 | 推断近似,ELBO下界 |
| GraphRNN | 自回归生成 | 无法精确计算概率 |
| EBMs | 灵活建模 | 推断需要采样 |
| NFs for Graphs | 可逆变换 | 结构约束复杂 |
1.3 PGC的解决方案
PGC通过以下设计解决上述问题:
┌─────────────────────────────────────────────────────────────┐
│ 概率图电路 (PGC) │
├─────────────────────────────────────────────────────────────┤
│ │
│ 输入图G ┌─────────────────────┐ 精确概率 │
│ │ │ 图电路结构学习 │ │ │
│ ▼ │ (节点/边的PC分解) │ ▼ │
│ ┌──────┐ └─────────────────────┘ ┌──────┐ │
│ │ 编码 │ │ │ │ │
│ └──────┘ ▼ │ 输出 │ │
│ ┌──────┐ ┌─────────────────────┐ │ 概率 │ │
│ │ 边PC │ │ 可处理推断引擎 │───→│ │ │
│ └──────┘ │ (边际/条件/MAP) │ │ P(G) │ │
│ ┌──────┐ └─────────────────────┘ │ │ │
│ │ 节点PC│ │ └──────┘ │
│ └──────┘ │ │
│ ┌──────┐ ▼ │
│ │ 拓扑PC│ ┌─────────────────────┐ │
│ └──────┘ │ 生成/推断双模式 │ │
│ └─────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
2. 形式化定义
2.1 图的概率表示
设 为一个属性图,其中:
- :节点集合,
- :边集合
- :节点特征
- :边特征
定义(图的概率分布): PGC定义图上的概率分布为:
其中 表示节点 的父节点集合。
2.2 图电路结构
定义(图电路): 图电路是一个有向无环图 ,其中:
-
每个节点 是以下类型之一:
- 输入节点: 对应图的元素(节点/边/邻接)
- 乘积节点: 实现边的条件独立
- 求和节点: 实现边际化
- 特征节点: 处理节点/边特征
-
可处理条件: 对于任意节点 ,其子树的计算复杂度为 ,其中 是常数
2.3 分解性质
PGC利用图的稀疏性和局部性实现高效推断:
定理(局部分解): 设 是一个图电路, 是一个图。若 满足:
- 乘积节点只连接相邻节点
- 求和节点实现局部边际化
则边际推断 可以在 时间内完成。
3. 核心架构
3.1 节点级电路
节点级电路建模节点特征的分布:
class NodeCircuit(nn.Module):
"""节点级概率电路"""
def __init__(self, feature_dim, hidden_dim, num_mixtures):
super().__init__()
self.feature_dim = feature_dim
self.num_mixtures = num_mixtures
# 节点特征编码
self.encoder = nn.Sequential(
nn.Linear(feature_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# 混合模型参数
self.mixture_weights = nn.Linear(hidden_dim, num_mixtures)
self.means = nn.Linear(hidden_dim, num_mixtures * feature_dim)
self.log_stds = nn.Linear(hidden_dim, num_mixtures * feature_dim)
def forward(self, x_v):
"""
计算节点特征的密度
x_v: (batch, feature_dim)
"""
h = self.encoder(x_v)
# 混合高斯参数
pi = F.softmax(self.mixture_weights(h), dim=-1)
mu = self.means(h).view(-1, self.num_mixtures, self.feature_dim)
log_std = self.log_stds(h).view(-1, self.num_mixtures, self.feature_dim)
# 密度计算
log_probs = []
for k in range(self.num_mixtures):
diff = x_v.unsqueeze(1) - mu[:, k:k+1, :]
log_prob = -0.5 * ((diff ** 2) / (torch.exp(2 * log_std[:, k:k+1, :]) + 1e-8))
log_prob = log_prob.sum(dim=-1)
log_probs.append(log_prob + torch.log(pi[:, k:k+1] + 1e-8))
log_probs = torch.cat(log_probs, dim=1)
return torch.logsumexp(log_probs, dim=1)3.2 边级电路
边级电路建模边存在性和特征的分布:
class EdgeCircuit(nn.Module):
"""边级概率电路"""
def __init__(self, node_dim, edge_dim, hidden_dim):
super().__init__()
# 边存在性网络
self.edge_exists = nn.Sequential(
nn.Linear(node_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
# 边特征网络
self.edge_features = nn.Sequential(
nn.Linear(node_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, edge_dim)
)
def forward(self, x_u, x_v, edge_exists_prior=0.1):
"""
计算边的概率
x_u, x_v: (batch, node_dim)
"""
# 边存在概率
combined = torch.cat([x_u, x_v], dim=-1)
p_exists = self.edge_exists(combined).squeeze(-1)
# 边特征分布
edge_feat = self.edge_features(combined)
# 边的总体概率(存在性 × 特征)
return p_exists, edge_feat3.3 拓扑电路
拓扑电路建模图结构的分布:
class TopologyCircuit(nn.Module):
"""拓扑级概率电路"""
def __init__(self, node_dim, hidden_dim, max_degree):
super().__init__()
self.max_degree = max_degree
# 度分布建模
self.degree_net = nn.Sequential(
nn.Linear(node_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, max_degree + 1),
nn.Softmax(dim=-1)
)
# 邻接模式建模
self.adj_pattern = nn.Sequential(
nn.Linear(node_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
def compute_topology_prob(self, x, adj):
"""
计算拓扑结构的概率
x: (n, node_dim)
adj: (n, n) 邻接矩阵
"""
n = x.size(0)
log_prob = 0
# 度分布概率
for i in range(n):
degree_i = adj[i].sum()
if degree_i <= self.max_degree:
p_degree = self.degree_net(x[i])[int(degree_i)]
log_prob += torch.log(p_degree + 1e-8)
# 邻接模式概率
for i in range(n):
for j in range(i+1, n):
combined = torch.cat([x[i], x[j]], dim=-1)
p_edge = self.adj_pattern(combined).squeeze(-1)
if adj[i, j] > 0:
log_prob += torch.log(p_edge + 1e-8)
else:
log_prob += torch.log(1 - p_edge + 1e-8)
return log_prob3.4 完整PGC模型
class ProbabilisticGraphCircuit(nn.Module):
"""完整概率图电路"""
def __init__(self, node_dim, edge_dim, hidden_dim,
num_mixtures, max_degree):
super().__init__()
self.node_circuit = NodeCircuit(node_dim, hidden_dim, num_mixtures)
self.edge_circuit = EdgeCircuit(hidden_dim, edge_dim, hidden_dim)
self.topology_circuit = TopologyCircuit(hidden_dim, hidden_dim, max_degree)
# 共享编码器
self.encoder = nn.Sequential(
nn.Linear(node_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
def forward(self, x, adj):
"""
计算图的联合对数似然
Args:
x: (batch, n, node_dim) 节点特征
adj: (batch, n, n) 邻接矩阵
Returns:
log_prob: (batch,) 图的对数概率
"""
batch_size, n, _ = x.size()
# 编码节点特征
h = self.encoder(x) # (batch, n, hidden_dim)
# 节点级概率
node_log_prob = 0
for i in range(n):
node_log_prob += self.node_circuit(x[:, i, :])
# 边级概率
edge_log_prob = 0
for i in range(n):
for j in range(i+1, n):
p_exists, _ = self.edge_circuit(h[:, i, :], h[:, j, :])
# 对batch中每个样本
for b in range(batch_size):
if adj[b, i, j] > 0:
edge_log_prob += torch.log(p_exists[b] + 1e-8)
else:
edge_log_prob += torch.log(1 - p_exists[b] + 1e-8)
# 拓扑概率
topo_log_prob = 0
for b in range(batch_size):
topo_log_prob += self.topology_circuit(
h[b], adj[b]
)
return node_log_prob + edge_log_prob + topo_log_prob
def marginal_inference(self, x, observed_edges=None):
"""
边际推断
P(G | x) = Σ_{E'} P(G' | x) 其中 E' 遍历未观测边
"""
# 简化实现:利用局部性近似
pass
def conditional_inference(self, x, adj_observed):
"""
条件推断
P(G | x, adj_observed) ∝ P(G, adj_observed | x)
"""
# 计算观测部分的概率
log_prob = self.forward(x, adj_observed)
return torch.exp(log_prob)
def map_inference(self, x, num_steps=100):
"""
MAP推断:找到最可能的图结构
"""
# 贪心搜索
adj_pred = torch.zeros_like(x[:, :, 0])
for step in range(num_steps):
best_score = -float('inf')
best_edge = None
for i in range(x.size(1)):
for j in range(i+1, x.size(1)):
if adj_pred[i, j] == 0:
# 尝试添加边
adj_pred[i, j] = 1
adj_pred[j, i] = 1
score = self.forward(x, adj_pred.unsqueeze(0))
if score > best_score:
best_score = score
best_edge = (i, j)
# 撤销
adj_pred[i, j] = 0
adj_pred[j, i] = 0
if best_edge is not None:
adj_pred[best_edge[0], best_edge[1]] = 1
adj_pred[best_edge[1], best_edge[0]] = 1
return adj_pred4. 可处理推断算法
4.1 精确边际推断
PGC支持图上精确边际推断:
def exact_marginal_inference(pgc, x, query_nodes):
"""
精确边际推断
目标:P(nodes_in_query | rest_of_graph)
利用图电路的局部性实现多项式时间计算
"""
# 1. 识别查询相关的子电路
sub_circuit = pgc.extract_subcircuit(query_nodes)
# 2. 局部边际化
log_prob = 0
for node in query_nodes:
log_prob += pgc.node_circuit(x[:, node, :])
for i, u in enumerate(query_nodes):
for j, v in enumerate(query_nodes):
if i < j:
p_exists, _ = pgc.edge_circuit(
pgc.encoder(x[:, u, :]),
pgc.encoder(x[:, v, :])
)
log_prob += torch.log(p_exists + 1e-8)
return torch.exp(log_prob)
def marginal_likelihood(pgc, x):
"""
计算图的边际似然 P(x)
积分掉所有可能的图结构
"""
n = x.size(1)
log_marginal = 0
# 利用分解性质
for i in range(n):
# 节点边际
log_marginal += pgc.node_circuit(x[:, i, :])
# 边际化(近似)
for i in range(n):
for j in range(i+1, n):
# 计算边存在的期望
h_i = pgc.encoder(x[:, i, :])
h_j = pgc.encoder(x[:, j, :])
p_exists = pgc.edge_circuit(h_i, h_j)[0]
# log(1 - P(edge)) 近似边际化
log_marginal += torch.log(1 - p_exists + 1e-8)
return torch.exp(log_marginal)4.2 条件推断
def conditional_inference(pgc, x, evidence_adj):
"""
条件推断
P(G | x, evidence) ∝ P(G, evidence | x)
"""
# 观测边作为证据
log_prob = pgc.forward(x, evidence_adj)
# 归一化
Z = compute_partition_function(pgc, x)
return torch.exp(log_prob - Z)
def compute_partition_function(pgc, x, num_samples=1000):
"""
计算配分函数 Z = Σ_G P(G | x)
使用重要性采样近似
"""
n = x.size(1)
samples = []
weights = []
for _ in range(num_samples):
# 从提议分布采样
adj_sample = torch.rand(n, n) > 0.5
adj_sample = (adj_sample + adj_sample.t()) / 2 # 对称化
adj_sample.fill_diagonal_(0)
# 计算权重
log_w = pgc.forward(x, adj_sample.unsqueeze(0))
samples.append(adj_sample)
weights.append(torch.exp(log_w))
weights = torch.stack(weights)
weights = weights / weights.sum()
return weights.sum().item()4.3 MAP推断
def map_inference_greedy(pgc, x, num_iterations=100):
"""
贪心MAP推断
"""
n = x.size(1)
adj_hat = torch.zeros(n, n)
for _ in range(num_iterations):
best_delta = 0
best_edge = None
for i in range(n):
for j in range(i+1, n):
# 当前边的贡献
if adj_hat[i, j] == 0:
adj_hat[i, j] = 1
adj_hat[j, i] = 1
delta = pgc.forward(x, adj_hat.unsqueeze(0))
if delta > best_delta:
best_delta = delta
best_edge = (i, j)
adj_hat[i, j] = 0
adj_hat[j, i] = 0
if best_edge is not None:
adj_hat[best_edge[0], best_edge[1]] = 1
adj_hat[best_edge[1], best_edge[0]] = 1
return adj_hat
def map_inference_lp(pgc, x):
"""
线性规划松弛MAP推断
"""
# 将离散优化松弛为连续优化
pass # 详细实现略5. 与GNN的关系
5.1 表达能力对比
| 维度 | PGC | GNN |
|---|---|---|
| 图生成能力 | ✓ 概率模型 | 需要额外生成器 |
| 精确推断 | ✓ 多项式时间 | ✗ 需要近似 |
| 概率校准 | ✓ 原生支持 | ✗ 需要校准 |
| 可解释性 | ✓ 因果路径 | 中等 |
| 表达能力 | 中等 | ✓ 强 |
5.2 融合方法
PGC可以与GNN融合以结合两者优势:
class GNNPGCFusion(nn.Module):
"""GNN与PGC的融合模型"""
def __init__(self, gnn_module, pgc_module):
super().__init__()
self.gnn = gnn_module
self.pgc = pgc_module
def forward(self, x, adj):
"""
融合前向传播
"""
# 1. GNN提取节点表示
h = self.gnn(x, adj)
# 2. PGC建模结构分布
log_prob = self.pgc(h, adj)
return log_prob, h
def gnn_guided_generation(self, x, num_steps=50):
"""
GNN引导的图生成
"""
adj = torch.zeros_like(x[:, :, 0])
for _ in range(num_steps):
h = self.gnn(x, adj)
# PGC评分
scores = self.pgc.score_edges(h, adj)
# 选择最高分边
top_edge = scores.argmax()
i, j = top_edge // adj.size(0), top_edge % adj.size(0)
adj[i, j] = 1
adj[j, i] = 1
return adj6. 应用场景
6.1 分子图生成
class MolecularGraphCircuit(ProbabilisticGraphCircuit):
"""
分子图的概率生成模型
"""
def __init__(self, atom_types, bond_types):
super().__init__(
node_dim=len(atom_types),
edge_dim=len(bond_types),
hidden_dim=256,
num_mixtures=8,
max_degree=4 # 碳的最大度数为4
)
self.atom_types = atom_types
self.bond_types = bond_types
def generate_molecule(self, num_atoms=20, temperature=1.0):
"""
生成新分子
"""
# 初始化
x = torch.zeros(1, num_atoms, len(self.atom_types))
adj = torch.zeros(num_atoms, num_atoms)
# 生成原子类型
for i in range(num_atoms):
probs = torch.softmax(
torch.randn(len(self.atom_types)) / temperature,
dim=0
)
atom_idx = torch.multinomial(probs, 1)
x[0, i, atom_idx] = 1
# 生成边
for i in range(num_atoms):
for j in range(i+1, num_atoms):
bond_probs = torch.softmax(
torch.randn(len(self.bond_types)) / temperature,
dim=0
)
bond_idx = torch.multinomial(bond_probs, 1)
# 根据原子类型限制键
if self.is_valid_bond(x[0, i], x[0, j], bond_idx):
adj[i, j] = bond_idx
adj[j, i] = bond_idx
return x, adj
def is_valid_bond(self, atom1, atom2, bond_idx):
"""化学有效性检查"""
# 实现化学规则
return True6.2 知识图谱补全
class KnowledgeGraphCircuit(ProbabilisticGraphCircuit):
"""
知识图谱的概率补全模型
"""
def __init__(self, num_entities, num_relations, embed_dim):
super().__init__(
node_dim=embed_dim,
edge_dim=num_relations,
hidden_dim=128,
num_mixtures=4,
max_degree=100
)
self.num_entities = num_entities
self.num_relations = num_relations
self.entity_embeddings = nn.Embedding(num_entities, embed_dim)
self.relation_embeddings = nn.Embedding(num_relations, embed_dim)
def complete_triples(self, head, relation, candidates):
"""
补全缺失的尾实体
P(tail | head, relation)
"""
h = self.entity_embeddings(head)
r = self.relation_embeddings(relation)
scores = []
for tail in candidates:
t = self.entity_embeddings(tail)
# 计算三元组分数
score = self.compute_triple_score(h, r, t)
scores.append(score)
scores = torch.stack(scores)
probs = F.softmax(scores, dim=0)
return probs
def compute_triple_score(self, h, r, t):
"""TransE风格的评分函数"""
return -torch.norm(h + r - t, dim=-1)
def predict_relation(self, head, tail):
"""预测头尾实体间的关系"""
h = self.entity_embeddings(head)
t = self.entity_embeddings(tail)
scores = []
for r in range(self.num_relations):
r_emb = self.relation_embeddings(r)
score = self.compute_triple_score(h, r_emb, t)
scores.append(score)
scores = torch.stack(scores)
return F.softmax(scores, dim=0)7. 理论分析
7.1 表达能力
定理(PGC表达能力): 设 是一个有 个节点的图电路,则 可以表示任何定义在 (节点图的空间)上的分布,满足:
- 分解性质:
- 局部性约束:每个因子只依赖于 个节点
7.2 计算复杂度
| 操作 | 精确复杂度 | 近似复杂度 |
|---|---|---|
| 联合概率 | ||
| 边际概率 | ||
| 条件概率 | ||
| MAP推断 | NP难 |
7.3 学习保证
PAC学习框架: 设训练集 从真实分布 中采样,则PGC的经验风险:
满足:
其中 是参数数量, 是样本数量。
8. 实现与优化
8.1 高效实现
class OptimizedPGC(ProbabilisticGraphCircuit):
"""
优化版PGC
"""
def __init__(self, *args, use_sparse=True, **kwargs):
super().__init__(*args, **kwargs)
self.use_sparse = use_sparse
def sparse_forward(self, x, adj):
"""
稀疏矩阵优化的前向传播
"""
# 转换为稀疏表示
if self.use_sparse:
adj_sparse = adj.to_sparse()
h = self.encoder(x)
# 节点概率
node_log_prob = self.node_circuit(x).sum(dim=1)
# 边概率(利用稀疏性)
if self.use_sparse:
# 只计算观测边
edge_log_prob = self.compute_sparse_edge_prob(h, adj_sparse)
else:
edge_log_prob = self.compute_dense_edge_prob(h, adj)
return node_log_prob + edge_log_prob
def compute_sparse_edge_prob(self, h, adj_sparse):
"""稀疏边概率计算"""
# 获取边索引
indices = adj_sparse.indices() # (2, num_edges)
# 边两端节点的特征
h_src = h[:, indices[0], :]
h_dst = h[:, indices[1], :]
# 边存在概率
p_exists, _ = self.edge_circuit(h_src, h_dst)
return torch.log(p_exists + 1e-8).sum()8.2 批处理优化
def batch_marginal_inference(pgc, x_batch, adj_batch):
"""
批量边际推断
"""
batch_size = x_batch.size(0)
n = x_batch.size(1)
# 编码
h_batch = pgc.encoder(x_batch)
# 节点概率(批量)
node_log_probs = pgc.node_circuit(x_batch) # (batch, n)
node_log_probs = node_log_probs.sum(dim=1) # (batch,)
# 边概率(批量)
edge_log_probs = []
for b in range(batch_size):
adj = adj_batch[b]
h = h_batch[b]
# 提取上三角(避免重复)
i, j = torch.triu_indices(n, n, offset=1)
h_i = h[i]
h_j = h[j]
p_exists, _ = pgc.edge_circuit(h_i, h_j)
# 乘以邻接矩阵
edge_exists = adj[i, j]
log_prob = torch.where(
edge_exists > 0,
torch.log(p_exists + 1e-8),
torch.log(1 - p_exists + 1e-8)
)
edge_log_probs.append(log_prob.sum())
edge_log_probs = torch.stack(edge_log_probs)
return node_log_probs + edge_log_probs9. 局限性与未来方向
9.1 当前局限
| 问题 | 描述 | 影响 |
|---|---|---|
| 表达能力限制 | 局部分解限制全局依赖建模 | 无法捕获某些复杂模式 |
| 结构学习困难 | 图结构学习复杂 | 需要领域知识 |
| 规模化挑战 | 大图计算开销 | 难以处理大规模图 |
9.2 未来方向
- 层次化PGC: 多尺度图建模
- 动态PGC: 时序图建模
- PGC-GNN融合: 结合两者的优势
- 端到端学习: 从数据自动学习图电路结构
10. 参考
相关文档: 神经概率电路 | 几何感知概率电路 | GNN概率推断
Footnotes
-
Papez et al. (2025): Probabilistic Graph Circuits: Deep Generative Models for Tractable Probabilistic Inference over Graphs. UAI 2025. ↩