概述

神经概率电路(Neural Probabilistic Circuits, NPC)是一种将神经网络的可组合性概率电路的可推断性相结合的框架。1

传统的概率电路(如SPN)在结构学习后参数固定,难以利用大规模数据;而神经网络虽然表达能力强,但推断往往需要近似方法(如变分推断、采样)。NPC的核心思想是:

让神经网络成为概率电路的基本构建块,同时保持精确推断的能力。

这一框架使得:

  • 模块化的神经网络组件可以直接嵌入概率电路
  • 关键的概率计算(边际、条件)可以在多项式时间内精确完成
  • 推理路径完全透明可追踪

1. 问题背景

1.1 概率电路的局限性

传统概率电路(如Sum-Product Networks, SPN)面临以下挑战:

问题描述影响
表达能力有限固定结构的PC难以捕捉复杂模式需要手动设计或复杂结构学习
参数效率低大量参数但利用不充分难以规模化
缺乏组合性模块难以复用构建复杂模型困难

1.2 神经网络的局限性

神经网络虽然强大,但存在:

问题描述影响
推断不透明近似推断(VI/MC Dropout)不确定性量化不准确
黑盒性质决策过程不可解释在高风险应用中受限
计算成本高采样/变分推断开销大实时应用困难

1.3 NPC的解决思路

NPC通过以下设计解决上述问题:

┌─────────────────────────────────────────────────────────────┐
│                    神经概率电路 (NPC)                         │
├─────────────────────────────────────────────────────────────┤
│  ┌─────────────┐    ┌─────────────┐    ┌─────────────┐      │
│  │  神经网络    │    │  神经网络    │    │  神经网络    │      │
│  │   模块       │ +  │   模块       │ +  │   模块       │      │
│  │  (可学习)    │    │  (可学习)    │    │  (可学习)    │      │
│  └──────┬──────┘    └──────┬──────┘    └──────┬──────┘      │
│         │                  │                  │             │
│         └──────────────────┼──────────────────┘             │
│                            ▼                                │
│                   ┌────────────────┐                        │
│                   │   概率电路       │                        │
│                   │   推理层        │                        │
│                   │ (精确边际/条件) │                        │
│                   └────────────────┘                        │
│                            │                                │
│                            ▼                                │
│                   ┌────────────────┐                        │
│                   │   可解释输出    │                        │
│                   │  (推理路径)     │                        │
│                   └────────────────┘                        │
└─────────────────────────────────────────────────────────────┘

2. 神经概率电路框架

2.1 形式化定义

定义(神经概率电路):

神经概率电路是一个有向无环图 ,其中每个节点 是以下两种类型之一:

  1. 输入节点 : 对应随机变量或神经网络编码
  2. 计算节点 : 实现特定的概率操作

形式上,NPC定义了一个复合函数:

其中 是可学习参数, 是输入空间。

2.2 节点类型

NPC包含以下核心节点类型:

节点类型功能数学表示
神经网络节点特征提取/变换
乘积节点条件独立
求和节点边际化
证据节点条件概率

2.3 组合性原则

NPC的核心设计原则是组合性(Compositionality)

复合操作保持可处理性: 如果两个子电路都是可处理的,则它们的组合也是可处理的。

数学表达:

是可处理的,则:

  • (乘积组合)可处理
  • (求和组合)可处理
  • (条件组合)可处理

3. 核心机制

3.1 符号-连续混合表示

NPC的一个关键创新是统一表示符号随机变量连续变量

class SymbolicVariable:
    """符号变量节点"""
    def __init__(self, name, domain):
        self.name = name
        self.domain = domain  # 如 {0, 1} 或 {A, B, C}
    
    def __repr__(self):
        return f"Symbolic({self.name}{self.domain})"
 
 
class ContinuousVariable:
    """连续变量节点"""
    def __init__(self, name, encoder_net):
        self.name = name
        self.encoder = encoder_net  # 神经网络编码器
    
    def encode(self, x):
        # 将连续输入编码为概率分布参数
        params = self.encoder(x)
        return Distribution(params)

