概述

因果神经概率电路(Causal Neural Probabilistic Circuits, Causal NPC)是将因果推断能力神经概率电路相结合的框架。1

传统的机器学习模型(包括标准NPC)只能进行关联推断(associative inference),即从观测数据中学习相关性。但这种推断无法回答因果问题

  • “如果我进行干预X,会发生什么?“(干预问题)
  • “如果Y没有发生,X还会发生吗?“(反事实问题)

Causal NPC通过引入因果图结构do算子,使得模型能够:

  1. 支持测试时的概念干预
  2. 进行反事实推理
  3. 估计因果效应
  4. 保持概率电路的可处理推断性质

这一框架在以下场景特别有价值:

  • 医疗诊断:评估不同治疗方案的效果
  • 推荐系统:估计推荐对用户行为的影响
  • 自动驾驶:评估不同决策的安全风险
  • 科学发现:推断变量间的因果关系

1. 背景:概念瓶颈模型

1.1 概念瓶颈模型简介

概念瓶颈模型(Concept Bottleneck Models, CBM)是一种增强可解释性的架构设计:2

输入 → [概念层] → [标签层]
        ↑           ↑
     可干预      可解释

核心思想:

  1. 模型首先预测一组中间概念(如图像中的”有翅膀”、“是红色”)
  2. 然后基于概念预测最终标签

1.2 CBM的优势

优势描述
可干预性测试时可修正错误的概念预测
可解释性预测理由可以通过概念解释
领域知识可注入专家知识约束概念关系

1.3 CBM的局限性

问题描述
推断不精确通常使用确定性预测
因果能力有限不支持真正的因果干预
缺乏不确定性不量化概念预测的不确定性

1.4 Causal NPC的解决方案

Causal NPC将CBM的概念瓶颈与概率电路结合,实现:

  • ✓ 概念层的概率表示
  • ✓ 支持do算子的因果干预
  • ✓ 精确的边际/条件推断
  • ✓ 不确定性量化

2. 因果图基础

2.1 结构因果模型

Causal NPC基于结构因果模型(Structural Causal Model, SCM)

定义: SCM是一个四元组 ,其中:

  • :内生变量集合
  • :外生变量集合
  • :因果机制函数族
  • :外生变量的联合分布

2.2 因果图表示

        Z₁         Z₂
       ↙   ↘     ↙   ↘
      ↓       ↓   ↓       ↓
      X₁ ────→ Y ←──── X₂
           ↑
           Z₃

其中:

  • :概念变量
  • :输入特征
  • :输出标签

2.3 do算子

do算子是因果推断的核心工具:

直观理解

  • :观测到的概率(关联)
  • :强制设置的概率(因果)

2.4 因果推断规则

Causal NPC利用以下因果推断规则:

后门调整公式

前门调整公式:用于存在未观测混淆的情况


3. 因果神经概率电路架构

3.1 核心架构

┌─────────────────────────────────────────────────────────────┐
│                    因果神经概率电路                           │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌─────────┐                                               │
│  │ 输入 X   │ ──────────────────────────────────────────┐  │
│  └─────────┘                                             │  │
│       │                                                   │  │
│       ▼                                                   │  │
│  ┌─────────────────────────────────────────┐              │  │
│  │         神经网络编码器                    │              │  │
│  │    h = Encoder_θ(X)                     │              │  │
│  └─────────────────────────────────────────┘              │  │
│       │                                                   │  │
│       ▼                                                   │  │
│  ┌─────────────────────────────────────────┐              │  │
│  │         因果概念层 (C)                   │              │  │
│  │    C = [C₁, C₂, ..., Cₖ]               │              │  │
│  │    P(C | do(X)) = ...                  │              │  │
│  └─────────────────────────────────────────┘              │  │
│       │                                                   │  │
│       ▼                                                   │  │
│  ┌─────────────────────────────────────────┐              │  │
│  │         因果预测层 (Y)                   │              │  │
│  │    Y = f(C, do(θ))                     │              │  │
│  │    P(Y | do(C)) = ...                  │              │  │
│  └─────────────────────────────────────────┘              │  │
│                                                             │
└─────────────────────────────────────────────────────────────┘

