因果可解释性基础理论

1. 引言

因果可解释性(Causal Explainability)旨在揭示机器学习模型决策的因果机制,而不仅仅是统计关联。与传统解释方法不同,因果解释可以回答”如果输入改变,输出会如何变化”以及”为什么模型做出这个决策”。1

1.1 为什么需要因果可解释性

解释类型能力局限
特征重要性识别重要特征无法区分相关vs因果
注意力可视化显示关注区域可能误导
反事实解释揭示因果机制计算成本高
概念解释人类可理解需要概念定义

1.2 因果解释的目标

  1. 忠实性(Faithfulness):解释真实反映模型行为
  2. 可操作性(Actionability):提供可行的改进建议
  3. 人类可理解性(Human Understandability):与人类认知一致
  4. 因果有效性(Causal Validity):基于正确的因果假设

2. 反事实解释

2.1 基础定义

反事实解释(Counterfactual Explanations)2

定义:找到最小变化的输入 ,使得模型输出从 变为 ,且 尽可能相似。

数学形式化

对于输入 和预测 ,反事实解释 满足:

  1. 有效性
  2. 最小性 最小
  3. 可行性 满足某些约束(如现实可行性)

2.2 优化框架

class CounterfactualExplainer:
    """
    反事实解释器
    
    优化目标:
    min d(x, x') + λ · L_cost(x')
    s.t. f(x') ≠ f(x)
    """
    
    def __init__(self, model, distance_fn='l2'):
        self.model = model
        self.distance_fn = distance_fn
    
    def find_counterfactual(self, x, target_class=None, 
                           constraints=None, lambda_reg=0.1):
        """
        寻找反事实解释
        
        Args:
            x: 原始输入
            target_class: 目标类别(若为None,则寻找任何变化)
            constraints: 可行性约束
            lambda_reg: 正则化参数
        """
        x = torch.tensor(x, requires_grad=True)
        optimizer = torch.optim.Adam([x], lr=0.01)
        
        for step in range(1000):
            optimizer.zero_grad()
            
            # 计算距离损失
            if self.distance_fn == 'l2':
                distance_loss = torch.sum((x - x_original) ** 2)
            elif self.distance_fn == 'l1':
                distance_loss = torch.sum(torch.abs(x - x_original))
            
            # 计算有效性损失
            output = self.model(x)
            
            if target_class is not None:
                validity_loss = F.cross_entropy(output.unsqueeze(0), 
                                               torch.tensor([target_class]))
            else:
                # 任何类别变化
                current_class = output.argmax()
                validity_loss = F.relu(current_class - current_class + 1)  # 鼓励变化
            
            # 约束损失
            constraint_loss = 0
            if constraints:
                for constraint_fn in constraints:
                    constraint_loss += constraint_fn(x)
            
            # 总损失
            total_loss = distance_loss + lambda_reg * validity_loss + constraint_loss
            
            total_loss.backward()
            optimizer.step()
            
            # 检查收敛
            if output.argmax() != original_class:
                break
        
        return x.detach().numpy()

2.3 DiCE框架

Diverse Counterfactual Explanations (DiCE)3

核心思想:生成多个多样化的反事实解释,帮助用户理解决策边界。

class DiCEExplainer:
    """
    DiCE: 多样化反事实解释
    
    目标:
    min Σ d(x, x_i) - λ · Diversity(x_i)
    s.t. f(x_i) ≠ f(x)
    """
    
    def __init__(self, model, X_train, feature_types):
        self.model = model
        self.X_train = X_train
        self.feature_types = feature_types
    
    def generate_dice(self, x, n_cf=5, diversity_weight=0.5):
        """
        生成多样化反事实
        """
        counterfactuals = []
        
        for i in range(n_cf):
            # 生成反事实
            cf = self._generate_single_cf(
                x, 
                diversity_weight=diversity_weight * (i / n_cf)
            )
            counterfactuals.append(cf)
        
        return counterfactuals
    
    def _diversity_penalty(self, counterfactuals):
        """
        多样性惩罚:鼓励反事实之间的差异
        """
        penalty = 0
        for i, cf1 in enumerate(counterfactuals):
            for cf2 in counterfactuals[i+1:]:
                # 使用距离度量多样性
                dist = torch.sum((cf1 - cf2) ** 2)
                penalty -= dist  # 最大化距离 = 最小化负距离
        
        return penalty

