概述
因果神经概率电路(Causal Neural Probabilistic Circuits, Causal NPC)是将因果推断能力与神经概率电路相结合的框架。1
传统的机器学习模型(包括标准NPC)只能进行关联推断(associative inference),即从观测数据中学习相关性。但这种推断无法回答因果问题:
- “如果我进行干预X,会发生什么?“(干预问题)
- “如果Y没有发生,X还会发生吗?“(反事实问题)
Causal NPC通过引入因果图结构和do算子,使得模型能够:
- 支持测试时的概念干预
- 进行反事实推理
- 估计因果效应
- 保持概率电路的可处理推断性质
这一框架在以下场景特别有价值:
- 医疗诊断:评估不同治疗方案的效果
- 推荐系统:估计推荐对用户行为的影响
- 自动驾驶:评估不同决策的安全风险
- 科学发现:推断变量间的因果关系
1. 背景:概念瓶颈模型
1.1 概念瓶颈模型简介
概念瓶颈模型(Concept Bottleneck Models, CBM)是一种增强可解释性的架构设计:2
输入 → [概念层] → [标签层]
↑ ↑
可干预 可解释
核心思想:
- 模型首先预测一组中间概念(如图像中的”有翅膀”、“是红色”)
- 然后基于概念预测最终标签
1.2 CBM的优势
| 优势 | 描述 |
|---|---|
| 可干预性 | 测试时可修正错误的概念预测 |
| 可解释性 | 预测理由可以通过概念解释 |
| 领域知识 | 可注入专家知识约束概念关系 |
1.3 CBM的局限性
| 问题 | 描述 |
|---|---|
| 推断不精确 | 通常使用确定性预测 |
| 因果能力有限 | 不支持真正的因果干预 |
| 缺乏不确定性 | 不量化概念预测的不确定性 |
1.4 Causal NPC的解决方案
Causal NPC将CBM的概念瓶颈与概率电路结合,实现:
- ✓ 概念层的概率表示
- ✓ 支持do算子的因果干预
- ✓ 精确的边际/条件推断
- ✓ 不确定性量化
2. 因果图基础
2.1 结构因果模型
Causal NPC基于结构因果模型(Structural Causal Model, SCM):
定义: SCM是一个四元组 ,其中:
- :内生变量集合
- :外生变量集合
- :因果机制函数族
- :外生变量的联合分布
2.2 因果图表示
Z₁ Z₂
↙ ↘ ↙ ↘
↓ ↓ ↓ ↓
X₁ ────→ Y ←──── X₂
↑
Z₃
其中:
- :概念变量
- :输入特征
- :输出标签
2.3 do算子
do算子是因果推断的核心工具:
直观理解:
- :观测到时的概率(关联)
- :强制设置时的概率(因果)
2.4 因果推断规则
Causal NPC利用以下因果推断规则:
后门调整公式:
前门调整公式:用于存在未观测混淆的情况
3. 因果神经概率电路架构
3.1 核心架构
┌─────────────────────────────────────────────────────────────┐
│ 因果神经概率电路 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────┐ │
│ │ 输入 X │ ──────────────────────────────────────────┐ │
│ └─────────┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌─────────────────────────────────────────┐ │ │
│ │ 神经网络编码器 │ │ │
│ │ h = Encoder_θ(X) │ │ │
│ └─────────────────────────────────────────┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌─────────────────────────────────────────┐ │ │
│ │ 因果概念层 (C) │ │ │
│ │ C = [C₁, C₂, ..., Cₖ] │ │ │
│ │ P(C | do(X)) = ... │ │ │
│ └─────────────────────────────────────────┘ │ │
│ │ │ │
│ ▼ │ │
│ ┌─────────────────────────────────────────┐ │ │
│ │ 因果预测层 (Y) │ │ │
│ │ Y = f(C, do(θ)) │ │ │
│ │ P(Y | do(C)) = ... │ │ │
│ └─────────────────────────────────────────┘ │ │
│ │
└─────────────────────────────────────────────────────────────┘
3.2 节点类型
| 节点类型 | 功能 | 因果语义 |
|---|---|---|
| 输入节点 | 编码输入特征 | 观测变量 |
| 概念节点 | 表示因果概念 | 中间因果变量 |
| 干预节点 | 表示do操作 | 干预变量 |
| 输出节点 | 最终预测 | 响应变量 |
| 混淆节点 | 未观测因素 | 潜在变量 |
3.3 条件概率表(CPT)表示
每个因果变量的条件概率可以用概率电路表示:
class CausalConceptNode(nn.Module):
"""因果概念节点"""
def __init__(self, name, parents, cpt_network):
super().__init__()
self.name = name
self.parents = parents # 父节点列表
self.cpt_net = cpt_network # 神经网络实现的CPT
def forward(self, parent_values, do_interventions={}):
"""
计算概念的条件概率分布
Args:
parent_values: 父节点的当前值
do_interventions: 干预字典 {var_name: value}
"""
# 如果当前变量被干预,直接返回干预值
if self.name in do_interventions:
return do_interventions[self.name]
# 否则计算条件分布
context = torch.cat([
parent_values,
torch.tensor([do_interventions.get(p, 0) for p in self.parents])
])
return self.cpt_net(context)
def do_distribution(self, intervention_value):
"""
计算 P(C = c | do(C = intervention_value))
即强制设置该概念为某值
"""
# do操作使变量独立于其父节点
return torch.eye(self.num_values)[intervention_value]4. 核心算法
4.1 do算子实现
class CausalNPC(nn.Module):
"""因果神经概率电路"""
def __init__(self, causal_graph):
super().__init__()
self.graph = causal_graph # 因果图结构
self.concept_nodes = nn.ModuleDict()
self.encoder = None
def do_intervention(self, variable, value, circuit_state):
"""
执行do(X = value)操作
效果:
1. 移除所有指向X的边
2. 强制X = value
"""
# 1. 创建干预后的电路状态
intervened_state = circuit_state.copy()
# 2. 设置干预变量
intervened_state[variable] = value
# 3. 移除从父节点到X的依赖
parents = self.graph.parents[variable]
for parent in parents:
# 断开边:parent -> variable
self.disconnect(parent, variable)
return intervened_state
def compute_causal_effect(self, cause, effect, cause_value):
"""
计算因果效应 P(effect | do(cause = cause_value))
使用后门调整公式:
P(Y | do(X=x)) = Σ_z P(Y | X=x, Z=z) P(Z=z)
"""
# 获取混淆变量Z(cause的父节点和后门路径上的变量)
confounders = self.get_confounders(cause)
total_effect = 0
for z_value in self.enumerate_assignments(confounders):
# P(Z = z)
p_z = self.compute_marginal(confounders, z_value)
# P(Y | X=x, Z=z)
evidence = {cause: cause_value, **z_value}
p_y_given = self.compute_conditional(effect, evidence)
total_effect += p_y_given * p_z
return total_effect
def counterfactual(self, individual, hypothetical):
"""
反事实推理:
给定观测数据,评估假设情景
三步过程:
1. Abduction: 根据观测推断潜在变量
2. Action: 执行干预
3. Prediction: 预测结果
"""
observation = individual['observation']
intervention = hypothetical['intervention']
query = hypothetical['query']
# Step 1: Abduction
# P(U | observed_data) - 推断潜在变量后验
u_posterior = self.abduction(observation)
# Step 2: Action
# 执行干预 do(intervention)
modified_circuit = self.apply_intervention(
intervention,
individual['circuit_state']
)
# Step 3: Prediction
# 使用更新后的电路和U的后验预测结果
cf_outcome = self.predict(
query,
circuit=modified_circuit,
latent_posterior=u_posterior
)
return cf_outcome4.2 干预效果估计
def estimate_average_treatment_effect(self, treatment_var, outcome_var,
dataset):
"""
估计平均 treatment effect (ATE):
ATE = E[Y | do(T=1)] - E[Y | do(T=0)]
"""
# 处理组潜在结果
y_do_1 = self.compute_causal_effect(
cause=treatment_var,
effect=outcome_var,
cause_value=1
)
# 对照组潜在结果
y_do_0 = self.compute_causal_effect(
cause=treatment_var,
effect=outcome_var,
cause_value=0
)
return y_do_1 - y_do_0
def estimate_conditional_treatment_effect(self, treatment_var, outcome_var,
condition_var, condition_value):
"""
估计条件 treatment effect (CATE):
CATE = E[Y | do(T=1), C=c] - E[Y | do(T=0), C=c]
"""
# 添加条件变量到证据
evidence = {condition_var: condition_value}
y_do_1 = self.compute_causal_effect(
cause=treatment_var,
effect=outcome_var,
cause_value=1,
condition=evidence
)
y_do_0 = self.compute_causal_effect(
cause=treatment_var,
effect=outcome_var,
cause_value=0,
condition=evidence
)
return y_do_1 - y_do_04.3 测试时概念修正
class ConceptCorrectionInterface:
"""测试时概念修正接口"""
def __init__(self, causal_npc):
self.model = causal_npc
def predict_with_intervention(self, x, corrections={}):
"""
预测并允许概念修正
Args:
x: 输入样本
corrections: {concept_name: corrected_value}
Returns:
predictions: 最终预测
concept_probs: 概念概率分布(修正前)
explanation: 预测解释
"""
# 1. 前向传播获取概念分布
concept_probs = self.model.forward_concepts(x)
# 2. 应用修正(do操作)
interventions = {}
for concept, corrected_value in corrections.items():
interventions[concept] = corrected_value
# 3. 预测(带有干预)
final_predictions = self.model.predict_with_do(
x,
do_interventions=interventions
)
# 4. 生成解释
explanation = self.generate_explanation(
concept_probs,
corrections,
final_predictions
)
return final_predictions, concept_probs, explanation
def what_if_scenario(self, x, concept_changes):
"""
"What if" 场景分析
Example:
what_if_scenario(x, {"has_wings": 1, "is_red": 0})
询问: 如果鸟有翅膀但不是红色的,预测会如何变化?
"""
interventions = concept_changes
# 计算干预后的预测
cf_predictions = self.model.predict_with_do(
x,
do_interventions=interventions
)
# 计算原始预测
original_predictions = self.model.forward(x)
# 比较差异
diff = self.compute_difference(
cf_predictions,
original_predictions
)
return {
"original": original_predictions,
"counterfactual": cf_predictions,
"difference": diff
}5. 实现细节
5.1 因果图构建
from collections import defaultdict
class CausalGraph:
"""因果图"""
def __init__(self):
self.adjacency = defaultdict(list) # parent -> [children]
self.reverse_adj = defaultdict(list) # child -> [parents]
self.nodes = set()
self.observed_nodes = set() # 观测变量
self.latent_nodes = set() # 潜在变量
def add_edge(self, parent, child, observed=True):
"""添加因果边 parent -> child"""
self.adjacency[parent].append(child)
self.reverse_adj[child].append(parent)
self.nodes.add(parent)
self.nodes.add(child)
if observed:
self.observed_nodes.add(parent)
self.observed_nodes.add(child)
def add_latent_edge(self, parent, child):
"""添加未观测的因果边(用虚线表示)"""
self.add_edge(parent, child, observed=False)
self.latent_nodes.add(parent)
self.latent_nodes.add(child)
def parents(self, node):
"""获取节点的父节点"""
return self.reverse_adj.get(node, [])
def children(self, node):
"""获取节点的子节点"""
return self.adjacency.get(node, [])
def descendants(self, node):
"""获取节点的后代"""
result = set()
queue = [node]
while queue:
current = queue.pop()
for child in self.children(current):
if child not in result:
result.add(child)
queue.append(child)
return result
def ancestors(self, node):
"""获取节点的祖先"""
result = set()
queue = [node]
while queue:
current = queue.pop()
for parent in self.parents(current):
if parent not in result:
result.add(parent)
queue.append(parent)
return result
def is_d_separated(self, x, y, z):
"""
检查X和Y是否在给定Z时d-分离
d-分离意味着X和Y条件独立于Z
"""
# 构建被Z阻隔的图
blocked = set(z)
# BFS寻找连接路径
def has_connection(start, end, blocked_set):
visited = set()
queue = [(start, None)] # (node, path_type)
while queue:
node, path_type = queue.pop(0)
if node == end:
return True
if node in visited:
continue
visited.add(node)
# 检查前向路径
for child in self.children(node):
if child not in blocked_set:
queue.append((child, "forward"))
# 检查后向路径
for parent in self.parents(node):
if parent not in blocked_set:
# 串行/分叉路径在collider处不被阻隔
if path_type != "backward" or node not in blocked_set:
queue.append((parent, "backward"))
# 检查V-结构(碰撞)
for child in self.children(node):
if child in blocked_set:
# 碰撞点在blocked时,路径被阻隔
if node not in blocked_set:
for grandchild in self.children(child):
queue.append((grandchild, "forward"))
return False
return not has_connection(x, y, blocked)5.2 完整模型实现
class CausalNeuralProbabilisticCircuit(nn.Module):
"""因果神经概率电路完整实现"""
def __init__(self, concept_names, outcome_name, encoder_dim=512):
super().__init__()
self.concept_names = concept_names
self.outcome_name = outcome_name
# 编码器
self.encoder = nn.Sequential(
nn.Linear(input_dim, encoder_dim),
nn.ReLU(),
nn.Linear(encoder_dim, encoder_dim)
)
# 概念层
self.concept_layers = nn.ModuleDict()
for name in concept_names:
self.concept_layers[name] = ConceptLayer(
input_dim=encoder_dim,
output_dim=2 # 二值概念
)
# 因果结构(可学习)
self.causal_structure = CausalStructure(
concepts=concept_names,
outcome=outcome_name
)
# 预测层
self.predictor = PredictorLayer(
input_dim=len(concept_names) * 2,
output_dim=num_classes
)
# 概率电路组件
self.pc_components = ProbabilisticCircuitComponents()
def forward(self, x, interventions={}):
"""
前向传播(无干预)
"""
# 1. 编码输入
h = self.encoder(x)
# 2. 推断概念
concepts = {}
for name, layer in self.concept_layers.items():
# 如果被干预,使用干预值
if name in interventions:
concepts[name] = interventions[name]
else:
concepts[name] = layer(h)
# 3. 构建概念表示
concept_repr = torch.cat([concepts[name] for name in self.concept_names], dim=-1)
# 4. 预测
logits = self.predictor(concept_repr)
return logits
def predict_with_do(self, x, do_interventions):
"""
预测(带有do干预)
do_interventions: {variable_name: value}
"""
# 移除被干预变量的父节点依赖
h = self.encoder(x)
concepts = {}
for name, layer in self.concept_layers.items():
if name in do_interventions:
# do操作:强制设置值
concepts[name] = F.one_hot(
torch.tensor(do_interventions[name]),
num_classes=2
).float().to(x.device)
else:
concepts[name] = layer(h)
concept_repr = torch.cat([concepts[name] for name in self.concept_names], dim=-1)
logits = self.predictor(concept_repr)
return logits
def compute_causal_effect(self, cause, effect, cause_value, x):
"""
计算因果效应
"""
# P(effect | do(cause))
y_do = self.predict_with_do(x, {cause: cause_value})
# P(effect | do(cause ≠ cause_value))
other_value = 1 - cause_value
y_do_other = self.predict_with_do(x, {cause: other_value})
return y_do - y_do_other
class ConceptLayer(nn.Module):
"""概念层"""
def __init__(self, input_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, input_dim // 2),
nn.ReLU(),
nn.Linear(input_dim // 2, output_dim)
)
def forward(self, h):
logits = self.net(h)
return F.softmax(logits, dim=-1)
class PredictorLayer(nn.Module):
"""预测层"""
def __init__(self, input_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, input_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(input_dim, output_dim)
)
def forward(self, x):
return self.net(x)6. 应用示例
6.1 医疗诊断应用
class MedicalDiagnosisCausalNPC:
"""
医疗诊断因果NPC
因果图:
症状(S) ← 疾病(D) → 治疗(T)
概念:发烧、咳嗽、胸痛、血糖...
标签:疾病类型
"""
def __init__(self):
# 定义因果结构
self.concepts = ['fever', 'cough', 'chest_pain', 'blood_sugar']
self.outcome = 'disease'
self.model = CausalNeuralProbabilisticCircuit(
concept_names=self.concepts,
outcome_name=self.outcome
)
def diagnose_with_intervention(self, patient_data, corrections={}):
"""
诊断并允许医生修正概念预测
corrections: 医生修正的预测
例如:{"fever": 1, "cough": 0}
"""
x = self.prepare_patient_data(patient_data)
# 获取预测(允许修正)
predictions = self.model.predict_with_do(
x,
do_interventions=corrections
)
# 获取概念概率
concept_probs = {}
for name in self.concepts:
if name not in corrections:
concept_probs[name] = self.model.concept_layers[name](
self.model.encoder(x)
)
else:
concept_probs[name] = F.one_hot(
torch.tensor(corrections[name]),
num_classes=2
).float()
# 生成诊断报告
report = self.generate_report(
predictions,
concept_probs,
corrections
)
return report
def estimate_treatment_effect(self, patient_data, treatment, outcome):
"""
估计治疗效果
Example:
effect = estimate_treatment_effect(
patient_data,
treatment='antibiotics',
outcome='recovery'
)
"""
x = self.prepare_patient_data(patient_data)
# 计算ATE
ate = self.model.compute_causal_effect(
cause=treatment,
effect=outcome,
cause_value=1,
x=x
)
return ate
def what_if_treatment(self, patient_data, treatment_changes):
"""
反事实:不同治疗方案的效果对比
Example:
what_if_treatment(patient, {"surgery": 1, "medication": 0})
"""
x = self.prepare_patient_data(patient_data)
results = {}
for treatment_plan, value in treatment_changes.items():
prediction = self.model.predict_with_do(
x,
do_interventions={treatment_plan: value}
)
results[treatment_plan] = prediction
return results6.2 推荐系统应用
class RecommenderCausalNPC:
"""
因果推荐NPC
因果图:
用户特征(U) → 偏好(P) → 评分(R) ← 项目特征(I)
↓
推荐(A)
"""
def estimate_recommendation_effect(self, user_data, item_data,
recommendation):
"""
估计推荐对用户行为的因果影响
例如:
- 推荐商品A会提高用户购买概率吗?
- 推荐电影B会影响用户满意度吗?
"""
# 构建输入
x = self.combine_features(user_data, item_data)
# 计算干预效果
effect = self.model.compute_causal_effect(
cause='recommendation',
effect='engagement',
cause_value=recommendation,
x=x
)
return effect
def counterfactual_recommendation(self, historical_data,
alternative_recommendation):
"""
反事实推荐分析
如果推荐了不同的商品,结果会如何?
"""
# 推断潜在因素
u_posterior = self.abduction(historical_data)
# 应用反事实干预
cf_outcome = self.counterfactual(
individual={
'observation': historical_data,
'circuit_state': self.get_circuit_state(historical_data)
},
hypothetical={
'intervention': {'recommendation': alternative_recommendation},
'query': 'engagement'
}
)
return cf_outcome7. 与其他方法对比
7.1 因果推断方法对比
| 方法 | 推断类型 | 可扩展性 | 可解释性 | 不确定性 |
|---|---|---|---|---|
| Causal NPC | 精确因果 | 中等 | 高 | ✓ |
| 标准CBM | 关联 | 高 | 中等 | ✗ |
| 变分因果发现 | 近似因果 | 高 | 低 | ✓ |
| 结构方程模型 | 精确因果 | 低 | 高 | ✓ |
| 因果森林 | 近似因果 | 高 | 中等 | ✓ |
7.2 Causal NPC的优势
- 精确推断: 利用概率电路实现精确因果推断
- 可解释: 推理路径完全透明
- 可干预: 原生支持do操作
- 不确定性: 概率表示支持不确定性量化
7.3 局限性
- 图结构假设: 需要预先定义因果图
- 可扩展性: 复杂图结构可能计算困难
- 潜在变量: 难以处理大量未观测混淆
8. 实践指南
8.1 因果图构建建议
- 从领域知识出发: 利用专家知识定义因果关系
- 验证假设: 使用因果发现算法验证/补充假设
- 简化结构: 避免过度复杂的图结构
- 处理混淆: 识别并标注未观测变量
8.2 训练技巧
# 因果一致性正则化
def causal_consistency_loss(model, x, y):
"""
鼓励模型学习因果关系而非虚假关联
"""
# 1. 计算原始预测
pred_original = model(x)
# 2. 计算干预预测
concepts = model.extract_concepts(x)
interventions = {k: 1 - v for k, v in concepts.items()}
pred_intervened = model.predict_with_do(x, interventions)
# 3. 因果一致性损失
# 如果干预改变概念,预测应该相应改变
consistency_loss = F.mse_loss(pred_original, pred_intervened)
return consistency_loss
# 多任务训练
def train_with_causal_regularization(model, loader, alpha=0.1):
"""
带因果正则化的训练
"""
for x, y in loader:
# 标准损失
pred = model(x)
ce_loss = F.cross_entropy(pred, y)
# 因果一致性损失
causal_loss = causal_consistency_loss(model, x, y)
# 总损失
loss = ce_loss + alpha * causal_loss
loss.backward()
optimizer.step()