3.2 神经网络模块集成

NPC允许各种神经网络架构作为模块嵌入:

class NeuralModule(nn.Module):
    """神经网络模块基类"""
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)
    
    def to_probabilistic(self, x):
        """转换为概率参数"""
        params = self.forward(x)
        return Bernoulli(logits=params)  # 或其他分布
 
 
class NPCModule(nn.Module):
    """NPC模块容器"""
    def __init__(self, module):
        super().__init__()
        self.module = module
        self.is_conditioning = False
        self.conditioned_value = None
    
    def condition(self, value):
        """设置条件变量"""
        self.is_conditioning = True
        self.conditioned_value = value
        return self
    
    def marginalize(self):
        """边际化处理"""
        self.is_conditioning = False
        self.conditioned_value = None

3.3 精确推断算法

NPC支持多种精确推断操作:

3.3.1 边际推断

def marginal_inference(circuit, query_vars, evidence={}):
    """
    精确边际推断
    
    复杂度: O(2^{|query_vars|}) 但利用PC结构优化
    """
    # 1. 简化电路(固定证据)
    simplified = simplify(circuit, evidence)
    
    # 2. 自底向上传播
    bottom_up(simplified)
    
    # 3. 提取边际
    return extract_marginal(simplified, query_vars)

3.3.2 条件推断

def conditional_inference(circuit, query, evidence):
    """
    P(query | evidence) = P(query, evidence) / P(evidence)
    """
    # 联合分布
    joint = joint_marginal(circuit, query + list(evidence.keys()))
    
    # 证据边际
    marginal_evidence = marginal_inference(circuit, list(evidence.keys()))
    
    # 条件概率
    return joint / marginal_evidence

3.3.3 MAP推断

def map_inference(circuit, evidence={}):
    """
    找到最可能的配置
    """
    # 利用PC的结构进行高效搜索
    return viterbi_search(circuit, evidence)

4. 逻辑推理集成

4.1 命题逻辑融合

NPC可以将命题逻辑规则概率推断结合:

class LogicalRule(nn.Module):
    """逻辑规则模块"""
    def __init__(self, antecedent, consequent, weight=1.0):
        super().__init__()
        self.antecedent = antecedent  # 前提变量列表
        self.consequent = consequent  # 结论变量
        self.weight = weight
    
    def forward(self, assignments):
        """
        评估规则的真值
        assignments: {var: value} 字典
        """
        antecedent_val = all(
            assignments.get(var) == val 
            for var, val in self.antecedent
        )
        consequent_val = assignments.get(self.consequent[0])
        
        # 返回逻辑蕴含的真值
        return float(antecedent_val == consequent_val)
 
 
class LogicAwareCircuit(nn.Module):
    """集成逻辑规则的NPC"""
    def __init__(self):
        super().__init__()
        self.nn_modules = nn.ModuleList([...])  # 神经网络模块
        self.logic_rules = nn.ModuleList([...])  # 逻辑规则
    
    def soft_logic_loss(self, data):
        """
        软逻辑损失:鼓励满足逻辑约束
        """
        total_loss = 0
        for rule in self.logic_rules:
            assignments = self.extract_assignments(data)
            rule_violation = 1 - rule(assignments)
            total_loss += self.rule.weight * rule_violation
        return total_loss

4.2 一阶逻辑近似

对于一阶逻辑,NPC提供近似方法:

class FirstOrderLogicApproximation:
    """一阶逻辑的NPC近似"""
    
    def __init__(self, pc_circuit):
        self.pc = pc_circuit
        self.grounding_cache = {}
    
    def forall(self, variable, domain, formula):
        """
        ∀x: Formula(x) 的概率
        """
        total_prob = 0
        for x in domain:
            prob = self.pc.evaluate(formula.replace(variable, x))
            total_prob += prob
        return total_prob / len(domain)
    
    def exists(self, variable, domain, formula):
        """
        ∃x: Formula(x) 的概率
        """
        max_prob = 0
        for x in domain:
            prob = self.pc.evaluate(formula.replace(variable, x))
            max_prob = max(max_prob, prob)
        return max_prob

