LLM因果可解释性
1. 引言
大型语言模型(LLMs)由于其数十亿参数和复杂的非线性结构,被视为”黑盒”。因果可解释性方法旨在揭示LLM内部工作的因果机制,回答”模型如何做出决策”这一根本问题。1
1.1 为什么需要因果可解释性
| 问题 | 因果视角 |
|---|---|
| 模型行为不透明 | 揭示信息流和因果路径 |
| 偏见难以解释 | 识别偏见传播的因果机制 |
| 幻觉来源不明 | 追踪错误生成的根本原因 |
| 安全风险 | 理解有害输出的因果触发器 |
1.2 LLM可解释性的挑战
- 规模:数十亿参数难以分析
- 非线性:复杂的注意力机制
- 涌现能力:难以从组件行为预测整体
- 因果结构未知:没有明确的因果图
2. Transformer电路视角
2.1 电路视角基础
Transformer电路(Transformer Circuits Framework,Elhage et al., 2021)2:
核心思想:将Transformer视为由注意力头和MLP组成的电路,分析信息处理机制。
Transformer电路分解:
┌─────────────────────────────────────────────────────┐
│ Transformer │
├─────────────────────────────────────────────────────┤
│ │
│ Input Tokens │
│ ↓ │
│ ┌─────────────────────────────────────────────┐ │
│ │ Attention Heads │ │
│ │ Head 1.1 Head 1.2 ... Head 1.h │ │
│ │ ↓ ↓ ↓ │ │
│ │ Key₁ Key₂ Keyₕ │ │
│ │ Query₁ Query₂ Queryₕ │ │
│ │ Value₁ Value₂ Valueₕ │ │
│ └─────────────────────────────────────────────┘ │
│ ↓ │
│ ┌─────────────────────────────────────────────┐ │
│ │ MLP Layers │ │
│ │ Hidden₁ Hidden₂ ... Hiddenₙ │ │
│ └─────────────────────────────────────────────┘ │
│ ↓ │
│ Output │
│ │
└─────────────────────────────────────────────────────┘
2.2 关键电路模式
class TransformerCircuitPatterns:
"""
Transformer中的关键电路模式
"""
# Induction Heads: 模式匹配和复制
INDUCTION_HEAD = """
电路模式:Induction Head
功能:从序列中的先前位置"复制"相关token
机制:
1. 后续token的Query寻找KV匹配
2. 定位到AB模式中的B位置
3. 输出B的Value
应用:上下文学习、模式完成
"""
# G最大似然头:预测下一个token
GMAX_HEAD = """
电路模式:G最大似然头
功能:使用未嵌入的logits直接预测下一个token
机制:
1. 跳过Value计算
2. 直接使用Query和Key
3. 最大化与正确token的注意力
"""
# 抑制token头:过滤特定token
SUPPRESSION_HEAD = """
电路模式:抑制token头
功能:抑制或强调特定token
机制:
1. 异常token的Query与特殊KV匹配
2. 通过Value注入抑制信号
"""3. 稀疏自编码器(SAE)
3.1 叠加与多义性
叠加假说(Superposition Hypothesis,Elhage et al., 2022)3:
核心洞察:神经元是多义的,单个神经元响应多个不相关的特征。
叠加示意图:
理想情况(独立特征):
特征1: ───●────────────
特征2: ──────●─────────
特征3: ──────────●─────
实际情况(叠加):
特征1: ───●─────●──────
特征2: ──●────────●────
特征3: ────●───●───────
每个神经元响应多个特征!
问题:如何从叠加中分离出独立的可解释特征?
3.2 SAE架构
稀疏自编码器(Sparse Autoencoders)4:
class SparseAutoencoder(nn.Module):
"""
稀疏自编码器
将叠加的神经元活动分解为独立的稀疏特征
"""
def __init__(self, d_model, n_features, sparsity_coef=1e-3):
super().__init__()
# 编码器
self.encoder = nn.Linear(d_model, n_features, bias=True)
# 解码器
self.decoder = nn.Linear(n_features, d_model, bias=False)
self.sparsity_coef = sparsity_coef
def forward(self, x):
"""
前向传播
Args:
x: 残差流的激活 [batch, seq_len, d_model]
Returns:
reconstructed: 重构
features: 稀疏特征激活
"""
# 编码:提取稀疏特征
features = F.relu(self.encoder(x))
# 重构:解码回原空间
reconstructed = self.decoder(features)
return {
'reconstructed': reconstructed,
'features': features,
'l0_norm': (features > 0).sum(dim=-1).float()
}
def loss(self, x, output):
"""
SAE损失函数
L = ||x - x̂||² + λ · ||f||₁
"""
reconstruction_loss = F.mse_loss(
output['reconstructed'],
x
)
sparsity_loss = self.sparsity_coef * output['features'].abs().sum()
total_loss = reconstruction_loss + sparsity_loss
return {
'total': total_loss,
'reconstruction': reconstruction_loss,
'sparsity': sparsity_loss
}3.3 Gemma Scope
Gemma Scope(Google DeepMind, 2024)5:
核心贡献:发布 Gemma 2 模型各层的 SAE。
# 使用Gemma Scope
from gemini import GemmaScope
# 加载SAE
sae = GemmaScope.from_pretrained(
"google/gemma-scope-2b",
layer=10,
unit="default"
)
# 分析激活
activations = model.run_with_hooks(tokens)
features = sae.encode(activations)
# 查找相关特征
relevant_features = find_concept_features(
features,
concept="programming_language"
)3.4 特征分析
class FeatureAnalyzer:
"""
SAE特征分析
"""
def __init__(self, sae, model, tokenizer):
self.sae = sae
self.model = model
self.tokenizer = tokenizer
def analyze_feature(self, feature_idx):
"""
分析单个特征的行为
"""
# 获取激活该特征的token
tokens = self._get_feature_activating_tokens(feature_idx)
# 分析激活上下文
contexts = self._get_activation_contexts(tokens)
# 提取语义标签
semantic_label = self._infer_semantic_label(contexts)
# 计算特征统计
stats = self._compute_feature_stats(feature_idx)
return {
'tokens': tokens,
'contexts': contexts,
'semantic_label': semantic_label,
'stats': stats
}
def _get_feature_activating_tokens(self, feature_idx, n_samples=1000):
"""
收集激活该特征的token
"""
activating_tokens = []
for _ in range(n_samples):
# 随机文本
text = self._random_text()
tokens = self.tokenizer(text, return_tensors='pt')['input_ids']
# 获取激活
activations = self.model.run_with_hooks(tokens)
features = self.sae.encode(activations)
# 找到激活该特征的token
activated = features[:, :, feature_idx] > 0.5
for pos in activated.nonzero():
activating_tokens.append(tokens[pos].item())
return activating_tokens4. 因果特征电路
4.1 Sparse Feature Circuits(ICLR 2025)
Sparse Feature Circuits(ICLR 2025)6:
核心思想:将SAE分解的特征与电路分析结合,分析特征的因果作用。
class SparseFeatureCircuit:
"""
稀疏特征电路分析
"""
def __init__(self, model, sae):
self.model = model
self.sae = sae
def trace_feature_circuit(self, feature_idx, behavior):
"""
追踪特征实现特定行为的因果电路
"""
# 识别特征相关的注意力头
relevant_heads = self._find_relevant_heads(feature_idx)
# 追踪信息流
circuit = {
'feature': feature_idx,
'input_heads': [],
'mlp_layers': [],
'output_heads': []
}
# 分析每层的因果贡献
for layer in range(model.n_layers):
# 计算该层对特征的因果贡献
causal_contribution = self._compute_causal_contribution(
layer, feature_idx, behavior
)
if causal_contribution > threshold:
circuit[f'layer_{layer}'] = causal_contribution
return circuit
def _compute_causal_contribution(self, layer, feature_idx, behavior):
"""
计算层的因果贡献
使用激活 patching
"""
# 干净激活
clean_acts = self._get_activations(layer)
# 干预:替换该特征的激活
patched_acts = clean_acts.clone()
patched_acts[:, :, feature_idx] = 0
# 计算行为变化
clean_behavior = behavior(clean_acts)
patched_behavior = behavior(patched_acts)
return clean_behavior - patched_behavior4.2 特征因果关系图
class FeatureCausalGraph:
"""
特征因果关系图
"""
def build_feature_graph(self, model, sae, behaviors):
"""
构建特征的因果关系图
节点:SAE特征
边:因果关系
"""
import networkx as nx
G = nx.DiGraph()
# 添加所有特征作为节点
for feature_idx in range(sae.n_features):
G.add_node(feature_idx)
# 计算特征之间的因果关系
for source in range(sae.n_features):
for target in range(sae.n_features):
if source != target:
# 测试source是否causes target
causal_strength = self._test_causal_relation(
source, target, behaviors
)
if abs(causal_strength) > threshold:
G.add_edge(source, target,
weight=causal_strength)
return G5. 因果中介分析
5.1 Transformer中的因果中介分析
因果中介分析(Causal Mediation Analysis)7:
核心思想:识别模型中实现特定因果效应的中介组件。
因果中介分析框架:
输入 X
↓
Layer 1 ← 直接效应
↓
Layer 2 ← 中介效应 1
↓
Layer 3 ← 中介效应 2
↓
↓
输出 Y
总效应 = 直接效应 + 间接效应(通过中介)
5.2 Transformer中介分析实现
class TransformerMediationAnalysis:
"""
Transformer因果中介分析
"""
def __init__(self, model):
self.model = model
def analyze_mediation(self, prompt, target_layer, target_head=None):
"""
分析从输入到输出的中介
分解:
Total Effect = Direct Effect + Indirect Effect (via mediator)
"""
# 基准输出
baseline_output = self._get_output(prompt)
# 直接效应(跳过目标层)
direct_effect = self._get_direct_effect(prompt, target_layer)
# 间接效应(通过目标层)
indirect_effect = self._get_indirect_effect(prompt, target_layer)
return {
'total_effect': baseline_output,
'direct_effect': direct_effect,
'indirect_effect': indirect_effect,
'mediation_ratio': indirect_effect / baseline_output
}
def _get_direct_effect(self, prompt, target_layer):
"""
获取直接效应
干预:将目标层的激活置零
"""
def hook_fn(activation, hook):
activation[:] = 0
return activation
with HookedTransformer() as hooked_model:
hooked_model.set_hook(target_layer, hook_fn)
output = hooked_model(prompt)
return output
def _get_indirect_effect(self, prompt, target_layer):
"""
获取间接效应
间接效应 = 总效应 - 直接效应
"""
total = self._get_output(prompt)
direct = self._get_direct_effect(prompt, target_layer)
return total - direct5.3 注意力作为因果中介
class AttentionMediationAnalysis:
"""
注意力头的中介分析
"""
def analyze_head_mediation(self, prompt, head_idx):
"""
分析特定注意力头的中介效应
"""
layer = head_idx // n_heads
head = head_idx % n_heads
# 运行模型获取注意力模式
_, cache = self.model.run_with_cache(prompt)
# 获取该头的注意力权重
attention_pattern = cache[f'attn_{layer}'][0, head]
# 分析注意力作为信息传递中介
mediated_info = self._compute_mediated_info(
attention_pattern,
prompt
)
return {
'head_idx': head_idx,
'layer': layer,
'head': head,
'attention_pattern': attention_pattern,
'mediated_info': mediated_info
}6. 因果归因方法
6.1 激活归因
class ActivationAttribution:
"""
激活归因
将输出变化归因于特定的激活模式
"""
def __init__(self, model):
self.model = model
def attribute_to_features(self, tokens, target_feature):
"""
将预测归因于SAE特征
"""
# 获取激活
activations = self.model.run_with_hooks(tokens)
# 获取SAE特征
features = self.sae.encode(activations)
# 计算每个特征对输出的贡献
contributions = []
for feat_idx in range(features.shape[-1]):
# 梯度:d_output / d_feature
feat_importance = torch.autograd.grad(
outputs=output,
inputs=features,
grad_outputs=torch.ones_like(output)
)[0][:, :, feat_idx]
contributions.append(feat_importance.mean().item())
return contributions6.2 路径归因
class PathAttribution:
"""
路径归因
追踪从输入到输出的因果路径
"""
def trace_causal_path(self, input_tokens, output_token):
"""
追踪因果路径
"""
# 初始归因
attr = self._initialize_attribution(input_tokens, output_token)
# 逐层传播归因
for layer in range(self.model.n_layers):
# 注意力归因传播
attr = self._propagate_attention(attr, layer)
# MLP归因传播
attr = self._propagate_mlp(attr, layer)
return attr
def _propagate_attention(self, attr, layer):
"""
通过注意力层传播归因
"""
# 获取注意力权重
attn_weights = self.cache[f'attn_{layer}']
# 归因 = 注意力权重 × 下游归因
propagated = attn_weights @ attr
return propagated7. 实践工具
7.1 TransformerLens
# TransformerLens: Transformer可解释性研究工具
from transformer_lens import HookedTransformer
# 加载模型
model = HookedTransformer.from_pretrained("gpt2")
# 设置缓存
model.run_with_cache("The cat sat on the mat")
# 获取特定激活
def hook_fn(activation, hook):
print(f"Attention pattern shape: {activation.shape}")
return activation
model.run_with_hooks(
"Hello world",
fwd_hooks=[("blocks.1.attn.hook_pattern", hook_fn)]
)7.2 SAELens
# SAELens: SAE分析工具
from sae_lens import SAE, ActivationsStore
# 加载SAE
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-resid-pre-layer-10",
sae_id="blocks.10.hook_resid_pre"
)
# 分析激活
activations_store = ActivationsStore(
model=model,
activation_dim=cfg_dict['d_in'],
)
# 编码
tokens = model.tokenizer("Hello world", return_tensors='pt')['input_ids']
activations = model.run_with_hooks(tokens)
features = sae.encode(activations)7.3 完整分析流程
class LLMInterpretabilityAnalysis:
"""
LLM可解释性完整分析流程
"""
def __init__(self, model_name):
# 加载模型
self.model = HookedTransformer.from_pretrained(model_name)
# 加载SAE
self.sae = self._load_sae(model_name)
self.tokenizer = self.model.tokenizer
def complete_analysis(self, prompt, target_behavior):
"""
完整分析流程
"""
# 1. 基本激活分析
activations = self._analyze_activations(prompt)
# 2. SAE特征分解
features = self._decompose_features(activations)
# 3. 关键特征识别
key_features = self._identify_key_features(features, target_behavior)
# 4. 因果路径追踪
causal_paths = self._trace_causal_paths(key_features)
# 5. 生成解释
explanation = self._generate_explanation(
key_features,
causal_paths
)
return {
'activations': activations,
'features': features,
'key_features': key_features,
'causal_paths': causal_paths,
'explanation': explanation
}8. 评估与基准
8.1 解释质量评估
class ExplanationQualityEvaluator:
"""
解释质量评估
"""
def evaluate_explanation(self, explanation, ground_truth):
"""
评估解释质量
维度:
1. 正确性:与ground truth的一致性
2. 完整性:覆盖所有相关组件
3. 简洁性:解释的简洁程度
4. 可操作性:能否指导干预
"""
metrics = {
'correctness': self._evaluate_correctness(explanation, ground_truth),
'completeness': self._evaluate_completeness(explanation),
'succinctness': self._evaluate_succinctness(explanation),
'actionability': self._evaluate_actionability(explanation)
}
return metrics
def _evaluate_correctness(self, explanation, ground_truth):
"""
正确性评估
比较解释识别的特征/电路与ground truth
"""
explained_features = set(explanation['features'])
true_features = set(ground_truth['features'])
# Precision/Recall
tp = len(explained_features & true_features)
fp = len(explained_features - true_features)
fn = len(true_features - explained_features)
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
return {'precision': precision, 'recall': recall, 'f1': f1}8.2 基准数据集
| 基准 | 描述 | 用途 |
|---|---|---|
| IOI | 归纳头测试 | 电路分析验证 |
| Greater-Than | 数值比较测试 | 机制验证 |
| Indirect Object Identification | 间接宾语识别 | 注意力分析 |
| SAE Feature Benchmark | 特征可解释性 | SAE质量评估 |
9. 参考文献
相关主题
Footnotes
-
Elhage, N., et al. (2022). A Mathematical Framework for Transformer Circuits. ↩
-
Elhage, N., et al. (2021). Softmax Linear Units. Transformer Circuits Thread. ↩
-
Elhage, N., et al. (2022). Superposition, Memorization, and Double Descent. Transformer Circuits Thread. ↩
-
Bricken, T., et al. (2023). Towards Monosemanticity. Transformer Circuits Thread. ↩
-
DeepMind. (2024). Gemma Scope. GitHub. ↩
-
Sparse Feature Circuits Authors. (2025). Sparse Feature Circuits. ICLR 2025. ↩
-
Geiger, A., et al. (2024). Causal Abstraction in Large Language Models. ↩