3.2 节点类型

节点类型功能因果语义
输入节点编码输入特征观测变量
概念节点表示因果概念中间因果变量
干预节点表示do操作干预变量
输出节点最终预测响应变量
混淆节点未观测因素潜在变量

3.3 条件概率表(CPT)表示

每个因果变量的条件概率可以用概率电路表示:

class CausalConceptNode(nn.Module):
    """因果概念节点"""
    def __init__(self, name, parents, cpt_network):
        super().__init__()
        self.name = name
        self.parents = parents  # 父节点列表
        self.cpt_net = cpt_network  # 神经网络实现的CPT
    
    def forward(self, parent_values, do_interventions={}):
        """
        计算概念的条件概率分布
        
        Args:
            parent_values: 父节点的当前值
            do_interventions: 干预字典 {var_name: value}
        """
        # 如果当前变量被干预,直接返回干预值
        if self.name in do_interventions:
            return do_interventions[self.name]
        
        # 否则计算条件分布
        context = torch.cat([
            parent_values,
            torch.tensor([do_interventions.get(p, 0) for p in self.parents])
        ])
        
        return self.cpt_net(context)
    
    def do_distribution(self, intervention_value):
        """
        计算 P(C = c | do(C = intervention_value))
        即强制设置该概念为某值
        """
        # do操作使变量独立于其父节点
        return torch.eye(self.num_values)[intervention_value]

4. 核心算法

4.1 do算子实现

class CausalNPC(nn.Module):
    """因果神经概率电路"""
    
    def __init__(self, causal_graph):
        super().__init__()
        self.graph = causal_graph  # 因果图结构
        self.concept_nodes = nn.ModuleDict()
        self.encoder = None
    
    def do_intervention(self, variable, value, circuit_state):
        """
        执行do(X = value)操作
        
        效果:
        1. 移除所有指向X的边
        2. 强制X = value
        """
        # 1. 创建干预后的电路状态
        intervened_state = circuit_state.copy()
        
        # 2. 设置干预变量
        intervened_state[variable] = value
        
        # 3. 移除从父节点到X的依赖
        parents = self.graph.parents[variable]
        for parent in parents:
            # 断开边:parent -> variable
            self.disconnect(parent, variable)
        
        return intervened_state
    
    def compute_causal_effect(self, cause, effect, cause_value):
        """
        计算因果效应 P(effect | do(cause = cause_value))
        
        使用后门调整公式:
        P(Y | do(X=x)) = Σ_z P(Y | X=x, Z=z) P(Z=z)
        """
        # 获取混淆变量Z(cause的父节点和后门路径上的变量)
        confounders = self.get_confounders(cause)
        
        total_effect = 0
        for z_value in self.enumerate_assignments(confounders):
            # P(Z = z)
            p_z = self.compute_marginal(confounders, z_value)
            
            # P(Y | X=x, Z=z)
            evidence = {cause: cause_value, **z_value}
            p_y_given = self.compute_conditional(effect, evidence)
            
            total_effect += p_y_given * p_z
        
        return total_effect
    
    def counterfactual(self, individual, hypothetical):
        """
        反事实推理:
        给定观测数据,评估假设情景
        
        三步过程:
        1. Abduction: 根据观测推断潜在变量
        2. Action: 执行干预
        3. Prediction: 预测结果
        """
        observation = individual['observation']
        intervention = hypothetical['intervention']
        query = hypothetical['query']
        
        # Step 1: Abduction
        # P(U | observed_data) - 推断潜在变量后验
        u_posterior = self.abduction(observation)
        
        # Step 2: Action
        # 执行干预 do(intervention)
        modified_circuit = self.apply_intervention(
            intervention, 
            individual['circuit_state']
        )
        
        # Step 3: Prediction
        # 使用更新后的电路和U的后验预测结果
        cf_outcome = self.predict(
            query,
            circuit=modified_circuit,
            latent_posterior=u_posterior
        )
        
        return cf_outcome