2.4 可执行反事实

Executable Counterfactuals(ICML 2025)4

核心思想:不仅生成反事实,还要确保反事实在现实中可行。

可执行反事实的约束:

1. 特征可行性约束
   - 年龄:18-100
   - 收入:非负
   - 职业:有效职业列表

2. 变化可行性约束
   - 最小变化原则
   - 可逆性约束
   - 时间一致性

3. 因果可行性约束
   - 不违反因果结构
   - 沿可行路径干预

3. 因果归因

3.1 因果归因基础

因果归因(Causal Attribution)5

核心思想:将模型的输出变化归因于输入特征的因果效应。

3.2 神经网络因果归因框架

class CausalAttribution:
    """
    因果归因框架
    
    将神经网络视为结构因果模型
    """
    
    def __init__(self, model):
        self.model = model
    
    def attribute(self, x, y_target=None):
        """
        计算输入特征的因果归因
        
        返回每个特征的因果效应
        """
        # 建立因果图
        graph = self._build_causal_graph(x)
        
        # 计算因果效应
        attributions = self._compute_causal_effects(graph, x, y_target)
        
        return attributions
    
    def _build_causal_graph(self, x):
        """
        建立神经网络的因果图
        
        每个神经元是一个因果节点
        """
        # 简化为层级别
        graph = {}
        
        # 输入层
        for i, feat in enumerate(x):
            graph[f'input_{i}'] = {
                'value': feat,
                'children': [],  # 后继节点
                'parents': []
            }
        
        # 隐藏层
        for layer_idx, layer in enumerate(self.model.layers):
            layer_name = f'hidden_{layer_idx}'
            graph[layer_name] = {
                'value': None,
                'children': [],
                'parents': []
            }
            
            # 连接边
            for neuron in range(layer.output_dim):
                neuron_name = f'{layer_name}_{neuron}'
                graph[neuron_name] = {
                    'value': None,
                    'parents': [layer_name],
                    'children': []
                }
                graph[layer_name]['children'].append(neuron_name)
        
        return graph

3.3 Causal Attribution via Interchange Interventions

核心方法

class InterchangeIntervention:
    """
    交换干预归因
    
    通过交换干预比较不同输入的影响
    """
    
    def __init__(self, model, baseline_x):
        self.model = model
        self.baseline_x = baseline_x
    
    def compute_attribution(self, x, layer, neuron):
        """
        计算特定神经元对输出的因果贡献
        
        通过交换干预:
        1. 用baseline值替换目标神经元
        2. 观察输出变化
        """
        # 基准输出
        output_baseline = self.model(x)
        
        # 创建干预版本
        x_intervened = x.clone()
        x_intervened[layer][neuron] = self.baseline_x[layer][neuron]
        
        # 干预后输出
        output_intervened = self.model(x_intervened)
        
        # 因果归因 = 输出差异
        attribution = output_intervened - output_baseline
        
        return attribution

4. 因果抽象

4.1 因果抽象定义

因果抽象(Causal Abstraction,Beckers & Halpern, 2019)6

定义:建立高层因果模型与低层实现(神经网络)之间的对应关系。

因果抽象层次:

高层(因果模型)        低层(神经网络)
┌─────────────────┐    ┌─────────────────┐
│  因果图 G        │ ←→ │  神经网络 N     │
│                 │    │                 │
│  X → Z → Y      │    │  layers_1,2,3   │
│  (概念层级)      │    │  (神经元层级)    │
└─────────────────┘    └─────────────────┘

抽象关系满足:
1. 因果等价性
2. 因果充分性  
3. 因果精确性

4.2 神经网络的因果抽象

class CausalAbstraction:
    """
    因果抽象分析
    
    分析神经网络是否实现特定因果模型
    """
    
    def __init__(self, model, high_level_graph):
        self.model = model
        self.high_level_graph = high_level_graph
    
    def test_abstraction(self, high_level_interventions, data):
        """
        测试抽象关系
        
        比较:
        1. 高层干预的因果效应(理论)
        2. 低层对应干预的效应(实证)
        """
        results = []
        
        for intervention in high_level_interventions:
            # 高层干预效果(理论计算)
            high_level_effect = self._compute_high_level_effect(
                intervention, 
                self.high_level_graph
            )
            
            # 低层干预效果(实证估计)
            low_level_effect = self._compute_low_level_effect(
                intervention,
                data
            )
            
            # 计算一致性
            consistency = self._measure_consistency(
                high_level_effect,
                low_level_effect
            )
            
            results.append({
                'intervention': intervention,
                'high_level_effect': high_level_effect,
                'low_level_effect': low_level_effect,
                'consistency': consistency
            })
        
        return results

