概述
神经概率电路(Neural Probabilistic Circuits, NPC)是一种将神经网络的可组合性与概率电路的可推断性相结合的框架。1
传统的概率电路(如SPN)在结构学习后参数固定,难以利用大规模数据;而神经网络虽然表达能力强,但推断往往需要近似方法(如变分推断、采样)。NPC的核心思想是:
让神经网络成为概率电路的基本构建块,同时保持精确推断的能力。
这一框架使得:
- 模块化的神经网络组件可以直接嵌入概率电路
- 关键的概率计算(边际、条件)可以在多项式时间内精确完成
- 推理路径完全透明可追踪
1. 问题背景
1.1 概率电路的局限性
传统概率电路(如Sum-Product Networks, SPN)面临以下挑战:
| 问题 | 描述 | 影响 |
|---|---|---|
| 表达能力有限 | 固定结构的PC难以捕捉复杂模式 | 需要手动设计或复杂结构学习 |
| 参数效率低 | 大量参数但利用不充分 | 难以规模化 |
| 缺乏组合性 | 模块难以复用 | 构建复杂模型困难 |
1.2 神经网络的局限性
神经网络虽然强大,但存在:
| 问题 | 描述 | 影响 |
|---|---|---|
| 推断不透明 | 近似推断(VI/MC Dropout) | 不确定性量化不准确 |
| 黑盒性质 | 决策过程不可解释 | 在高风险应用中受限 |
| 计算成本高 | 采样/变分推断开销大 | 实时应用困难 |
1.3 NPC的解决思路
NPC通过以下设计解决上述问题:
┌─────────────────────────────────────────────────────────────┐
│ 神经概率电路 (NPC) │
├─────────────────────────────────────────────────────────────┤
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ 神经网络 │ │ 神经网络 │ │ 神经网络 │ │
│ │ 模块 │ + │ 模块 │ + │ 模块 │ │
│ │ (可学习) │ │ (可学习) │ │ (可学习) │ │
│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │ │
│ └──────────────────┼──────────────────┘ │
│ ▼ │
│ ┌────────────────┐ │
│ │ 概率电路 │ │
│ │ 推理层 │ │
│ │ (精确边际/条件) │ │
│ └────────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────┐ │
│ │ 可解释输出 │ │
│ │ (推理路径) │ │
│ └────────────────┘ │
└─────────────────────────────────────────────────────────────┘
2. 神经概率电路框架
2.1 形式化定义
定义(神经概率电路):
神经概率电路是一个有向无环图 ,其中每个节点 是以下两种类型之一:
- 输入节点 : 对应随机变量或神经网络编码
- 计算节点 : 实现特定的概率操作
形式上,NPC定义了一个复合函数:
其中 是可学习参数, 是输入空间。
2.2 节点类型
NPC包含以下核心节点类型:
| 节点类型 | 功能 | 数学表示 |
|---|---|---|
| 神经网络节点 | 特征提取/变换 | |
| 乘积节点 | 条件独立 | |
| 求和节点 | 边际化 | |
| 证据节点 | 条件概率 |
2.3 组合性原则
NPC的核心设计原则是组合性(Compositionality):
复合操作保持可处理性: 如果两个子电路都是可处理的,则它们的组合也是可处理的。
数学表达:
若 和 是可处理的,则:
- (乘积组合)可处理
- (求和组合)可处理
- (条件组合)可处理
3. 核心机制
3.1 符号-连续混合表示
NPC的一个关键创新是统一表示符号随机变量和连续变量:
class SymbolicVariable:
"""符号变量节点"""
def __init__(self, name, domain):
self.name = name
self.domain = domain # 如 {0, 1} 或 {A, B, C}
def __repr__(self):
return f"Symbolic({self.name} ∈ {self.domain})"
class ContinuousVariable:
"""连续变量节点"""
def __init__(self, name, encoder_net):
self.name = name
self.encoder = encoder_net # 神经网络编码器
def encode(self, x):
# 将连续输入编码为概率分布参数
params = self.encoder(x)
return Distribution(params)3.2 神经网络模块集成
NPC允许各种神经网络架构作为模块嵌入:
class NeuralModule(nn.Module):
"""神经网络模块基类"""
def __init__(self, input_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, x):
return self.net(x)
def to_probabilistic(self, x):
"""转换为概率参数"""
params = self.forward(x)
return Bernoulli(logits=params) # 或其他分布
class NPCModule(nn.Module):
"""NPC模块容器"""
def __init__(self, module):
super().__init__()
self.module = module
self.is_conditioning = False
self.conditioned_value = None
def condition(self, value):
"""设置条件变量"""
self.is_conditioning = True
self.conditioned_value = value
return self
def marginalize(self):
"""边际化处理"""
self.is_conditioning = False
self.conditioned_value = None3.3 精确推断算法
NPC支持多种精确推断操作:
3.3.1 边际推断
def marginal_inference(circuit, query_vars, evidence={}):
"""
精确边际推断
复杂度: O(2^{|query_vars|}) 但利用PC结构优化
"""
# 1. 简化电路(固定证据)
simplified = simplify(circuit, evidence)
# 2. 自底向上传播
bottom_up(simplified)
# 3. 提取边际
return extract_marginal(simplified, query_vars)3.3.2 条件推断
def conditional_inference(circuit, query, evidence):
"""
P(query | evidence) = P(query, evidence) / P(evidence)
"""
# 联合分布
joint = joint_marginal(circuit, query + list(evidence.keys()))
# 证据边际
marginal_evidence = marginal_inference(circuit, list(evidence.keys()))
# 条件概率
return joint / marginal_evidence3.3.3 MAP推断
def map_inference(circuit, evidence={}):
"""
找到最可能的配置
"""
# 利用PC的结构进行高效搜索
return viterbi_search(circuit, evidence)4. 逻辑推理集成
4.1 命题逻辑融合
NPC可以将命题逻辑规则与概率推断结合:
class LogicalRule(nn.Module):
"""逻辑规则模块"""
def __init__(self, antecedent, consequent, weight=1.0):
super().__init__()
self.antecedent = antecedent # 前提变量列表
self.consequent = consequent # 结论变量
self.weight = weight
def forward(self, assignments):
"""
评估规则的真值
assignments: {var: value} 字典
"""
antecedent_val = all(
assignments.get(var) == val
for var, val in self.antecedent
)
consequent_val = assignments.get(self.consequent[0])
# 返回逻辑蕴含的真值
return float(antecedent_val == consequent_val)
class LogicAwareCircuit(nn.Module):
"""集成逻辑规则的NPC"""
def __init__(self):
super().__init__()
self.nn_modules = nn.ModuleList([...]) # 神经网络模块
self.logic_rules = nn.ModuleList([...]) # 逻辑规则
def soft_logic_loss(self, data):
"""
软逻辑损失:鼓励满足逻辑约束
"""
total_loss = 0
for rule in self.logic_rules:
assignments = self.extract_assignments(data)
rule_violation = 1 - rule(assignments)
total_loss += self.rule.weight * rule_violation
return total_loss4.2 一阶逻辑近似
对于一阶逻辑,NPC提供近似方法:
class FirstOrderLogicApproximation:
"""一阶逻辑的NPC近似"""
def __init__(self, pc_circuit):
self.pc = pc_circuit
self.grounding_cache = {}
def forall(self, variable, domain, formula):
"""
∀x: Formula(x) 的概率
"""
total_prob = 0
for x in domain:
prob = self.pc.evaluate(formula.replace(variable, x))
total_prob += prob
return total_prob / len(domain)
def exists(self, variable, domain, formula):
"""
∃x: Formula(x) 的概率
"""
max_prob = 0
for x in domain:
prob = self.pc.evaluate(formula.replace(variable, x))
max_prob = max(max_prob, prob)
return max_prob4.3 规则挖掘与学习
NPC可以从数据中自动学习逻辑规则:
class RuleLearner(nn.Module):
"""从NPC中挖掘逻辑规则"""
def __init__(self, circuit, min_support=0.1, min_confidence=0.8):
super().__init__()
self.circuit = circuit
self.min_support = min_support
self.min_confidence = min_confidence
def mine_rules(self, data):
"""
挖掘频繁模式和关联规则
"""
# 1. 发现频繁项集
frequent_itemsets = self.find_frequent(data)
# 2. 生成关联规则
rules = []
for itemset in frequent_itemsets:
for subset in proper_subsets(itemset):
antecedent = subset
consequent = itemset - subset
# 计算置信度
p_antecedent = self.circuit.marginal(list(antecedent))
p_joint = self.circuit.marginal(list(itemset))
confidence = p_joint / (p_antecedent + 1e-10)
if confidence >= self.min_confidence:
rules.append({
'antecedent': antecedent,
'consequent': consequent,
'confidence': confidence,
'support': p_joint
})
return rules5. 可解释性机制
5.1 透明推理路径
NPC的核心优势之一是完全透明的推理过程:
class ExplainableNPC(nn.Module):
"""可解释的NPC"""
def forward_with_explanation(self, x):
"""
返回预测及其解释
"""
# 前向传播并记录路径
path_trace = []
value_trace = []
def hook_fn(module, input, output):
path_trace.append({
'module': module.__class__.__name__,
'module_id': id(module),
'inputs': [i.clone() for i in input],
'outputs': output.clone()
})
value_trace.append(output)
# 注册hook
handles = []
for module in self.modules():
if isinstance(module, (SummingNode, ProductNode)):
handles.append(module.register_forward_hook(hook_fn))
# 前向传播
output = self.forward(x)
# 移除hook
for h in handles:
h.remove()
return output, path_trace, value_trace
def generate_explanation(self, x, prediction):
"""
生成自然语言解释
"""
_, path_trace, value_trace = self.forward_with_explanation(x)
explanation_parts = []
for step in path_trace:
if isinstance(step['module'], NeuralModule):
explanation_parts.append(
f"神经网络提取特征: {step['outputs'].shape}"
)
elif isinstance(step['module'], SumsNode):
explanation_parts.append(
f"加权组合 {len(step['outputs'])} 个候选项"
)
return " → ".join(explanation_parts)5.2 因果链追踪
class CausalTrace:
"""因果链追踪"""
def trace_causal_path(self, circuit, query_var, target_var):
"""
追踪从 query_var 到 target_var 的因果路径
"""
paths = []
def dfs(node, current_path):
if node == target_var:
paths.append(current_path.copy())
return
for child in node.children:
current_path.append(child)
dfs(child, current_path)
current_path.pop()
dfs(query_var, [query_var])
return paths
def compute_path_effects(self, circuit, paths):
"""
计算各因果路径的效应
"""
effects = {}
for path in paths:
effect = self.compute_path_effect(circuit, path)
effects[tuple(path)] = effect
return effects5.3 置信度校准
class CalibratedNPC(nn.Module):
"""置信度校准的NPC"""
def __init__(self, circuit, temperature=1.0):
super().__init__()
self.circuit = circuit
self.temperature = nn.Parameter(torch.tensor(temperature))
def calibrate(self, val_loader):
"""
使用验证集校准置信度
"""
optimizer = torch.optim.LBFGS([self.temperature], lr=0.01)
def closure():
optimizer.zero_grad()
total_nll = 0
for x, y in val_loader:
logits = self.forward(x) / self.temperature
nll = F.cross_entropy(logits, y)
total_nll += nll
total_nll.backward()
return total_nll
optimizer.step(closure)
return self.temperature.item()
def predict(self, x):
"""
返回校准后的预测和置信度
"""
logits = self.forward(x) / self.temperature
probs = F.softmax(logits, dim=-1)
predictions = probs.argmax(dim=-1)
confidences = probs.max(dim=-1).values
return predictions, confidences6. 应用场景
6.1 组合优化
NPC可以用于求解组合优化问题,同时提供最优性保证:
class CombinatorialNPC:
"""
组合优化NPC
示例: 旅行商问题(TSP)
"""
def __init__(self, num_cities):
self.num_cities = num_cities
self.circuit = self.build_tsp_circuit()
def build_tsp_circuit(self):
"""
构建TSP的NPC表示
"""
# 决策变量: X[i,j] = 1 表示从城市i到城市j
circuit = ProbabilisticCircuit()
# 添加路径约束节点
for i in range(self.num_cities):
# 每个城市恰好离开一次
out_sum = SumNode(f"out_{i}")
for j in range(self.num_cities):
if i != j:
out_sum.add_child(
ProductNode(f"edge_{i}_{j}"),
weight=1.0
)
# 每个城市恰好进入一次
in_sum = SumNode(f"in_{j}")
for i in range(self.num_cities):
if i != j:
in_sum.add_child(
ProductNode(f"edge_{i}_{j}"),
weight=1.0
)
# 添加距离成本
self.add_distance_costs()
return circuit
def solve(self, distances):
"""
求解TSP
返回: 最优路径及其概率/成本
"""
# 固定距离证据
evidence = self.set_distance_evidence(distances)
# MAP推断
best_assignment = self.circuit.map_inference(evidence)
# 提取路径
path = self.extract_path(best_assignment)
cost = self.compute_cost(path, distances)
return path, cost6.2 知识图谱推理
class KnowledgeGraphNPC(nn.Module):
"""
知识图谱推理NPC
"""
def __init__(self, num_entities, num_relations):
super().__init__()
self.entity_embeddings = nn.Embedding(num_entities, embed_dim)
self.relation_embeddings = nn.Embedding(num_relations, embed_dim)
self.npc = self.build_npc()
def build_npc(self):
"""
构建知识图谱的NPC
"""
circuit = ProbabilisticCircuit()
# 关系先验
relation_prior = SumNode("relation_prior")
# 实体条件概率
for r in range(self.num_relations):
condition_node = ProductNode(f"cond_{r}")
condition_node.add_factor(
self.relation_embeddings(r)
)
relation_prior.add_child(condition_node, weight=1.0)
circuit.add_root(relation_prior)
return circuit
def predict_link(self, head, relation, tail):
"""
预测链接存在概率
P(tail | head, relation)
"""
evidence = {
'head': head,
'relation': relation
}
# 条件推断
probs = self.npc.conditional(
query='tail',
evidence=evidence
)
return probs[tail]
def learn_rules(self, kg_data):
"""
从知识图谱学习逻辑规则
"""
rule_learner = RuleLearner(self.npc)
rules = rule_learner.mine_rules(kg_data)
# 过滤高质量规则
high_quality = [
r for r in rules
if r['confidence'] > 0.9 and r['support'] > 0.1
]
return high_quality6.3 医疗诊断
class MedicalDiagnosisNPC(nn.Module):
"""
医疗诊断NPC
支持因果干预和解释
"""
def __init__(self, symptoms, diseases):
super().__init__()
self.symptoms = symptoms
self.diseases = diseases
self.circuit = self.build_diagnostic_circuit()
def build_diagnostic_circuit(self):
"""
构建诊断电路
结构: 症状 → 疾病 → 治疗
"""
circuit = ProbabilisticCircuit()
# 疾病节点
disease_nodes = {}
for disease in self.diseases:
disease_nodes[disease] = SumNode(f"disease_{disease}")
# 先验概率
disease_nodes[disease].set_weights(self.get_prior(disease))
# 症状节点(条件于疾病)
symptom_nodes = {}
for symptom in self.symptoms:
symptom_node = SumNode(f"symptom_{symptom}")
for disease in self.diseases:
cond_prob = self.get_likelihood(symptom, disease)
symptom_node.add_child(
ProductNode(f"{symptom}|{disease}"),
weight=cond_prob
)
symptom_nodes[symptom] = symptom_node
# 连接疾病和症状
for disease in self.diseases:
for symptom in self.symptoms:
product = ProductNode(f"{symptom}|{disease}")
product.add_factor(disease_nodes[disease])
product.add_factor(symptom_nodes[symptom])
return circuit
def diagnose(self, observed_symptoms):
"""
诊断: 计算各疾病的后验概率
"""
evidence = {
symptom: 1 for symptom in observed_symptoms
}
# 后验推断
posteriors = {}
for disease in self.diseases:
posteriors[disease] = self.circuit.conditional(
query=disease,
evidence=evidence
)
return posteriors
def explain_diagnosis(self, diagnosis, evidence):
"""
生成诊断解释
"""
# 获取因果路径
paths = self.circuit.get_paths(
query=diagnosis,
evidence=list(evidence.keys())
)
# 计算各路径的贡献
contributions = []
for path in paths:
contribution = self.compute_path_contribution(path, evidence)
contributions.append({
'path': path,
'contribution': contribution,
'probability': np.exp(contribution)
})
# 排序并生成解释
contributions.sort(key=lambda x: x['probability'], reverse=True)
explanation = f"诊断{diagnosis}的主要依据:\n"
for i, c in enumerate(contributions[:3]):
explanation += f"{i+1}. {c['path']}: 概率 {c['probability']:.3f}\n"
return explanation
def what_if_intervention(self, diagnosis, do_intervention):
"""
反事实推理: 如果进行某干预会怎样?
"""
# do(intervention): 强制设置某变量值
do_evidence = {f"do({k})": v for k, v in do_intervention.items()}
# 计算干预后的疾病概率
new_posteriors = {}
for disease in self.diseases:
new_posteriors[disease] = self.circuit.intervene(
query=disease,
intervention=do_evidence
)
return new_posteriors7. 与现有方法对比
7.1 方法对比表
| 维度 | NPC | 标准NN | 变分自编码器 | 概率电路(SPN) |
|---|---|---|---|---|
| 推断精度 | ✓ 精确 | ✗ 近似 | ✗ 近似 | ✓ 精确 |
| 表达能力 | ✓ 强 | ✓ 很强 | ✓ 强 | 中等 |
| 可解释性 | ✓ 高 | ✗ 低 | 中等 | ✓ 高 |
| 可组合性 | ✓ 高 | ✓ 高 | 中等 | 中等 |
| 可扩展性 | 中等 | ✓ 高 | ✓ 高 | 中等 |
| 因果推断 | ✓ 原生支持 | ✗ | ✗ | ✓ |
7.2 适用场景
推荐使用NPC的场景:
- 需要精确不确定性量化的应用
- 决策需要可解释性的高风险场景
- 需要因果干预和反事实推理
- 组合约束必须满足的优化问题
不推荐使用NPC的场景:
- 超大规模数据的高效学习
- 极深网络的需求
- 实时推理的高吞吐需求
8. 实现细节
8.1 完整PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Bernoulli, Categorical
class Node(nn.Module):
"""NPC节点基类"""
def __init__(self, name):
super().__init__()
self.name = name
self.children = []
self.parents = []
def forward(self):
raise NotImplementedError
def marginal(self):
raise NotImplementedError
class InputNode(Node):
"""输入节点 - 编码神经网络模块"""
def __init__(self, name, encoder_net):
super().__init__(name)
self.encoder = encoder_net
def forward(self, x):
return self.encoder(x)
def marginal(self, x):
params = self.forward(x)
return Bernoulli(logits=params)
class SumNode(Node):
"""求和节点 - 边际化"""
def __init__(self, name):
super().__init__(name)
self.weights = None
def set_weights(self, weights):
self.weights = weights
def add_child(self, child, weight):
self.children.append(child)
child.parents.append(self)
if self.weights is None:
self.weights = [weight]
else:
self.weights.append(weight)
def forward(self):
total = 0
for i, child in enumerate(self.children):
w = self.weights[i] if isinstance(self.weights, list) else self.weights[i]
total += w * child.forward()
return total
def marginal(self):
return Categorical(probs=F.softmax(torch.tensor(self.weights), dim=-1))
class ProductNode(Node):
"""乘积节点 - 条件独立"""
def __init__(self, name):
super().__init__(name)
self.factors = []
def add_factor(self, factor):
self.factors.append(factor)
def forward(self):
product = 1
for factor in self.factors:
if isinstance(factor, Node):
product *= factor.forward()
else:
product *= factor
return product
class NeuralProbabilisticCircuit(nn.Module):
"""神经概率电路主类"""
def __init__(self):
super().__init__()
self.nodes = nn.ModuleDict()
self.root = None
def add_node(self, node):
self.nodes[node.name] = node
def set_root(self, node):
self.root = node
def forward(self, x):
if self.root is None:
raise ValueError("Root node not set")
return self.root.forward(x)
def marginal(self, x):
"""精确边际推断"""
return torch.exp(self.forward(x))
def conditional(self, x, evidence):
"""精确条件推断"""
joint = self.forward(x, evidence)
marginal = self.forward(evidence)
return joint / (marginal + 1e-10)
def map_inference(self, evidence):
"""MAP推断 - 最可能配置"""
# 实现Viterbi-style搜索
pass8.2 训练循环
def train_npc(model, train_loader, num_epochs, lr=1e-3):
"""NPC训练循环"""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(num_epochs):
total_loss = 0
for batch in train_loader:
x, y = batch
optimizer.zero_grad()
# 前向传播
log_prob = model(x)
# 负对数似然损失
nll_loss = -log_prob.mean()
# 逻辑正则化(可选)
logic_loss = model.soft_logic_loss(x)
# 总损失
loss = nll_loss + 0.1 * logic_loss
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch}: Loss = {total_loss / len(train_loader):.4f}")
return model9. 局限性与挑战
9.1 当前局限
| 问题 | 描述 | 潜在解决方案 |
|---|---|---|
| 结构设计复杂 | 需要领域知识设计PC结构 | 自动结构学习 |
| 计算复杂度 | 某些操作仍可能指数级 | 近似方法、层次化 |
| 规模化挑战 | 大规模数据上的训练 | 分布式训练、稀疏化 |
9.2 未来方向
- 自动化结构学习: 端到端学习PC结构
- 与LLM集成: 融合语言模型的推理能力
- 多模态扩展: 处理图像、文本、音频混合输入
- 因果发现: 自动从数据中发现因果结构
10. 参考
相关文档: 概率电路与深度学习融合专题 | 概率电路基础 | 因果神经概率电路
Footnotes
-
Chen et al. (2025): Neural Probabilistic Circuits: Enabling Compositional and Interpretable Predictions through Logical Reasoning. arXiv:2501.07021. UIUC. ↩