概述

概率图电路(Probabilistic Graph Circuits, PGC)是一种在图结构数据上实现可处理概率推断的生成模型框架。1

传统的图生成模型(如GNN-based生成器、GraphRNN)面临以下挑战:

  • 推断不精确: 只能近似推断,难以计算精确概率
  • 计算复杂度高: 大规模图上的推断代价昂贵
  • 缺乏理论保证: 缺乏PAC学习等理论保证

PGC的核心思想是:

将概率电路的可处理推断能力扩展到图结构数据,同时保持图生成模型的表达能力。

这一框架使得:

  • 图上的精确边际推断可以在多项式时间内完成
  • 图结构的条件概率可以精确计算
  • 图生成模型具备可验证的推断能力

1. 问题背景

1.1 图上的概率推断挑战

在图结构数据上进行概率推断面临独特挑战:

挑战描述影响
结构异质性节点和边的类型多样统一建模困难
规模复杂性节点数指数级组合空间推断困难
依赖复杂性节点间存在长程依赖条件独立假设失效
动态性图结构随时间变化时序建模复杂

1.2 现有方法的局限性

方法优点缺点
GNN+VAE端到端可微推断近似,ELBO下界
GraphRNN自回归生成无法精确计算概率
EBMs灵活建模推断需要采样
NFs for Graphs可逆变换结构约束复杂

1.3 PGC的解决方案

PGC通过以下设计解决上述问题:

┌─────────────────────────────────────────────────────────────┐
│                    概率图电路 (PGC)                          │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  输入图G    ┌─────────────────────┐    精确概率              │
│      │      │   图电路结构学习      │        │                │
│      ▼      │  (节点/边的PC分解)   │        ▼                │
│  ┌──────┐   └─────────────────────┘    ┌──────┐            │
│  │ 编码 │            │                 │      │            │
│  └──────┘            ▼                 │ 输出 │            │
│  ┌──────┐   ┌─────────────────────┐    │ 概率 │            │
│  │ 边PC │   │   可处理推断引擎     │───→│      │            │
│  └──────┘   │ (边际/条件/MAP)      │    │ P(G) │            │
│  ┌──────┐   └─────────────────────┘    │      │            │
│  │ 节点PC│            │                 └──────┘            │
│  └──────┘            │                                      │
│  ┌──────┐            ▼                                      │
│  │ 拓扑PC│   ┌─────────────────────┐                        │
│  └──────┘   │   生成/推断双模式    │                        │
│              └─────────────────────┘                        │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2. 形式化定义

2.1 图的概率表示

为一个属性图,其中:

  • :节点集合,
  • :边集合
  • :节点特征
  • :边特征

定义(图的概率分布): PGC定义图上的概率分布为:

其中 表示节点 的父节点集合。

2.2 图电路结构

定义(图电路): 图电路是一个有向无环图 ,其中:

  1. 每个节点 是以下类型之一:

    • 输入节点: 对应图的元素(节点/边/邻接)
    • 乘积节点: 实现边的条件独立
    • 求和节点: 实现边际化
    • 特征节点: 处理节点/边特征
  2. 可处理条件: 对于任意节点 ,其子树的计算复杂度为 ,其中 是常数

2.3 分解性质

PGC利用图的稀疏性和局部性实现高效推断:

定理(局部分解): 设 是一个图电路, 是一个图。若 满足:

  1. 乘积节点只连接相邻节点
  2. 求和节点实现局部边际化

则边际推断 可以在 时间内完成。


3. 核心架构

3.1 节点级电路

节点级电路建模节点特征的分布:

