LLM因果可解释性

1. 引言

大型语言模型(LLMs)由于其数十亿参数和复杂的非线性结构,被视为”黑盒”。因果可解释性方法旨在揭示LLM内部工作的因果机制,回答”模型如何做出决策”这一根本问题。1

1.1 为什么需要因果可解释性

问题因果视角
模型行为不透明揭示信息流和因果路径
偏见难以解释识别偏见传播的因果机制
幻觉来源不明追踪错误生成的根本原因
安全风险理解有害输出的因果触发器

1.2 LLM可解释性的挑战

  1. 规模:数十亿参数难以分析
  2. 非线性:复杂的注意力机制
  3. 涌现能力:难以从组件行为预测整体
  4. 因果结构未知:没有明确的因果图

2. Transformer电路视角

2.1 电路视角基础

Transformer电路(Transformer Circuits Framework,Elhage et al., 2021)2

核心思想:将Transformer视为由注意力头和MLP组成的电路,分析信息处理机制。

Transformer电路分解:

┌─────────────────────────────────────────────────────┐
│                    Transformer                       │
├─────────────────────────────────────────────────────┤
│                                                     │
│  Input Tokens                                       │
│       ↓                                              │
│  ┌─────────────────────────────────────────────┐   │
│  │           Attention Heads                     │   │
│  │  Head 1.1  Head 1.2  ... Head 1.h         │   │
│  │      ↓          ↓              ↓            │   │
│  │  Key₁      Key₂        Keyₕ                │   │
│  │  Query₁    Query₂      Queryₕ              │   │
│  │  Value₁    Value₂      Valueₕ              │   │
│  └─────────────────────────────────────────────┘   │
│       ↓                                              │
│  ┌─────────────────────────────────────────────┐   │
│  │           MLP Layers                         │   │
│  │       Hidden₁  Hidden₂  ... Hiddenₙ         │   │
│  └─────────────────────────────────────────────┘   │
│       ↓                                              │
│  Output                                             │
│                                                     │
└─────────────────────────────────────────────────────┘

2.2 关键电路模式

class TransformerCircuitPatterns:
    """
    Transformer中的关键电路模式
    """
    
    # Induction Heads: 模式匹配和复制
    INDUCTION_HEAD = """
    电路模式:Induction Head
    
    功能:从序列中的先前位置"复制"相关token
    
    机制:
    1. 后续token的Query寻找KV匹配
    2. 定位到AB模式中的B位置
    3. 输出B的Value
    
    应用:上下文学习、模式完成
    """
    
    # G最大似然头:预测下一个token
    GMAX_HEAD = """
    电路模式:G最大似然头
    
    功能:使用未嵌入的logits直接预测下一个token
    
    机制:
    1. 跳过Value计算
    2. 直接使用Query和Key
    3. 最大化与正确token的注意力
    """
    
    # 抑制token头:过滤特定token
    SUPPRESSION_HEAD = """
    电路模式:抑制token头
    
    功能:抑制或强调特定token
    
    机制:
    1. 异常token的Query与特殊KV匹配
    2. 通过Value注入抑制信号
    """

3. 稀疏自编码器(SAE)

3.1 叠加与多义性

叠加假说(Superposition Hypothesis,Elhage et al., 2022)3

核心洞察:神经元是多义的,单个神经元响应多个不相关的特征。

叠加示意图:

理想情况(独立特征):
  特征1: ───●────────────
  特征2: ──────●─────────
  特征3: ──────────●─────
  
实际情况(叠加):
  特征1: ───●─────●──────
  特征2: ──●────────●────
  特征3: ────●───●───────
  
每个神经元响应多个特征!

问题:如何从叠加中分离出独立的可解释特征?

3.2 SAE架构

稀疏自编码器(Sparse Autoencoders)4

class SparseAutoencoder(nn.Module):
    """
    稀疏自编码器
    
    将叠加的神经元活动分解为独立的稀疏特征
    """
    
    def __init__(self, d_model, n_features, sparsity_coef=1e-3):
        super().__init__()
        
        # 编码器
        self.encoder = nn.Linear(d_model, n_features, bias=True)
        
        # 解码器
        self.decoder = nn.Linear(n_features, d_model, bias=False)
        
        self.sparsity_coef = sparsity_coef
    
    def forward(self, x):
        """
        前向传播
        
        Args:
            x: 残差流的激活 [batch, seq_len, d_model]
        
        Returns:
            reconstructed: 重构
            features: 稀疏特征激活
        """
        # 编码:提取稀疏特征
        features = F.relu(self.encoder(x))
        
        # 重构:解码回原空间
        reconstructed = self.decoder(features)
        
        return {
            'reconstructed': reconstructed,
            'features': features,
            'l0_norm': (features > 0).sum(dim=-1).float()
        }
    
    def loss(self, x, output):
        """
        SAE损失函数
        
        L = ||x - x̂||² + λ · ||f||₁
        """
        reconstruction_loss = F.mse_loss(
            output['reconstructed'], 
            x
        )
        
        sparsity_loss = self.sparsity_coef * output['features'].abs().sum()
        
        total_loss = reconstruction_loss + sparsity_loss
        
        return {
            'total': total_loss,
            'reconstruction': reconstruction_loss,
            'sparsity': sparsity_loss
        }