4.2 干预效果估计

def estimate_average_treatment_effect(self, treatment_var, outcome_var, 
                                     dataset):
    """
    估计平均 treatment effect (ATE):
    ATE = E[Y | do(T=1)] - E[Y | do(T=0)]
    """
    # 处理组潜在结果
    y_do_1 = self.compute_causal_effect(
        cause=treatment_var,
        effect=outcome_var,
        cause_value=1
    )
    
    # 对照组潜在结果
    y_do_0 = self.compute_causal_effect(
        cause=treatment_var,
        effect=outcome_var,
        cause_value=0
    )
    
    return y_do_1 - y_do_0
 
 
def estimate_conditional_treatment_effect(self, treatment_var, outcome_var,
                                         condition_var, condition_value):
    """
    估计条件 treatment effect (CATE):
    CATE = E[Y | do(T=1), C=c] - E[Y | do(T=0), C=c]
    """
    # 添加条件变量到证据
    evidence = {condition_var: condition_value}
    
    y_do_1 = self.compute_causal_effect(
        cause=treatment_var,
        effect=outcome_var,
        cause_value=1,
        condition=evidence
    )
    
    y_do_0 = self.compute_causal_effect(
        cause=treatment_var,
        effect=outcome_var,
        cause_value=0,
        condition=evidence
    )
    
    return y_do_1 - y_do_0

4.3 测试时概念修正

class ConceptCorrectionInterface:
    """测试时概念修正接口"""
    
    def __init__(self, causal_npc):
        self.model = causal_npc
    
    def predict_with_intervention(self, x, corrections={}):
        """
        预测并允许概念修正
        
        Args:
            x: 输入样本
            corrections: {concept_name: corrected_value}
        
        Returns:
            predictions: 最终预测
            concept_probs: 概念概率分布(修正前)
            explanation: 预测解释
        """
        # 1. 前向传播获取概念分布
        concept_probs = self.model.forward_concepts(x)
        
        # 2. 应用修正(do操作)
        interventions = {}
        for concept, corrected_value in corrections.items():
            interventions[concept] = corrected_value
        
        # 3. 预测(带有干预)
        final_predictions = self.model.predict_with_do(
            x,
            do_interventions=interventions
        )
        
        # 4. 生成解释
        explanation = self.generate_explanation(
            concept_probs,
            corrections,
            final_predictions
        )
        
        return final_predictions, concept_probs, explanation
    
    def what_if_scenario(self, x, concept_changes):
        """
        "What if" 场景分析
        
        Example:
            what_if_scenario(x, {"has_wings": 1, "is_red": 0})
        
        询问: 如果鸟有翅膀但不是红色的,预测会如何变化?
        """
        interventions = concept_changes
        
        # 计算干预后的预测
        cf_predictions = self.model.predict_with_do(
            x,
            do_interventions=interventions
        )
        
        # 计算原始预测
        original_predictions = self.model.forward(x)
        
        # 比较差异
        diff = self.compute_difference(
            cf_predictions,
            original_predictions
        )
        
        return {
            "original": original_predictions,
            "counterfactual": cf_predictions,
            "difference": diff
        }

5. 实现细节

5.1 因果图构建

from collections import defaultdict
 