4.3 RICA框架

Revised Interchange Intervention with Causal Abstraction

class RICA:
    """
    RICA: 因果抽象的交换干预
    
    分析表示层级之间的因果关系
    """
    
    def analyze_representations(self, representations, causal_graph):
        """
        分析表示的因果抽象
        """
        # 对每层表示进行干预
        layer_results = {}
        
        for layer_name, representation in representations.items():
            # 交换干预
            interventions = self._generate_interventions(
                representation,
                causal_graph
            )
            
            # 计算因果效应
            effects = self._compute_interchange_effects(
                interventions
            )
            
            layer_results[layer_name] = {
                'effects': effects,
                'abstraction_level': self._measure_abstraction_level(effects)
            }
        
        return layer_results

5. 概念瓶颈模型

5.1 概念瓶颈模型基础

Concept Bottleneck Models(Koh et al., ICML 2020)7

核心思想:显式建模人类可理解的概念,使预测过程透明。

概念瓶颈模型架构:

输入 X
    ↓
概念层 Z = g(X)  ← 中间瓶颈
    ↓
预测层 Y = h(Z)
    ↓
输出 Ŷ
class ConceptBottleneckModel(nn.Module):
    """
    概念瓶颈模型
    """
    
    def __init__(self, input_dim, concept_dim, output_dim):
        super().__init__()
        
        # 概念预测器
        self.concept_predictor = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, concept_dim),
            nn.Sigmoid()  # 概念概率
        )
        
        # 标签预测器
        self.label_predictor = nn.Sequential(
            nn.Linear(concept_dim, 64),
            nn.ReLU(),
            nn.Linear(64, output_dim)
        )
    
    def forward(self, x, intervene_concepts=None):
        """
        前向传播
        
        Args:
            x: 输入
            intervene_concepts: 可选的概念干预
        """
        # 预测概念
        concepts = self.concept_predictor(x)
        
        # 概念干预(用于反事实分析)
        if intervene_concepts is not None:
            concepts = intervene_concepts
        
        # 预测标签
        logits = self.label_predictor(concepts)
        
        return {
            'concepts': concepts,
            'logits': logits
        }

5.2 概念级别的干预与解释

class ConceptInterventionExplainer:
    """
    概念级别的反事实解释
    
    允许用户干预概念值,观察预测变化
    """
    
    def __init__(self, cbm_model, concepts):
        self.model = cbm_model
        self.concepts = concepts  # 概念名称列表
    
    def explain_prediction(self, x, target_idx=None):
        """
        解释预测
        
        返回每个概念的重要性
        """
        # 获取概念值
        with torch.no_grad():
            output = self.model(x)
            original_concepts = output['concepts'].numpy()
            original_logits = output['logits'].numpy()
        
        concept_importance = []
        
        for i, concept in enumerate(self.concepts):
            # 干预概念i:设为中性值
            intervened_concepts = original_concepts.copy()
            intervened_concepts[i] = 0.5  # 中性值
            
            with torch.no_grad():
                intervened_output = self.model(
                    x, 
                    intervene_concepts=torch.tensor(intervened_concepts)
                )
                intervened_logits = intervened_output['logits'].numpy()
            
            # 概念重要性 = 干预前后 logits 差异
            importance = np.abs(original_logits - intervened_logits)
            concept_importance.append({
                'concept': concept,
                'importance': importance,
                'original_value': original_concepts[i],
                'target_class': np.argmax(importance)
            })
        
        return sorted(concept_importance, 
                     key=lambda x: np.max(x['importance']), 
                     reverse=True)

5.3 概念瓶颈变体

变体特点论文
标准CBM显式概念层Koh et al., 2020
Post-hoc CBM训练后添加概念层Koh et al., 2020
Label-Free CBM不需概念标签Yuksekgonul et al., 2022
Hierarchical CBM概念层次结构2023
Probabilistic CBM概念不确定性2023