4.3 规则挖掘与学习

NPC可以从数据中自动学习逻辑规则

class RuleLearner(nn.Module):
    """从NPC中挖掘逻辑规则"""
    def __init__(self, circuit, min_support=0.1, min_confidence=0.8):
        super().__init__()
        self.circuit = circuit
        self.min_support = min_support
        self.min_confidence = min_confidence
    
    def mine_rules(self, data):
        """
        挖掘频繁模式和关联规则
        """
        # 1. 发现频繁项集
        frequent_itemsets = self.find_frequent(data)
        
        # 2. 生成关联规则
        rules = []
        for itemset in frequent_itemsets:
            for subset in proper_subsets(itemset):
                antecedent = subset
                consequent = itemset - subset
                
                # 计算置信度
                p_antecedent = self.circuit.marginal(list(antecedent))
                p_joint = self.circuit.marginal(list(itemset))
                confidence = p_joint / (p_antecedent + 1e-10)
                
                if confidence >= self.min_confidence:
                    rules.append({
                        'antecedent': antecedent,
                        'consequent': consequent,
                        'confidence': confidence,
                        'support': p_joint
                    })
        
        return rules

5. 可解释性机制

5.1 透明推理路径

NPC的核心优势之一是完全透明的推理过程

class ExplainableNPC(nn.Module):
    """可解释的NPC"""
    
    def forward_with_explanation(self, x):
        """
        返回预测及其解释
        """
        # 前向传播并记录路径
        path_trace = []
        value_trace = []
        
        def hook_fn(module, input, output):
            path_trace.append({
                'module': module.__class__.__name__,
                'module_id': id(module),
                'inputs': [i.clone() for i in input],
                'outputs': output.clone()
            })
            value_trace.append(output)
        
        # 注册hook
        handles = []
        for module in self.modules():
            if isinstance(module, (SummingNode, ProductNode)):
                handles.append(module.register_forward_hook(hook_fn))
        
        # 前向传播
        output = self.forward(x)
        
        # 移除hook
        for h in handles:
            h.remove()
        
        return output, path_trace, value_trace
    
    def generate_explanation(self, x, prediction):
        """
        生成自然语言解释
        """
        _, path_trace, value_trace = self.forward_with_explanation(x)
        
        explanation_parts = []
        for step in path_trace:
            if isinstance(step['module'], NeuralModule):
                explanation_parts.append(
                    f"神经网络提取特征: {step['outputs'].shape}"
                )
            elif isinstance(step['module'], SumsNode):
                explanation_parts.append(
                    f"加权组合 {len(step['outputs'])} 个候选项"
                )
        
        return " → ".join(explanation_parts)

5.2 因果链追踪

class CausalTrace:
    """因果链追踪"""
    
    def trace_causal_path(self, circuit, query_var, target_var):
        """
        追踪从 query_var 到 target_var 的因果路径
        """
        paths = []
        
        def dfs(node, current_path):
            if node == target_var:
                paths.append(current_path.copy())
                return
            
            for child in node.children:
                current_path.append(child)
                dfs(child, current_path)
                current_path.pop()
        
        dfs(query_var, [query_var])
        return paths
    
    def compute_path_effects(self, circuit, paths):
        """
        计算各因果路径的效应
        """
        effects = {}
        for path in paths:
            effect = self.compute_path_effect(circuit, path)
            effects[tuple(path)] = effect
        return effects

5.3 置信度校准