class CausalGraph:
    """因果图"""
    def __init__(self):
        self.adjacency = defaultdict(list)  # parent -> [children]
        self.reverse_adj = defaultdict(list)  # child -> [parents]
        self.nodes = set()
        self.observed_nodes = set()  # 观测变量
        self.latent_nodes = set()  # 潜在变量
    
    def add_edge(self, parent, child, observed=True):
        """添加因果边 parent -> child"""
        self.adjacency[parent].append(child)
        self.reverse_adj[child].append(parent)
        self.nodes.add(parent)
        self.nodes.add(child)
        
        if observed:
            self.observed_nodes.add(parent)
            self.observed_nodes.add(child)
    
    def add_latent_edge(self, parent, child):
        """添加未观测的因果边(用虚线表示)"""
        self.add_edge(parent, child, observed=False)
        self.latent_nodes.add(parent)
        self.latent_nodes.add(child)
    
    def parents(self, node):
        """获取节点的父节点"""
        return self.reverse_adj.get(node, [])
    
    def children(self, node):
        """获取节点的子节点"""
        return self.adjacency.get(node, [])
    
    def descendants(self, node):
        """获取节点的后代"""
        result = set()
        queue = [node]
        while queue:
            current = queue.pop()
            for child in self.children(current):
                if child not in result:
                    result.add(child)
                    queue.append(child)
        return result
    
    def ancestors(self, node):
        """获取节点的祖先"""
        result = set()
        queue = [node]
        while queue:
            current = queue.pop()
            for parent in self.parents(current):
                if parent not in result:
                    result.add(parent)
                    queue.append(parent)
        return result
    
    def is_d_separated(self, x, y, z):
        """
        检查X和Y是否在给定Z时d-分离
        d-分离意味着X和Y条件独立于Z
        """
        # 构建被Z阻隔的图
        blocked = set(z)
        
        # BFS寻找连接路径
        def has_connection(start, end, blocked_set):
            visited = set()
            queue = [(start, None)]  # (node, path_type)
            
            while queue:
                node, path_type = queue.pop(0)
                
                if node == end:
                    return True
                
                if node in visited:
                    continue
                visited.add(node)
                
                # 检查前向路径
                for child in self.children(node):
                    if child not in blocked_set:
                        queue.append((child, "forward"))
                
                # 检查后向路径
                for parent in self.parents(node):
                    if parent not in blocked_set:
                        # 串行/分叉路径在collider处不被阻隔
                        if path_type != "backward" or node not in blocked_set:
                            queue.append((parent, "backward"))
                
                # 检查V-结构(碰撞)
                for child in self.children(node):
                    if child in blocked_set:
                        # 碰撞点在blocked时,路径被阻隔
                        if node not in blocked_set:
                            for grandchild in self.children(child):
                                queue.append((grandchild, "forward"))
            
            return False
        
        return not has_connection(x, y, blocked)

5.2 完整模型实现

class CausalNeuralProbabilisticCircuit(nn.Module):
    """因果神经概率电路完整实现"""
    
    def __init__(self, concept_names, outcome_name, encoder_dim=512):
        super().__init__()
        self.concept_names = concept_names
        self.outcome_name = outcome_name
        
        # 编码器
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, encoder_dim),
            nn.ReLU(),
            nn.Linear(encoder_dim, encoder_dim)
        )
        
        # 概念层
        self.concept_layers = nn.ModuleDict()
        for name in concept_names:
            self.concept_layers[name] = ConceptLayer(
                input_dim=encoder_dim,
                output_dim=2  # 二值概念
            )
        
        # 因果结构(可学习)
        self.causal_structure = CausalStructure(
            concepts=concept_names,
            outcome=outcome_name
        )
        
        # 预测层
        self.predictor = PredictorLayer(
            input_dim=len(concept_names) * 2,
            output_dim=num_classes
        )
        
        # 概率电路组件
        self.pc_components = ProbabilisticCircuitComponents()
    
    def forward(self, x, interventions={}):
        """
        前向传播(无干预)
        """
        # 1. 编码输入
        h = self.encoder(x)
        
        # 2. 推断概念
        concepts = {}
        for name, layer in self.concept_layers.items():
            # 如果被干预,使用干预值
            if name in interventions:
                concepts[name] = interventions[name]
            else:
                concepts[name] = layer(h)
        
        # 3. 构建概念表示
        concept_repr = torch.cat([concepts[name] for name in self.concept_names], dim=-1)
        
        # 4. 预测
        logits = self.predictor(concept_repr)
        
        return logits
    
    def predict_with_do(self, x, do_interventions):
        """
        预测(带有do干预)
        
        do_interventions: {variable_name: value}
        """
        # 移除被干预变量的父节点依赖
        h = self.encoder(x)
        
        concepts = {}
        for name, layer in self.concept_layers.items():
            if name in do_interventions:
                # do操作:强制设置值
                concepts[name] = F.one_hot(
                    torch.tensor(do_interventions[name]),
                    num_classes=2
                ).float().to(x.device)
            else:
                concepts[name] = layer(h)
        
        concept_repr = torch.cat([concepts[name] for name in self.concept_names], dim=-1)
        logits = self.predictor(concept_repr)
        
        return logits
    
    def compute_causal_effect(self, cause, effect, cause_value, x):
        """
        计算因果效应
        """
        # P(effect | do(cause))
        y_do = self.predict_with_do(x, {cause: cause_value})
        
        # P(effect | do(cause ≠ cause_value))
        other_value = 1 - cause_value
        y_do_other = self.predict_with_do(x, {cause: other_value})
        
        return y_do - y_do_other
 
 