3.3 Gemma Scope

Gemma Scope(Google DeepMind, 2024)5

核心贡献:发布 Gemma 2 模型各层的 SAE。

# 使用Gemma Scope
from gemini import GemmaScope
 
# 加载SAE
sae = GemmaScope.from_pretrained(
    "google/gemma-scope-2b",
    layer=10,
    unit="default"
)
 
# 分析激活
activations = model.run_with_hooks(tokens)
features = sae.encode(activations)
 
# 查找相关特征
relevant_features = find_concept_features(
    features, 
    concept="programming_language"
)

3.4 特征分析

class FeatureAnalyzer:
    """
    SAE特征分析
    """
    
    def __init__(self, sae, model, tokenizer):
        self.sae = sae
        self.model = model
        self.tokenizer = tokenizer
    
    def analyze_feature(self, feature_idx):
        """
        分析单个特征的行为
        """
        # 获取激活该特征的token
        tokens = self._get_feature_activating_tokens(feature_idx)
        
        # 分析激活上下文
        contexts = self._get_activation_contexts(tokens)
        
        # 提取语义标签
        semantic_label = self._infer_semantic_label(contexts)
        
        # 计算特征统计
        stats = self._compute_feature_stats(feature_idx)
        
        return {
            'tokens': tokens,
            'contexts': contexts,
            'semantic_label': semantic_label,
            'stats': stats
        }
    
    def _get_feature_activating_tokens(self, feature_idx, n_samples=1000):
        """
        收集激活该特征的token
        """
        activating_tokens = []
        
        for _ in range(n_samples):
            # 随机文本
            text = self._random_text()
            tokens = self.tokenizer(text, return_tensors='pt')['input_ids']
            
            # 获取激活
            activations = self.model.run_with_hooks(tokens)
            features = self.sae.encode(activations)
            
            # 找到激活该特征的token
            activated = features[:, :, feature_idx] > 0.5
            for pos in activated.nonzero():
                activating_tokens.append(tokens[pos].item())
        
        return activating_tokens

4. 因果特征电路

4.1 Sparse Feature Circuits(ICLR 2025)

Sparse Feature Circuits(ICLR 2025)6

核心思想:将SAE分解的特征与电路分析结合,分析特征的因果作用。

class SparseFeatureCircuit:
    """
    稀疏特征电路分析
    """
    
    def __init__(self, model, sae):
        self.model = model
        self.sae = sae
    
    def trace_feature_circuit(self, feature_idx, behavior):
        """
        追踪特征实现特定行为的因果电路
        """
        # 识别特征相关的注意力头
        relevant_heads = self._find_relevant_heads(feature_idx)
        
        # 追踪信息流
        circuit = {
            'feature': feature_idx,
            'input_heads': [],
            'mlp_layers': [],
            'output_heads': []
        }
        
        # 分析每层的因果贡献
        for layer in range(model.n_layers):
            # 计算该层对特征的因果贡献
            causal_contribution = self._compute_causal_contribution(
                layer, feature_idx, behavior
            )
            
            if causal_contribution > threshold:
                circuit[f'layer_{layer}'] = causal_contribution
        
        return circuit
    
    def _compute_causal_contribution(self, layer, feature_idx, behavior):
        """
        计算层的因果贡献
        
        使用激活 patching
        """
        # 干净激活
        clean_acts = self._get_activations(layer)
        
        # 干预:替换该特征的激活
        patched_acts = clean_acts.clone()
        patched_acts[:, :, feature_idx] = 0
        
        # 计算行为变化
        clean_behavior = behavior(clean_acts)
        patched_behavior = behavior(patched_acts)
        
        return clean_behavior - patched_behavior

4.2 特征因果关系图