class CalibratedNPC(nn.Module):
    """置信度校准的NPC"""
    
    def __init__(self, circuit, temperature=1.0):
        super().__init__()
        self.circuit = circuit
        self.temperature = nn.Parameter(torch.tensor(temperature))
    
    def calibrate(self, val_loader):
        """
        使用验证集校准置信度
        """
        optimizer = torch.optim.LBFGS([self.temperature], lr=0.01)
        
        def closure():
            optimizer.zero_grad()
            total_nll = 0
            for x, y in val_loader:
                logits = self.forward(x) / self.temperature
                nll = F.cross_entropy(logits, y)
                total_nll += nll
            total_nll.backward()
            return total_nll
        
        optimizer.step(closure)
        return self.temperature.item()
    
    def predict(self, x):
        """
        返回校准后的预测和置信度
        """
        logits = self.forward(x) / self.temperature
        probs = F.softmax(logits, dim=-1)
        predictions = probs.argmax(dim=-1)
        confidences = probs.max(dim=-1).values
        
        return predictions, confidences

6. 应用场景

6.1 组合优化

NPC可以用于求解组合优化问题,同时提供最优性保证

class CombinatorialNPC:
    """
    组合优化NPC
    示例: 旅行商问题(TSP)
    """
    def __init__(self, num_cities):
        self.num_cities = num_cities
        self.circuit = self.build_tsp_circuit()
    
    def build_tsp_circuit(self):
        """
        构建TSP的NPC表示
        """
        # 决策变量: X[i,j] = 1 表示从城市i到城市j
        circuit = ProbabilisticCircuit()
        
        # 添加路径约束节点
        for i in range(self.num_cities):
            # 每个城市恰好离开一次
            out_sum = SumNode(f"out_{i}")
            for j in range(self.num_cities):
                if i != j:
                    out_sum.add_child(
                        ProductNode(f"edge_{i}_{j}"),
                        weight=1.0
                    )
            
            # 每个城市恰好进入一次
            in_sum = SumNode(f"in_{j}")
            for i in range(self.num_cities):
                if i != j:
                    in_sum.add_child(
                        ProductNode(f"edge_{i}_{j}"),
                        weight=1.0
                    )
        
        # 添加距离成本
        self.add_distance_costs()
        
        return circuit
    
    def solve(self, distances):
        """
        求解TSP
        返回: 最优路径及其概率/成本
        """
        # 固定距离证据
        evidence = self.set_distance_evidence(distances)
        
        # MAP推断
        best_assignment = self.circuit.map_inference(evidence)
        
        # 提取路径
        path = self.extract_path(best_assignment)
        cost = self.compute_cost(path, distances)
        
        return path, cost

6.2 知识图谱推理

class KnowledgeGraphNPC(nn.Module):
    """
    知识图谱推理NPC
    """
    def __init__(self, num_entities, num_relations):
        super().__init__()
        self.entity_embeddings = nn.Embedding(num_entities, embed_dim)
        self.relation_embeddings = nn.Embedding(num_relations, embed_dim)
        self.npc = self.build_npc()
    
    def build_npc(self):
        """
        构建知识图谱的NPC
        """
        circuit = ProbabilisticCircuit()
        
        # 关系先验
        relation_prior = SumNode("relation_prior")
        
        # 实体条件概率
        for r in range(self.num_relations):
            condition_node = ProductNode(f"cond_{r}")
            condition_node.add_factor(
                self.relation_embeddings(r)
            )
            relation_prior.add_child(condition_node, weight=1.0)
        
        circuit.add_root(relation_prior)
        return circuit
    
    def predict_link(self, head, relation, tail):
        """
        预测链接存在概率
        P(tail | head, relation)
        """
        evidence = {
            'head': head,
            'relation': relation
        }
        
        # 条件推断
        probs = self.npc.conditional(
            query='tail',
            evidence=evidence
        )
        
        return probs[tail]
    
    def learn_rules(self, kg_data):
        """
        从知识图谱学习逻辑规则
        """
        rule_learner = RuleLearner(self.npc)
        rules = rule_learner.mine_rules(kg_data)
        
        # 过滤高质量规则
        high_quality = [
            r for r in rules 
            if r['confidence'] > 0.9 and r['support'] > 0.1
        ]
        
        return high_quality

6.3 医疗诊断