class ConceptLayer(nn.Module):
    """概念层"""
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, input_dim // 2),
            nn.ReLU(),
            nn.Linear(input_dim // 2, output_dim)
        )
    
    def forward(self, h):
        logits = self.net(h)
        return F.softmax(logits, dim=-1)
 
 
class PredictorLayer(nn.Module):
    """预测层"""
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(input_dim, output_dim)
        )
    
    def forward(self, x):
        return self.net(x)

6. 应用示例

6.1 医疗诊断应用

class MedicalDiagnosisCausalNPC:
    """
    医疗诊断因果NPC
    
    因果图:
    症状(S) ← 疾病(D) → 治疗(T)
    
    概念:发烧、咳嗽、胸痛、血糖...
    标签:疾病类型
    """
    
    def __init__(self):
        # 定义因果结构
        self.concepts = ['fever', 'cough', 'chest_pain', 'blood_sugar']
        self.outcome = 'disease'
        
        self.model = CausalNeuralProbabilisticCircuit(
            concept_names=self.concepts,
            outcome_name=self.outcome
        )
    
    def diagnose_with_intervention(self, patient_data, corrections={}):
        """
        诊断并允许医生修正概念预测
        
        corrections: 医生修正的预测
        例如:{"fever": 1, "cough": 0}
        """
        x = self.prepare_patient_data(patient_data)
        
        # 获取预测(允许修正)
        predictions = self.model.predict_with_do(
            x,
            do_interventions=corrections
        )
        
        # 获取概念概率
        concept_probs = {}
        for name in self.concepts:
            if name not in corrections:
                concept_probs[name] = self.model.concept_layers[name](
                    self.model.encoder(x)
                )
            else:
                concept_probs[name] = F.one_hot(
                    torch.tensor(corrections[name]),
                    num_classes=2
                ).float()
        
        # 生成诊断报告
        report = self.generate_report(
            predictions,
            concept_probs,
            corrections
        )
        
        return report
    
    def estimate_treatment_effect(self, patient_data, treatment, outcome):
        """
        估计治疗效果
        
        Example:
            effect = estimate_treatment_effect(
                patient_data,
                treatment='antibiotics',
                outcome='recovery'
            )
        """
        x = self.prepare_patient_data(patient_data)
        
        # 计算ATE
        ate = self.model.compute_causal_effect(
            cause=treatment,
            effect=outcome,
            cause_value=1,
            x=x
        )
        
        return ate
    
    def what_if_treatment(self, patient_data, treatment_changes):
        """
        反事实:不同治疗方案的效果对比
        
        Example:
            what_if_treatment(patient, {"surgery": 1, "medication": 0})
        """
        x = self.prepare_patient_data(patient_data)
        
        results = {}
        for treatment_plan, value in treatment_changes.items():
            prediction = self.model.predict_with_do(
                x,
                do_interventions={treatment_plan: value}
            )
            results[treatment_plan] = prediction
        
        return results