6. 机制可解释性

6.1 电路分析

电路分析(Circuit Analysis)8

核心思想:将神经网络分解为可解释的”电路”,分析信息处理机制。

class CircuitAnalysis:
    """
    电路分析
    """
    
    def __init__(self, model, task):
        self.model = model
        self.task = task
    
    def identify_circuit(self, behavior):
        """
        识别实现特定行为的电路
        
        Returns:
            circuit: 边列表 [(from_node, to_node, weight)]
        """
        # 使用激活 patching 识别关键路径
        clean_tokens = self.task.get_clean_tokens()
        corrupted_tokens = self.task.get_corrupted_tokens()
        
        # 计算 clean 和 corrupted 的差异
        clean_activations = self._get_activations(clean_tokens)
        corrupted_activations = self._get_activations(corrupted_tokens)
        
        # 差异激活
        diff_activations = clean_activations - corrupted_activations
        
        # 识别关键节点
        key_nodes = self._identify_key_nodes(diff_activations)
        
        # 重建电路
        circuit = self._reconstruct_circuit(key_nodes)
        
        return circuit

6.2 注意力电路

class AttentionCircuitAnalysis:
    """
    注意力电路分析
    """
    
    def analyze_attention_circuit(self, tokens, attention_patterns):
        """
        分析注意力模式背后的电路
        
        识别实现特定注意力行为的节点和边
        """
        # 识别 induction heads
        induction_heads = self._find_induction_heads(attention_patterns)
        
        # 分析电路组成
        circuits = []
        for head in induction_heads:
            circuit = self._analyze_head_circuit(head, tokens)
            circuits.append(circuit)
        
        return circuits
    
    def _find_induction_heads(self, attention_patterns):
        """
        寻找 induction heads
        
        Induction heads 实现 token copy 机制
        """
        # 寻找:
        # 1. Q从位置i的token查询K
        # 2. K匹配位置i-1的prev token
        # 3. V输出该token
        # ...
        pass

7. 因果解释的评估

7.1 评估维度

维度描述评估方法
忠实性解释反映真实模型行为干预-响应一致性
一致性类似输入产生类似解释稳定性测试
简洁性解释简洁稀疏性度量
可操作性提供可行建议用户研究

7.2 Faithfulness测试

class FaithfulnessEvaluator:
    """
    忠实性评估
    """
    
    def __init__(self, model, explainer):
        self.model = model
        self.explainer = explainer
    
    def evaluate_faithfulness(self, test_data, n_samples=100):
        """
        评估解释的忠实性
        
        核心思想:重要的特征应该对预测有更大影响
        """
        faithfulness_scores = []
        
        for x in test_data[:n_samples]:
            # 获取解释
            explanation = self.explainer.explain(x)
            
            # 验证:重要特征应该对预测有影响
            # 干预重要特征,观察输出变化
            important_features = explanation['top_features']
            
            # 原始预测
            pred_orig = self.model.predict(x)
            
            # 干预后的预测
            x_intervened = x.copy()
            x_intervened[important_features] = random_baseline
            pred_intervened = self.model.predict(x_intervened)
            
            # 忠实性 = 干预后预测变化
            faithfulness = self._measure_change(pred_orig, pred_intervened)
            faithfulness_scores.append(faithfulness)
        
        return np.mean(faithfulness_scores)

8. 参考文献


相关主题

Footnotes

  1. Pearl, J., & Mackenzie, D. (2018). The Book of Why. Basic Books.

  2. Wachter, S., et al. (2018). Counterfactual explanations without opening the black box. Harvard Law Review.

  3. Mothilal, R. K., et al. (2020). Explaining machine learning classifiers through diverse counterfactual explanations. FAT*.

  4. Executable Counterfactuals Authors. (2025). Executable Counterfactuals. ICML 2025.

  5. Beckers, S., & Halpern, J. Y. (2019). Abstracting causation. Synthese.

  6. Geiger, A., et al. (2021). Causal Abstractions of Neural Networks. NeurIPS.

  7. Koh, P. W., et al. (2020). Concept Bottleneck Models. ICML.

  8. Elhage, N., et al. (2022). A Mathematical Framework for Transformer Circuits. Transformer Circuits Thread.