class MedicalDiagnosisNPC(nn.Module):
    """
    医疗诊断NPC
    支持因果干预和解释
    """
    def __init__(self, symptoms, diseases):
        super().__init__()
        self.symptoms = symptoms
        self.diseases = diseases
        self.circuit = self.build_diagnostic_circuit()
    
    def build_diagnostic_circuit(self):
        """
        构建诊断电路
        结构: 症状 → 疾病 → 治疗
        """
        circuit = ProbabilisticCircuit()
        
        # 疾病节点
        disease_nodes = {}
        for disease in self.diseases:
            disease_nodes[disease] = SumNode(f"disease_{disease}")
            # 先验概率
            disease_nodes[disease].set_weights(self.get_prior(disease))
        
        # 症状节点(条件于疾病)
        symptom_nodes = {}
        for symptom in self.symptoms:
            symptom_node = SumNode(f"symptom_{symptom}")
            
            for disease in self.diseases:
                cond_prob = self.get_likelihood(symptom, disease)
                symptom_node.add_child(
                    ProductNode(f"{symptom}|{disease}"),
                    weight=cond_prob
                )
            
            symptom_nodes[symptom] = symptom_node
        
        # 连接疾病和症状
        for disease in self.diseases:
            for symptom in self.symptoms:
                product = ProductNode(f"{symptom}|{disease}")
                product.add_factor(disease_nodes[disease])
                product.add_factor(symptom_nodes[symptom])
        
        return circuit
    
    def diagnose(self, observed_symptoms):
        """
        诊断: 计算各疾病的后验概率
        """
        evidence = {
            symptom: 1 for symptom in observed_symptoms
        }
        
        # 后验推断
        posteriors = {}
        for disease in self.diseases:
            posteriors[disease] = self.circuit.conditional(
                query=disease,
                evidence=evidence
            )
        
        return posteriors
    
    def explain_diagnosis(self, diagnosis, evidence):
        """
        生成诊断解释
        """
        # 获取因果路径
        paths = self.circuit.get_paths(
            query=diagnosis,
            evidence=list(evidence.keys())
        )
        
        # 计算各路径的贡献
        contributions = []
        for path in paths:
            contribution = self.compute_path_contribution(path, evidence)
            contributions.append({
                'path': path,
                'contribution': contribution,
                'probability': np.exp(contribution)
            })
        
        # 排序并生成解释
        contributions.sort(key=lambda x: x['probability'], reverse=True)
        
        explanation = f"诊断{diagnosis}的主要依据:\n"
        for i, c in enumerate(contributions[:3]):
            explanation += f"{i+1}. {c['path']}: 概率 {c['probability']:.3f}\n"
        
        return explanation
    
    def what_if_intervention(self, diagnosis, do_intervention):
        """
        反事实推理: 如果进行某干预会怎样?
        """
        # do(intervention): 强制设置某变量值
        do_evidence = {f"do({k})": v for k, v in do_intervention.items()}
        
        # 计算干预后的疾病概率
        new_posteriors = {}
        for disease in self.diseases:
            new_posteriors[disease] = self.circuit.intervene(
                query=disease,
                intervention=do_evidence
            )
        
        return new_posteriors

7. 与现有方法对比

7.1 方法对比表

维度NPC标准NN变分自编码器概率电路(SPN)
推断精度✓ 精确✗ 近似✗ 近似✓ 精确
表达能力✓ 强✓ 很强✓ 强中等
可解释性✓ 高✗ 低中等✓ 高
可组合性✓ 高✓ 高中等中等
可扩展性中等✓ 高✓ 高中等
因果推断✓ 原生支持

7.2 适用场景

推荐使用NPC的场景:

  • 需要精确不确定性量化的应用
  • 决策需要可解释性的高风险场景
  • 需要因果干预和反事实推理
  • 组合约束必须满足的优化问题

不推荐使用NPC的场景:

  • 超大规模数据的高效学习
  • 极深网络的需求
  • 实时推理的高吞吐需求

8. 实现细节

8.1 完整PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Bernoulli, Categorical
 