6.2 推荐系统应用

class RecommenderCausalNPC:
    """
    因果推荐NPC
    
    因果图:
    用户特征(U) → 偏好(P) → 评分(R) ← 项目特征(I)

                  推荐(A)
    """
    
    def estimate_recommendation_effect(self, user_data, item_data, 
                                     recommendation):
        """
        估计推荐对用户行为的因果影响
        
        例如:
        - 推荐商品A会提高用户购买概率吗?
        - 推荐电影B会影响用户满意度吗?
        """
        # 构建输入
        x = self.combine_features(user_data, item_data)
        
        # 计算干预效果
        effect = self.model.compute_causal_effect(
            cause='recommendation',
            effect='engagement',
            cause_value=recommendation,
            x=x
        )
        
        return effect
    
    def counterfactual_recommendation(self, historical_data, 
                                     alternative_recommendation):
        """
        反事实推荐分析
        
        如果推荐了不同的商品,结果会如何?
        """
        # 推断潜在因素
        u_posterior = self.abduction(historical_data)
        
        # 应用反事实干预
        cf_outcome = self.counterfactual(
            individual={
                'observation': historical_data,
                'circuit_state': self.get_circuit_state(historical_data)
            },
            hypothetical={
                'intervention': {'recommendation': alternative_recommendation},
                'query': 'engagement'
            }
        )
        
        return cf_outcome

7. 与其他方法对比

7.1 因果推断方法对比

方法推断类型可扩展性可解释性不确定性
Causal NPC精确因果中等
标准CBM关联中等
变分因果发现近似因果
结构方程模型精确因果
因果森林近似因果中等

7.2 Causal NPC的优势

  1. 精确推断: 利用概率电路实现精确因果推断
  2. 可解释: 推理路径完全透明
  3. 可干预: 原生支持do操作
  4. 不确定性: 概率表示支持不确定性量化

7.3 局限性

  1. 图结构假设: 需要预先定义因果图
  2. 可扩展性: 复杂图结构可能计算困难
  3. 潜在变量: 难以处理大量未观测混淆

8. 实践指南

8.1 因果图构建建议

  1. 从领域知识出发: 利用专家知识定义因果关系
  2. 验证假设: 使用因果发现算法验证/补充假设
  3. 简化结构: 避免过度复杂的图结构
  4. 处理混淆: 识别并标注未观测变量

8.2 训练技巧

# 因果一致性正则化
def causal_consistency_loss(model, x, y):
    """
    鼓励模型学习因果关系而非虚假关联
    """
    # 1. 计算原始预测
    pred_original = model(x)
    
    # 2. 计算干预预测
    concepts = model.extract_concepts(x)
    interventions = {k: 1 - v for k, v in concepts.items()}
    pred_intervened = model.predict_with_do(x, interventions)
    
    # 3. 因果一致性损失
    # 如果干预改变概念,预测应该相应改变
    consistency_loss = F.mse_loss(pred_original, pred_intervened)
    
    return consistency_loss
 
# 多任务训练
def train_with_causal_regularization(model, loader, alpha=0.1):
    """
    带因果正则化的训练
    """
    for x, y in loader:
        # 标准损失
        pred = model(x)
        ce_loss = F.cross_entropy(pred, y)
        
        # 因果一致性损失
        causal_loss = causal_consistency_loss(model, x, y)
        
        # 总损失
        loss = ce_loss + alpha * causal_loss
        
        loss.backward()
        optimizer.step()

9. 参考


相关文档: 神经概率电路 | 概率图电路 | 因果推断进阶

Footnotes

  1. Chen & Zhao (2026): Causal Neural Probabilistic Circuits. arXiv:2603.01372. UIUC.

  2. Koh et al. (2020): Concept Bottleneck Models. ICML 2020.