class FeatureCausalGraph:
    """
    特征因果关系图
    """
    
    def build_feature_graph(self, model, sae, behaviors):
        """
        构建特征的因果关系图
        
        节点:SAE特征
        边:因果关系
        """
        import networkx as nx
        
        G = nx.DiGraph()
        
        # 添加所有特征作为节点
        for feature_idx in range(sae.n_features):
            G.add_node(feature_idx)
        
        # 计算特征之间的因果关系
        for source in range(sae.n_features):
            for target in range(sae.n_features):
                if source != target:
                    # 测试source是否causes target
                    causal_strength = self._test_causal_relation(
                        source, target, behaviors
                    )
                    
                    if abs(causal_strength) > threshold:
                        G.add_edge(source, target, 
                                  weight=causal_strength)
        
        return G

5. 因果中介分析

5.1 Transformer中的因果中介分析

因果中介分析(Causal Mediation Analysis)7

核心思想:识别模型中实现特定因果效应的中介组件。

因果中介分析框架:

输入 X
   ↓
Layer 1  ← 直接效应
   ↓
Layer 2  ← 中介效应 1
   ↓
Layer 3  ← 中介效应 2
   ↓
   ↓
输出 Y

总效应 = 直接效应 + 间接效应(通过中介)

5.2 Transformer中介分析实现

class TransformerMediationAnalysis:
    """
    Transformer因果中介分析
    """
    
    def __init__(self, model):
        self.model = model
    
    def analyze_mediation(self, prompt, target_layer, target_head=None):
        """
        分析从输入到输出的中介
        
        分解:
        Total Effect = Direct Effect + Indirect Effect (via mediator)
        """
        # 基准输出
        baseline_output = self._get_output(prompt)
        
        # 直接效应(跳过目标层)
        direct_effect = self._get_direct_effect(prompt, target_layer)
        
        # 间接效应(通过目标层)
        indirect_effect = self._get_indirect_effect(prompt, target_layer)
        
        return {
            'total_effect': baseline_output,
            'direct_effect': direct_effect,
            'indirect_effect': indirect_effect,
            'mediation_ratio': indirect_effect / baseline_output
        }
    
    def _get_direct_effect(self, prompt, target_layer):
        """
        获取直接效应
        
        干预:将目标层的激活置零
        """
        def hook_fn(activation, hook):
            activation[:] = 0
            return activation
        
        with HookedTransformer() as hooked_model:
            hooked_model.set_hook(target_layer, hook_fn)
            output = hooked_model(prompt)
        
        return output
    
    def _get_indirect_effect(self, prompt, target_layer):
        """
        获取间接效应
        
        间接效应 = 总效应 - 直接效应
        """
        total = self._get_output(prompt)
        direct = self._get_direct_effect(prompt, target_layer)
        
        return total - direct

5.3 注意力作为因果中介

class AttentionMediationAnalysis:
    """
    注意力头的中介分析
    """
    
    def analyze_head_mediation(self, prompt, head_idx):
        """
        分析特定注意力头的中介效应
        """
        layer = head_idx // n_heads
        head = head_idx % n_heads
        
        # 运行模型获取注意力模式
        _, cache = self.model.run_with_cache(prompt)
        
        # 获取该头的注意力权重
        attention_pattern = cache[f'attn_{layer}'][0, head]
        
        # 分析注意力作为信息传递中介
        mediated_info = self._compute_mediated_info(
            attention_pattern,
            prompt
        )
        
        return {
            'head_idx': head_idx,
            'layer': layer,
            'head': head,
            'attention_pattern': attention_pattern,
            'mediated_info': mediated_info
        }

6. 因果归因方法

6.1 激活归因

class ActivationAttribution:
    """
    激活归因
    
    将输出变化归因于特定的激活模式
    """
    
    def __init__(self, model):
        self.model = model
    
    def attribute_to_features(self, tokens, target_feature):
        """
        将预测归因于SAE特征
        """
        # 获取激活
        activations = self.model.run_with_hooks(tokens)
        
        # 获取SAE特征
        features = self.sae.encode(activations)
        
        # 计算每个特征对输出的贡献
        contributions = []
        for feat_idx in range(features.shape[-1]):
            # 梯度:d_output / d_feature
            feat_importance = torch.autograd.grad(
                outputs=output,
                inputs=features,
                grad_outputs=torch.ones_like(output)
            )[0][:, :, feat_idx]
            
            contributions.append(feat_importance.mean().item())
        
        return contributions

6.2 路径归因

class PathAttribution:
    """
    路径归因
    
    追踪从输入到输出的因果路径
    """
    
    def trace_causal_path(self, input_tokens, output_token):
        """
        追踪因果路径
        """
        # 初始归因
        attr = self._initialize_attribution(input_tokens, output_token)
        
        # 逐层传播归因
        for layer in range(self.model.n_layers):
            # 注意力归因传播
            attr = self._propagate_attention(attr, layer)
            
            # MLP归因传播
            attr = self._propagate_mlp(attr, layer)
        
        return attr
    
    def _propagate_attention(self, attr, layer):
        """
        通过注意力层传播归因
        """
        # 获取注意力权重
        attn_weights = self.cache[f'attn_{layer}']
        
        # 归因 = 注意力权重 × 下游归因
        propagated = attn_weights @ attr
        
        return propagated