class Node(nn.Module):
    """NPC节点基类"""
    def __init__(self, name):
        super().__init__()
        self.name = name
        self.children = []
        self.parents = []
    
    def forward(self):
        raise NotImplementedError
    
    def marginal(self):
        raise NotImplementedError
 
 
class InputNode(Node):
    """输入节点 - 编码神经网络模块"""
    def __init__(self, name, encoder_net):
        super().__init__(name)
        self.encoder = encoder_net
    
    def forward(self, x):
        return self.encoder(x)
    
    def marginal(self, x):
        params = self.forward(x)
        return Bernoulli(logits=params)
 
 
class SumNode(Node):
    """求和节点 - 边际化"""
    def __init__(self, name):
        super().__init__(name)
        self.weights = None
    
    def set_weights(self, weights):
        self.weights = weights
    
    def add_child(self, child, weight):
        self.children.append(child)
        child.parents.append(self)
        if self.weights is None:
            self.weights = [weight]
        else:
            self.weights.append(weight)
    
    def forward(self):
        total = 0
        for i, child in enumerate(self.children):
            w = self.weights[i] if isinstance(self.weights, list) else self.weights[i]
            total += w * child.forward()
        return total
    
    def marginal(self):
        return Categorical(probs=F.softmax(torch.tensor(self.weights), dim=-1))
 
 
class ProductNode(Node):
    """乘积节点 - 条件独立"""
    def __init__(self, name):
        super().__init__(name)
        self.factors = []
    
    def add_factor(self, factor):
        self.factors.append(factor)
    
    def forward(self):
        product = 1
        for factor in self.factors:
            if isinstance(factor, Node):
                product *= factor.forward()
            else:
                product *= factor
        return product
 
 
class NeuralProbabilisticCircuit(nn.Module):
    """神经概率电路主类"""
    def __init__(self):
        super().__init__()
        self.nodes = nn.ModuleDict()
        self.root = None
    
    def add_node(self, node):
        self.nodes[node.name] = node
    
    def set_root(self, node):
        self.root = node
    
    def forward(self, x):
        if self.root is None:
            raise ValueError("Root node not set")
        return self.root.forward(x)
    
    def marginal(self, x):
        """精确边际推断"""
        return torch.exp(self.forward(x))
    
    def conditional(self, x, evidence):
        """精确条件推断"""
        joint = self.forward(x, evidence)
        marginal = self.forward(evidence)
        return joint / (marginal + 1e-10)
    
    def map_inference(self, evidence):
        """MAP推断 - 最可能配置"""
        # 实现Viterbi-style搜索
        pass

8.2 训练循环

def train_npc(model, train_loader, num_epochs, lr=1e-3):
    """NPC训练循环"""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in train_loader:
            x, y = batch
            
            optimizer.zero_grad()
            
            # 前向传播
            log_prob = model(x)
            
            # 负对数似然损失
            nll_loss = -log_prob.mean()
            
            # 逻辑正则化(可选)
            logic_loss = model.soft_logic_loss(x)
            
            # 总损失
            loss = nll_loss + 0.1 * logic_loss
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch}: Loss = {total_loss / len(train_loader):.4f}")
    
    return model

9. 局限性与挑战

9.1 当前局限

问题描述潜在解决方案
结构设计复杂需要领域知识设计PC结构自动结构学习
计算复杂度某些操作仍可能指数级近似方法、层次化
规模化挑战大规模数据上的训练分布式训练、稀疏化

9.2 未来方向

  1. 自动化结构学习: 端到端学习PC结构
  2. 与LLM集成: 融合语言模型的推理能力
  3. 多模态扩展: 处理图像、文本、音频混合输入
  4. 因果发现: 自动从数据中发现因果结构

10. 参考


相关文档: 概率电路与深度学习融合专题 | 概率电路基础 | 因果神经概率电路

Footnotes

  1. Chen et al. (2025): Neural Probabilistic Circuits: Enabling Compositional and Interpretable Predictions through Logical Reasoning. arXiv:2501.07021. UIUC.