概述

概率电路的学习涉及两个核心问题:参数学习结构学习1

  • 参数学习:在固定电路结构下,学习和节点的权重以及叶节点的分布参数
  • 结构学习:从数据中自动发现最优的电路拓扑结构

概率电路的一个关键优势是支持端到端可微分训练,这使得我们可以用标准的深度学习优化器来学习电路参数。


参数学习

目标函数

给定数据集 ,参数学习的目标是最大化对数似然:

EM算法

期望最大化(EM)算法是学习概率电路的经典方法。2

E步:计算后验

对于每个和节点 ,计算每个子节点的后验权重

其中:

  • 是和节点 个子节点的权重
  • 是子节点在输入 下的输出

M步:更新权重

更新和节点的权重为后验权重的平均值:

EM算法收敛性

EM算法保证对数似然单调递增,且在有限步内收敛。收敛条件:

梯度下降方法

概率电路支持直接梯度计算,这使得我们可以使用标准的随机梯度下降(SGD)方法。3

对数似然的梯度

对于参数 ,梯度为:

和节点的梯度

对于和节点的权重

这个梯度形式与策略梯度类似,存在高方差问题。

降低方差:基线技术

def gradient_with_baseline(node, x, baseline):
    """使用基线降低梯度方差"""
    log_prob = node.log_prob(x)
    advantage = log_prob - baseline
    return advantage * node.gradient(x)

半朴素的EM初始化

动机:随机初始化可能导致局部最优

方法

  1. 首先运行半朴素SPN构建算法获得初始结构
  2. 使用小数据集进行预训练
  3. 之后进行微调或继续EM迭代

结构学习

结构学习是概率电路中最具挑战性的部分,因为搜索空间巨大。

电路结构学习问题

给定数据 ,寻找最优电路结构

其中 是模型得分(如似然、BIC、AIC), 是结构复杂度惩罚。

自顶向下方法

ID-SPN算法

Gens & Domingos (2012) 提出的经典方法。4

步骤

  1. 聚类:将变量分成互不相交的子集
  2. 递归分解:对每个子集递归构建子电路
  3. 选择操作:在和与积之间选择
def learn_spn(data, scope):
    # 1. 如果作用域足够小,直接构建叶节点
    if len(scope) <= threshold:
        return build_leaf_node(data, scope)
    
    # 2. 聚类变量
    partitions = cluster_variables(data, scope)
    
    # 3. 递归构建子电路
    children = [learn_spn(data, part) for part in partitions]
    
    # 4. 选择操作:和 vs 积
    if score_sum(children) > score_product(children):
        return SumNode(children)
    else:
        return ProductNode(children)

评分函数

自底向上方法

积节点学习

识别哪些变量应该组合在积节点下:

互信息方法

如果 较低,则 应该分开。

和节点学习

识别应该如何混合不同组件:

梯度差异分析:在训练过程中观察不同子电路的重要性变化。

从贝叶斯网络编译

一种高效的方法是将已有的概率图模型编译为概率电路。

变量消除编译

对于树结构的贝叶斯网络,可以直接编译为电路:

def compile_bayesian_network(bn):
    """将贝叶斯网络编译为概率电路"""
    # 1. 确定变量的消元顺序
    ordering = get_elimination_order(bn)
    
    # 2. 自底向上构建电路
    circuits = {}
    for var in reversed(ordering):
        # 收集相关的因子
        factors = get_factors(bn, var)
        # 构建和-积结构
        circuit = build_circuit(factors)
        circuits[var] = circuit
    
    return circuits[ordering[0]]

端到端结构学习

现代方法支持与参数一起的端到端结构学习。

可微分的结构参数

class DifferentiableCircuit(nn.Module):
    def __init__(self, n_vars, hidden_dim):
        super().__init__()
        # 可学习的结构参数
        self.structure_logits = nn.Parameter(torch.randn(n_vars, hidden_dim))
    
    def forward(self, x):
        # 软化结构选择
        attn = F.softmax(self.structure_logits, dim=-1)
        # ... 构建电路 ...
        return output

Gumbel-Softmax技巧

用于在离散结构选择上计算梯度:

其中 是温度参数。


电路简化与优化

电路简化规则

冗余和节点消除

如果和节点只有一个子节点,则消除该节点:

    +        →     child
    |
  child

冗余积节点消除

如果积节点只有一个子节点,则消除该节点。

子电路合并

如果两个和节点的子电路相同,则合并:

    +         +
   /|\       /|\
  a b c     a b c
  ↓          ↓
[相同]    [相同]

电路压缩

参数剪枝

移除不重要的和分支:

def prune_sum_node(node, threshold=0.01):
    """剪除权重小于阈值的分支"""
    mask = node.weights > threshold
    new_children = [c for c, m in zip(node.children, mask) if m]
    new_weights = node.weights[mask]
    new_weights /= new_weights.sum()  # 重新归一化
    return SumNode(new_children, new_weights)

实现库与工具

SPFlow

一个基于PyTorch的概率电路库,支持自动微分和GPU加速。5

from spflow import SPNFlow
 
# 定义SPN结构
spn = SPNFlow(
    num_sums=10,
    num_gauss=20,
    num_classes=2
)
 
# 学习
spn.learn(data)
 
# 推理
log_prob = spn.log_prob(data)

cirkit

专注于精确推断的概率电路库,支持GPU加速。

import cirkit
 
# 从网络编译
circuit = cirkit.compile("data.bif")
 
# 精确查询
result = circuit.query(evidence={"X": True}, target="Y")

学习算法的选择指南

场景推荐方法
小规模数据EM算法
大规模数据随机梯度下降
已知结构先验从先验编译
完全无监督ID-SPN
需要端到端训练可微分结构学习

挑战与开放问题

1. 表达能力 vs 效率的权衡

更深的电路表达能力更强,但评估更慢。如何在效率和表达能力之间找到平衡?

2. 局部最优问题

EM和梯度下降都容易陷入局部最优。需要更好的初始化策略或全局优化方法。

3. 规模化问题

现有算法在变量数超过100时面临挑战。需要更高效的近似方法或层次化方法。

4. 与深度学习的集成

如何更好地将概率电路与神经网络结合,利用两者的优势?


参考


相关主题

Footnotes

  1. Vergari et al. (2021). “Visualizing and Understanding Sum-Product Networks”.

  2. Peharz et al. (2014). “Learning Sum-Product Networks with Latent Variables”.

  3. Cheng & Bolen (2018). “Differentiable Learning of Probabilistic Circuits”.

  4. Gens & Domingos (2012). “Learning the Structure of Sum-Product Networks”. ICML 2012.

  5. Molina et al. (2019). “SPFlow: An Easy and Extensible Library for Deep Probabilistic Learning”.