7. 实践工具

7.1 TransformerLens

# TransformerLens: Transformer可解释性研究工具
from transformer_lens import HookedTransformer
 
# 加载模型
model = HookedTransformer.from_pretrained("gpt2")
 
# 设置缓存
model.run_with_cache("The cat sat on the mat")
 
# 获取特定激活
def hook_fn(activation, hook):
    print(f"Attention pattern shape: {activation.shape}")
    return activation
 
model.run_with_hooks(
    "Hello world",
    fwd_hooks=[("blocks.1.attn.hook_pattern", hook_fn)]
)

7.2 SAELens

# SAELens: SAE分析工具
from sae_lens import SAE, ActivationsStore
 
# 加载SAE
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gpt2-small-resid-pre-layer-10",
    sae_id="blocks.10.hook_resid_pre"
)
 
# 分析激活
activations_store = ActivationsStore(
    model=model,
    activation_dim=cfg_dict['d_in'],
)
 
# 编码
tokens = model.tokenizer("Hello world", return_tensors='pt')['input_ids']
activations = model.run_with_hooks(tokens)
features = sae.encode(activations)

7.3 完整分析流程

class LLMInterpretabilityAnalysis:
    """
    LLM可解释性完整分析流程
    """
    
    def __init__(self, model_name):
        # 加载模型
        self.model = HookedTransformer.from_pretrained(model_name)
        
        # 加载SAE
        self.sae = self._load_sae(model_name)
        
        self.tokenizer = self.model.tokenizer
    
    def complete_analysis(self, prompt, target_behavior):
        """
        完整分析流程
        """
        # 1. 基本激活分析
        activations = self._analyze_activations(prompt)
        
        # 2. SAE特征分解
        features = self._decompose_features(activations)
        
        # 3. 关键特征识别
        key_features = self._identify_key_features(features, target_behavior)
        
        # 4. 因果路径追踪
        causal_paths = self._trace_causal_paths(key_features)
        
        # 5. 生成解释
        explanation = self._generate_explanation(
            key_features,
            causal_paths
        )
        
        return {
            'activations': activations,
            'features': features,
            'key_features': key_features,
            'causal_paths': causal_paths,
            'explanation': explanation
        }

8. 评估与基准

8.1 解释质量评估

class ExplanationQualityEvaluator:
    """
    解释质量评估
    """
    
    def evaluate_explanation(self, explanation, ground_truth):
        """
        评估解释质量
        
        维度:
        1. 正确性:与ground truth的一致性
        2. 完整性:覆盖所有相关组件
        3. 简洁性:解释的简洁程度
        4. 可操作性:能否指导干预
        """
        metrics = {
            'correctness': self._evaluate_correctness(explanation, ground_truth),
            'completeness': self._evaluate_completeness(explanation),
            'succinctness': self._evaluate_succinctness(explanation),
            'actionability': self._evaluate_actionability(explanation)
        }
        
        return metrics
    
    def _evaluate_correctness(self, explanation, ground_truth):
        """
        正确性评估
        
        比较解释识别的特征/电路与ground truth
        """
        explained_features = set(explanation['features'])
        true_features = set(ground_truth['features'])
        
        # Precision/Recall
        tp = len(explained_features & true_features)
        fp = len(explained_features - true_features)
        fn = len(true_features - explained_features)
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        return {'precision': precision, 'recall': recall, 'f1': f1}

8.2 基准数据集

基准描述用途
IOI归纳头测试电路分析验证
Greater-Than数值比较测试机制验证
Indirect Object Identification间接宾语识别注意力分析
SAE Feature Benchmark特征可解释性SAE质量评估

9. 参考文献


相关主题

Footnotes

  1. Elhage, N., et al. (2022). A Mathematical Framework for Transformer Circuits.

  2. Elhage, N., et al. (2021). Softmax Linear Units. Transformer Circuits Thread.

  3. Elhage, N., et al. (2022). Superposition, Memorization, and Double Descent. Transformer Circuits Thread.

  4. Bricken, T., et al. (2023). Towards Monosemanticity. Transformer Circuits Thread.

  5. DeepMind. (2024). Gemma Scope. GitHub.

  6. Sparse Feature Circuits Authors. (2025). Sparse Feature Circuits. ICLR 2025.

  7. Geiger, A., et al. (2024). Causal Abstraction in Large Language Models.