class NodeCircuit(nn.Module):
    """节点级概率电路"""
    def __init__(self, feature_dim, hidden_dim, num_mixtures):
        super().__init__()
        self.feature_dim = feature_dim
        self.num_mixtures = num_mixtures
        
        # 节点特征编码
        self.encoder = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # 混合模型参数
        self.mixture_weights = nn.Linear(hidden_dim, num_mixtures)
        self.means = nn.Linear(hidden_dim, num_mixtures * feature_dim)
        self.log_stds = nn.Linear(hidden_dim, num_mixtures * feature_dim)
    
    def forward(self, x_v):
        """
        计算节点特征的密度
        x_v: (batch, feature_dim)
        """
        h = self.encoder(x_v)
        
        # 混合高斯参数
        pi = F.softmax(self.mixture_weights(h), dim=-1)
        mu = self.means(h).view(-1, self.num_mixtures, self.feature_dim)
        log_std = self.log_stds(h).view(-1, self.num_mixtures, self.feature_dim)
        
        # 密度计算
        log_probs = []
        for k in range(self.num_mixtures):
            diff = x_v.unsqueeze(1) - mu[:, k:k+1, :]
            log_prob = -0.5 * ((diff ** 2) / (torch.exp(2 * log_std[:, k:k+1, :]) + 1e-8))
            log_prob = log_prob.sum(dim=-1)
            log_probs.append(log_prob + torch.log(pi[:, k:k+1] + 1e-8))
        
        log_probs = torch.cat(log_probs, dim=1)
        return torch.logsumexp(log_probs, dim=1)

3.2 边级电路

边级电路建模边存在性和特征的分布:

class EdgeCircuit(nn.Module):
    """边级概率电路"""
    def __init__(self, node_dim, edge_dim, hidden_dim):
        super().__init__()
        
        # 边存在性网络
        self.edge_exists = nn.Sequential(
            nn.Linear(node_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
        # 边特征网络
        self.edge_features = nn.Sequential(
            nn.Linear(node_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, edge_dim)
        )
    
    def forward(self, x_u, x_v, edge_exists_prior=0.1):
        """
        计算边的概率
        x_u, x_v: (batch, node_dim)
        """
        # 边存在概率
        combined = torch.cat([x_u, x_v], dim=-1)
        p_exists = self.edge_exists(combined).squeeze(-1)
        
        # 边特征分布
        edge_feat = self.edge_features(combined)
        
        # 边的总体概率(存在性 × 特征)
        return p_exists, edge_feat

3.3 拓扑电路

拓扑电路建模图结构的分布:

class TopologyCircuit(nn.Module):
    """拓扑级概率电路"""
    def __init__(self, node_dim, hidden_dim, max_degree):
        super().__init__()
        self.max_degree = max_degree
        
        # 度分布建模
        self.degree_net = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, max_degree + 1),
            nn.Softmax(dim=-1)
        )
        
        # 邻接模式建模
        self.adj_pattern = nn.Sequential(
            nn.Linear(node_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def compute_topology_prob(self, x, adj):
        """
        计算拓扑结构的概率
        x: (n, node_dim)
        adj: (n, n) 邻接矩阵
        """
        n = x.size(0)
        log_prob = 0
        
        # 度分布概率
        for i in range(n):
            degree_i = adj[i].sum()
            if degree_i <= self.max_degree:
                p_degree = self.degree_net(x[i])[int(degree_i)]
                log_prob += torch.log(p_degree + 1e-8)
        
        # 邻接模式概率
        for i in range(n):
            for j in range(i+1, n):
                combined = torch.cat([x[i], x[j]], dim=-1)
                p_edge = self.adj_pattern(combined).squeeze(-1)
                
                if adj[i, j] > 0:
                    log_prob += torch.log(p_edge + 1e-8)
                else:
                    log_prob += torch.log(1 - p_edge + 1e-8)
        
        return log_prob

3.4 完整PGC模型

class ProbabilisticGraphCircuit(nn.Module):
    """完整概率图电路"""
    def __init__(self, node_dim, edge_dim, hidden_dim, 
                 num_mixtures, max_degree):
        super().__init__()
        
        self.node_circuit = NodeCircuit(node_dim, hidden_dim, num_mixtures)
        self.edge_circuit = EdgeCircuit(hidden_dim, edge_dim, hidden_dim)
        self.topology_circuit = TopologyCircuit(hidden_dim, hidden_dim, max_degree)
        
        # 共享编码器
        self.encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
    
    def forward(self, x, adj):
        """
        计算图的联合对数似然
        
        Args:
            x: (batch, n, node_dim) 节点特征
            adj: (batch, n, n) 邻接矩阵
        
        Returns:
            log_prob: (batch,) 图的对数概率
        """
        batch_size, n, _ = x.size()
        
        # 编码节点特征
        h = self.encoder(x)  # (batch, n, hidden_dim)
        
        # 节点级概率
        node_log_prob = 0
        for i in range(n):
            node_log_prob += self.node_circuit(x[:, i, :])
        
        # 边级概率
        edge_log_prob = 0
        for i in range(n):
            for j in range(i+1, n):
                p_exists, _ = self.edge_circuit(h[:, i, :], h[:, j, :])
                
                # 对batch中每个样本
                for b in range(batch_size):
                    if adj[b, i, j] > 0:
                        edge_log_prob += torch.log(p_exists[b] + 1e-8)
                    else:
                        edge_log_prob += torch.log(1 - p_exists[b] + 1e-8)
        
        # 拓扑概率
        topo_log_prob = 0
        for b in range(batch_size):
            topo_log_prob += self.topology_circuit(
                h[b], adj[b]
            )
        
        return node_log_prob + edge_log_prob + topo_log_prob
    
    def marginal_inference(self, x, observed_edges=None):
        """
        边际推断
        
        P(G | x) = Σ_{E'} P(G' | x) 其中 E' 遍历未观测边
        """
        # 简化实现:利用局部性近似
        pass
    
    def conditional_inference(self, x, adj_observed):
        """
        条件推断
        
        P(G | x, adj_observed) ∝ P(G, adj_observed | x)
        """
        # 计算观测部分的概率
        log_prob = self.forward(x, adj_observed)
        return torch.exp(log_prob)
    
    def map_inference(self, x, num_steps=100):
        """
        MAP推断:找到最可能的图结构
        """
        # 贪心搜索
        adj_pred = torch.zeros_like(x[:, :, 0])
        
        for step in range(num_steps):
            best_score = -float('inf')
            best_edge = None
            
            for i in range(x.size(1)):
                for j in range(i+1, x.size(1)):
                    if adj_pred[i, j] == 0:
                        # 尝试添加边
                        adj_pred[i, j] = 1
                        adj_pred[j, i] = 1
                        
                        score = self.forward(x, adj_pred.unsqueeze(0))
                        
                        if score > best_score:
                            best_score = score
                            best_edge = (i, j)
                        
                        # 撤销
                        adj_pred[i, j] = 0
                        adj_pred[j, i] = 0
            
            if best_edge is not None:
                adj_pred[best_edge[0], best_edge[1]] = 1
                adj_pred[best_edge[1], best_edge[0]] = 1
        
        return adj_pred

4. 可处理推断算法

4.1 精确边际推断

PGC支持图上精确边际推断:

def exact_marginal_inference(pgc, x, query_nodes):
    """
    精确边际推断
    
    目标:P(nodes_in_query | rest_of_graph)
    
    利用图电路的局部性实现多项式时间计算
    """
    # 1. 识别查询相关的子电路
    sub_circuit = pgc.extract_subcircuit(query_nodes)
    
    # 2. 局部边际化
    log_prob = 0
    for node in query_nodes:
        log_prob += pgc.node_circuit(x[:, node, :])
    
    for i, u in enumerate(query_nodes):
        for j, v in enumerate(query_nodes):
            if i < j:
                p_exists, _ = pgc.edge_circuit(
                    pgc.encoder(x[:, u, :]),
                    pgc.encoder(x[:, v, :])
                )
                log_prob += torch.log(p_exists + 1e-8)
    
    return torch.exp(log_prob)
 
 
def marginal_likelihood(pgc, x):
    """
    计算图的边际似然 P(x)
    
    积分掉所有可能的图结构
    """
    n = x.size(1)
    log_marginal = 0
    
    # 利用分解性质
    for i in range(n):
        # 节点边际
        log_marginal += pgc.node_circuit(x[:, i, :])
    
    # 边际化(近似)
    for i in range(n):
        for j in range(i+1, n):
            # 计算边存在的期望
            h_i = pgc.encoder(x[:, i, :])
            h_j = pgc.encoder(x[:, j, :])
            p_exists = pgc.edge_circuit(h_i, h_j)[0]
            
            # log(1 - P(edge)) 近似边际化
            log_marginal += torch.log(1 - p_exists + 1e-8)
    
    return torch.exp(log_marginal)

4.2 条件推断

def conditional_inference(pgc, x, evidence_adj):
    """
    条件推断
    
    P(G | x, evidence) ∝ P(G, evidence | x)
    """
    # 观测边作为证据
    log_prob = pgc.forward(x, evidence_adj)
    
    # 归一化
    Z = compute_partition_function(pgc, x)
    
    return torch.exp(log_prob - Z)
 
 
def compute_partition_function(pgc, x, num_samples=1000):
    """
    计算配分函数 Z = Σ_G P(G | x)
    
    使用重要性采样近似
    """
    n = x.size(1)
    samples = []
    weights = []
    
    for _ in range(num_samples):
        # 从提议分布采样
        adj_sample = torch.rand(n, n) > 0.5
        adj_sample = (adj_sample + adj_sample.t()) / 2  # 对称化
        adj_sample.fill_diagonal_(0)
        
        # 计算权重
        log_w = pgc.forward(x, adj_sample.unsqueeze(0))
        samples.append(adj_sample)
        weights.append(torch.exp(log_w))
    
    weights = torch.stack(weights)
    weights = weights / weights.sum()
    
    return weights.sum().item()

4.3 MAP推断

def map_inference_greedy(pgc, x, num_iterations=100):
    """
    贪心MAP推断
    """
    n = x.size(1)
    adj_hat = torch.zeros(n, n)
    
    for _ in range(num_iterations):
        best_delta = 0
        best_edge = None
        
        for i in range(n):
            for j in range(i+1, n):
                # 当前边的贡献
                if adj_hat[i, j] == 0:
                    adj_hat[i, j] = 1
                    adj_hat[j, i] = 1
                    
                    delta = pgc.forward(x, adj_hat.unsqueeze(0))
                    
                    if delta > best_delta:
                        best_delta = delta
                        best_edge = (i, j)
                    
                    adj_hat[i, j] = 0
                    adj_hat[j, i] = 0
        
        if best_edge is not None:
            adj_hat[best_edge[0], best_edge[1]] = 1
            adj_hat[best_edge[1], best_edge[0]] = 1
    
    return adj_hat
 
 
def map_inference_lp(pgc, x):
    """
    线性规划松弛MAP推断
    """
    # 将离散优化松弛为连续优化
    pass  # 详细实现略

5. 与GNN的关系

5.1 表达能力对比

维度PGCGNN
图生成能力✓ 概率模型需要额外生成器
精确推断✓ 多项式时间✗ 需要近似
概率校准✓ 原生支持✗ 需要校准
可解释性✓ 因果路径中等
表达能力中等✓ 强

5.2 融合方法

PGC可以与GNN融合以结合两者优势:

class GNNPGCFusion(nn.Module):
    """GNN与PGC的融合模型"""
    def __init__(self, gnn_module, pgc_module):
        super().__init__()
        self.gnn = gnn_module
        self.pgc = pgc_module
    
    def forward(self, x, adj):
        """
        融合前向传播
        """
        # 1. GNN提取节点表示
        h = self.gnn(x, adj)
        
        # 2. PGC建模结构分布
        log_prob = self.pgc(h, adj)
        
        return log_prob, h
    
    def gnn_guided_generation(self, x, num_steps=50):
        """
        GNN引导的图生成
        """
        adj = torch.zeros_like(x[:, :, 0])
        
        for _ in range(num_steps):
            h = self.gnn(x, adj)
            
            # PGC评分
            scores = self.pgc.score_edges(h, adj)
            
            # 选择最高分边
            top_edge = scores.argmax()
            i, j = top_edge // adj.size(0), top_edge % adj.size(0)
            
            adj[i, j] = 1
            adj[j, i] = 1
        
        return adj

6. 应用场景

6.1 分子图生成

class MolecularGraphCircuit(ProbabilisticGraphCircuit):
    """
    分子图的概率生成模型
    """
    def __init__(self, atom_types, bond_types):
        super().__init__(
            node_dim=len(atom_types),
            edge_dim=len(bond_types),
            hidden_dim=256,
            num_mixtures=8,
            max_degree=4  # 碳的最大度数为4
        )
        self.atom_types = atom_types
        self.bond_types = bond_types
    
    def generate_molecule(self, num_atoms=20, temperature=1.0):
        """
        生成新分子
        """
        # 初始化
        x = torch.zeros(1, num_atoms, len(self.atom_types))
        adj = torch.zeros(num_atoms, num_atoms)
        
        # 生成原子类型
        for i in range(num_atoms):
            probs = torch.softmax(
                torch.randn(len(self.atom_types)) / temperature, 
                dim=0
            )
            atom_idx = torch.multinomial(probs, 1)
            x[0, i, atom_idx] = 1
        
        # 生成边
        for i in range(num_atoms):
            for j in range(i+1, num_atoms):
                bond_probs = torch.softmax(
                    torch.randn(len(self.bond_types)) / temperature,
                    dim=0
                )
                bond_idx = torch.multinomial(bond_probs, 1)
                
                # 根据原子类型限制键
                if self.is_valid_bond(x[0, i], x[0, j], bond_idx):
                    adj[i, j] = bond_idx
                    adj[j, i] = bond_idx
        
        return x, adj
    
    def is_valid_bond(self, atom1, atom2, bond_idx):
        """化学有效性检查"""
        # 实现化学规则
        return True

6.2 知识图谱补全

class KnowledgeGraphCircuit(ProbabilisticGraphCircuit):
    """
    知识图谱的概率补全模型
    """
    def __init__(self, num_entities, num_relations, embed_dim):
        super().__init__(
            node_dim=embed_dim,
            edge_dim=num_relations,
            hidden_dim=128,
            num_mixtures=4,
            max_degree=100
        )
        self.num_entities = num_entities
        self.num_relations = num_relations
        self.entity_embeddings = nn.Embedding(num_entities, embed_dim)
        self.relation_embeddings = nn.Embedding(num_relations, embed_dim)
    
    def complete_triples(self, head, relation, candidates):
        """
        补全缺失的尾实体
        
        P(tail | head, relation)
        """
        h = self.entity_embeddings(head)
        r = self.relation_embeddings(relation)
        
        scores = []
        for tail in candidates:
            t = self.entity_embeddings(tail)
            
            # 计算三元组分数
            score = self.compute_triple_score(h, r, t)
            scores.append(score)
        
        scores = torch.stack(scores)
        probs = F.softmax(scores, dim=0)
        
        return probs
    
    def compute_triple_score(self, h, r, t):
        """TransE风格的评分函数"""
        return -torch.norm(h + r - t, dim=-1)
    
    def predict_relation(self, head, tail):
        """预测头尾实体间的关系"""
        h = self.entity_embeddings(head)
        t = self.entity_embeddings(tail)
        
        scores = []
        for r in range(self.num_relations):
            r_emb = self.relation_embeddings(r)
            score = self.compute_triple_score(h, r_emb, t)
            scores.append(score)
        
        scores = torch.stack(scores)
        return F.softmax(scores, dim=0)

7. 理论分析

7.1 表达能力

定理(PGC表达能力): 设 是一个有 个节点的图电路,则 可以表示任何定义在 节点图的空间)上的分布,满足:

  1. 分解性质:
  2. 局部性约束:每个因子只依赖于 个节点

7.2 计算复杂度

操作精确复杂度近似复杂度
联合概率
边际概率
条件概率
MAP推断NP难

7.3 学习保证

PAC学习框架: 设训练集 从真实分布 中采样,则PGC的经验风险:

满足:

其中 是参数数量, 是样本数量。


8. 实现与优化

8.1 高效实现

class OptimizedPGC(ProbabilisticGraphCircuit):
    """
    优化版PGC
    """
    def __init__(self, *args, use_sparse=True, **kwargs):
        super().__init__(*args, **kwargs)
        self.use_sparse = use_sparse
    
    def sparse_forward(self, x, adj):
        """
        稀疏矩阵优化的前向传播
        """
        # 转换为稀疏表示
        if self.use_sparse:
            adj_sparse = adj.to_sparse()
        
        h = self.encoder(x)
        
        # 节点概率
        node_log_prob = self.node_circuit(x).sum(dim=1)
        
        # 边概率(利用稀疏性)
        if self.use_sparse:
            # 只计算观测边
            edge_log_prob = self.compute_sparse_edge_prob(h, adj_sparse)
        else:
            edge_log_prob = self.compute_dense_edge_prob(h, adj)
        
        return node_log_prob + edge_log_prob
    
    def compute_sparse_edge_prob(self, h, adj_sparse):
        """稀疏边概率计算"""
        # 获取边索引
        indices = adj_sparse.indices()  # (2, num_edges)
        
        # 边两端节点的特征
        h_src = h[:, indices[0], :]
        h_dst = h[:, indices[1], :]
        
        # 边存在概率
        p_exists, _ = self.edge_circuit(h_src, h_dst)
        
        return torch.log(p_exists + 1e-8).sum()

8.2 批处理优化

def batch_marginal_inference(pgc, x_batch, adj_batch):
    """
    批量边际推断
    """
    batch_size = x_batch.size(0)
    n = x_batch.size(1)
    
    # 编码
    h_batch = pgc.encoder(x_batch)
    
    # 节点概率(批量)
    node_log_probs = pgc.node_circuit(x_batch)  # (batch, n)
    node_log_probs = node_log_probs.sum(dim=1)  # (batch,)
    
    # 边概率(批量)
    edge_log_probs = []
    for b in range(batch_size):
        adj = adj_batch[b]
        h = h_batch[b]
        
        # 提取上三角(避免重复)
        i, j = torch.triu_indices(n, n, offset=1)
        
        h_i = h[i]
        h_j = h[j]
        
        p_exists, _ = pgc.edge_circuit(h_i, h_j)
        
        # 乘以邻接矩阵
        edge_exists = adj[i, j]
        log_prob = torch.where(
            edge_exists > 0,
            torch.log(p_exists + 1e-8),
            torch.log(1 - p_exists + 1e-8)
        )
        
        edge_log_probs.append(log_prob.sum())
    
    edge_log_probs = torch.stack(edge_log_probs)
    
    return node_log_probs + edge_log_probs

9. 局限性与未来方向

9.1 当前局限

问题描述影响
表达能力限制局部分解限制全局依赖建模无法捕获某些复杂模式
结构学习困难图结构学习复杂需要领域知识
规模化挑战大图计算开销难以处理大规模图

9.2 未来方向

  1. 层次化PGC: 多尺度图建模
  2. 动态PGC: 时序图建模
  3. PGC-GNN融合: 结合两者的优势
  4. 端到端学习: 从数据自动学习图电路结构

10. 参考


相关文档: 神经概率电路 | 几何感知概率电路 | GNN概率推断

Footnotes

  1. Papez et al. (2025): Probabilistic Graph Circuits: Deep Generative Models for Tractable Probabilistic Inference over Graphs. UAI 2025.