Circuit Discovery
Circuit Discovery(电路发现)是机械可解释性(Mechanistic Interpretability)领域的核心方法论,旨在通过识别神经网络中执行特定任务的稀疏子网络(子图)来理解模型的内部计算机制。1
引言:什么是 Circuit Discovery
机械可解释性的目标
机械可解释性试图”逆向工程”神经网络,将其内部计算分解为人类可理解的算法。与传统可解释性方法(如特征重要性、注意力可视化)不同,机械可解释性追求对模型行为的系统性、因果性理解。
黑盒模型 → 可解释的电路/算法
↓
电路 = 模型中执行特定任务的最小子图
电路的定义
在 Transformer 架构中,一个电路(Circuit) 由以下组件组成:
- 注意力头(Attention Heads):跨位置传递信息
- MLP 神经元:非线性变换
- 残差连接:信息直接传递路径
电路发现的任务是:给定一个模型行为,找到负责该行为的最小组件集合。
Activation Patching:核心因果分析方法
Activation Patching(激活修补),又称 Causal Tracing 或 Path Patching,是电路发现的核心因果推断技术。2
基本原理
Activation Patching 的核心思想是因果干预:
- 在干净文本(不包含目标行为)上运行模型,记录各层激活
- 在含目标行为的文本上运行模型
- 将目标位置的激活替换为干净文本的对应激活
- 观察输出是否恢复”干净”行为
若替换后输出恢复,则该位置对该行为不关键;若输出仍保持异常,则该位置关键。
直接注意头归因
对于注意力头,我们关注其在特定位置的输出:
def activation_patching_head(
model, # 待分析模型
clean_tokens, # 干净文本的token序列
corrupted_tokens,# 含目标行为的文本token序列
head_index, # 要patch的注意力头
layer_index, # 要patch的层
position # 要patch的位置
):
"""
测试特定注意力头对特定位置输出的贡献
"""
# 获取干净和损坏的激活
_, clean_cache = model.run_with_cache(clean_tokens)
_, corrupted_cache = model.run_with_cache(corrupted_tokens)
# 运行模型,在指定位置使用干净激活替换损坏激活
def patching_hook(value, hook):
if hook.layer() == layer_index:
value[:, position, head_index, :] = \
clean_cache[hook.name][:, position, head_index, :]
return value
patched_logits = model.run_with_hooks(
corrupted_tokens,
fwd_hooks=[(f"blocks.{layer_index}.attn.hook_v", patching_hook)]
)
return patched_logits路径 Patching
更精细的方法是同时 patch 多个路径的激活:
def path_patching(
model,
clean_tokens,
corrupted_tokens,
receiver_head, # 接收方注意力头 (layer, head)
sender_position, # 发送方位置
metric_fn # 评估指标函数
):
"""
分析从sender到receiver的路径贡献
"""
_, clean_cache = model.run_with_cache(clean_tokens)
_, corrupted_cache = model.run_with_cache(corrupted_tokens)
def qoikv_hook(z, hook, receiver_layer, receiver_head_idx):
"""Query-Output-Key-Value patching"""
if hook.layer() == receiver_layer:
# 计算从sender位置到当前注意力头的贡献
q = z[:, sender_position, receiver_head_idx, :]
return z
return z
# 计算patch后的指标
patched_logits = model.run_with_hooks(
corrupted_tokens,
fwd_hooks=[(f"blocks.{receiver_layer[0]}.attn.hook_z", qoikv_hook)]
)
return metric_fn(patched_logits)度量标准
Activation Patching 常用的评估指标:
| 指标 | 公式 | 含义 |
|---|---|---|
| Logit Difference | 恢复程度 | |
| Average KL Divergence | 分布差异 | |
| Cosine Similarity | 方向一致性 |
经典电路案例
Induction Head:归纳头
Induction Head 是最早被深入研究的电路之一,在 in-context learning 中起关键作用。3
机制原理
Induction Head 执行 match-and-copy 操作:
序列:... [A] [B] ... [A] → ?
↑_______↑
查找A后的B,复制到当前位置
具体来说,两个注意力头协同工作:
- 前一个token头(Previous Token Head):在第一层,将信息从前一个token复制到当前token
- 归纳头(Induction Head):在第二层,基于”当前token”寻找之前出现的”相同token”,并 attending 到其后的 token
数学形式化
设 为前一个token头的输出, 为归纳头:
其中 表示 是 之前的 token。
与 In-context Learning 的关联
Olsson et al. (2022) 提出了六条证据链,证明归纳头可能是大多数 in-context learning 的机制来源:
- 宏观共现:训练早期出现”相变”,同时形成归纳头和 in-context learning 能力
- 架构共扰动:改变架构使归纳头无法形成时,in-context learning 同步退化
- 直接消融:消融归纳头后 in-context learning 大幅下降
- 泛化能力:归纳头可执行抽象的 pattern matching 而非仅复制
- 机制合理性:小模型中可精确解释归纳头的工作方式
- 尺度连续性:大小模型中行为一致
SVA (Suffix Vector Ablation)
SVA 是一种验证电路完整性的方法,通过移除(ablating)特定模式来测试电路的功能完整性。
def suffix_vector_ablation(
model,
tokens,
circuit_heads, # 发现的电路中的注意力头
pattern_length=10 # 后缀模式长度
):
"""
对电路中的头进行组合消融,验证是否等价于完全移除
"""
results = {}
# 1. 正常前向传播
baseline_logits = model(tokens)
# 2. 只消融电路中的头
circuit_ablated = ablate_heads(model, tokens, circuit_heads)
# 3. 验证两者的差异
diff = torch.norm(baseline_logits - circuit_ablated, p=2)
results["fidelity"] = diff.item()
return results电路发现算法
Automated Circuit Discovery (ACDC)
ACDC(NeurIPS 2023)是首个系统化的自动电路发现算法。1
算法流程
1. 选择任务和数据集
↓
2. 构建计算图
↓
3. 初始化电路(所有组件)
↓
4. 迭代剪枝
↓
5. 验证电路完整性
核心算法
def acdc(
model,
task_metric, # 评估任务完成度的指标
graph, # 计算图 (layers × heads)
epsilon=0.01, # 剪枝阈值
max_iterations=100
):
"""
Automated Circuit Discovery
"""
# Step 1: 计算所有组件的 importance
# 使用 logit diff 或其他指标
importance = compute_component_importance(model, graph, task_metric)
# Step 2: 初始化电路为所有组件
circuit = set(graph.nodes)
# Step 3: 迭代剪枝
for iteration in range(max_iterations):
for node in graph.nodes:
if node not in circuit:
continue
# 尝试移除该节点
temp_circuit = circuit - {node}
# 计算移除后的性能损失
performance_loss = compute_loss(model, temp_circuit, task_metric)
# 如果损失在阈值内,则正式移除
if performance_loss < epsilon:
circuit = temp_circuit
# 检查收敛
if is_stable(circuit, importance):
break
return circuit
def compute_component_importance(model, graph, metric):
"""
计算每个组件的重要性分数
使用 activation patching 估计因果贡献
"""
importance = {}
for node in graph.nodes:
# patch 单个组件的激活
patched_output = patch_single_component(model, node)
# 计算指标变化
original_metric = metric(model)
patched_metric = metric(patched_output)
importance[node] = original_metric - patched_metric
return importance实验结果
在 GPT-2 Small(约 32,000 条边)上:
| 任务 | ACDC 发现的边数 | 手工发现的边数 | 召回率 |
|---|---|---|---|
| Greater-Than | 68 | 68 | 100% |
| Induction | ~200 | ~200 | ~100% |
| Docstring Parsing | ~150 | ~150 | ~100% |
DiscoGP:可微图剪枝
DiscoGP(2024)提出一种基于可微掩码的电路发现方法,比 ACDC 更高效。4
核心思想
将电路发现重新形式化为一个连续优化问题:
其中 是组件的掩码向量, 是任务损失, 是 正则化(促进稀疏性)。
可微松弛
使用 Gumbel-Softmax 或 Straight-Through Estimator 实现可微掩码:
class DifferentiableMask(nn.Module):
def __init__(self, num_components, temperature=1.0):
super().__init__()
self.logits = nn.Parameter(torch.zeros(num_components))
self.temperature = temperature
def forward(self):
"""
返回伯努利分布的连续松弛
"""
probs = torch.sigmoid(self.logits)
# Gumbel-Softmax 松弛
gumbels = -torch.log(-torch.log(torch.rand_like(probs) + 1e-20) + 1e-20)
scores = (probs.log() + gumbels) / self.temperature
mask = torch.sigmoid(scores)
return mask
def hard_mask(self):
"""用于推理的硬掩码"""
return (self.logits > 0).float()
def discogp(
model,
graph,
task_metric,
lambda_l0=0.01,
lr=0.1,
max_iterations=500
):
"""
DiscoGP: Differentiable Graph Pruning for Circuit Discovery
"""
# 初始化可学习掩码
mask_module = DifferentiableMask(len(graph.nodes))
optimizer = torch.optim.Adam(mask_module.parameters(), lr=lr)
for iteration in range(max_iterations):
optimizer.zero_grad()
# 获取当前掩码
mask = mask_module()
# 应用掩码并计算损失
masked_model = apply_mask(model, graph, mask)
loss = -task_metric(masked_model) # 最大化任务指标 = 最小化负指标
# L0 正则化
l0_loss = lambda_l0 * mask.sum()
total_loss = loss + l0_loss
total_loss.backward()
optimizer.step()
# 返回发现的电路(使用硬掩码)
hard_mask = mask_module.hard_mask()
circuit = [node for node, m in zip(graph.nodes, hard_mask) if m > 0]
return circuit与 ACDC 的比较
| 特性 | ACDC | DiscoGP |
|---|---|---|
| 搜索方式 | 离散剪枝 | 连续优化 |
| 前向传递次数 | ~1000+ | ~2-5 |
| 需要微调 | 否 | 是 |
| 可扩展性 | 中等 | 较高 |
Contextual Decomposition (CD)
CD 方法利用 Transformer 的线性结构,对注意力模式进行分解,特别适合发现位置感知的电路。5
电路验证:如何验证发现的电路是正确的
完整性测试(Fidelity)
验证发现的电路是否真正负责目标行为:
def fidelity_test(
model,
original_circuit,
test_dataset,
task_metric
):
"""
测试电路的完整性
"""
results = {}
# 1. 基线性能
baseline = task_metric(model, test_dataset)
results["baseline"] = baseline
# 2. 只消融电路内组件
circuit_only = ablate_components(model, original_circuit)
results["circuit_ablated"] = task_metric(circuit_only, test_dataset)
# 3. 完全消融(所有组件)
all_ablated = ablate_all(model)
results["all_ablated"] = task_metric(all_ablated, test_dataset)
# 4. 计算 fidelity
# 如果电路完全负责行为,则 circuit_ablated ≈ all_ablated
fidelity = (baseline - results["circuit_ablated"]) / \
(baseline - results["all_ablated"])
results["fidelity_score"] = fidelity
return results最小性测试(Minimality)
验证电路是最小的——即没有冗余组件:
def minimality_test(
model,
circuit,
test_dataset,
task_metric,
epsilon=0.01
):
"""
测试电路的最小性
"""
minimal_circuit = set(circuit)
redundant_components = []
for component in circuit:
# 尝试移除单个组件
test_circuit = minimal_circuit - {component}
if is_performance_preserved(
model, test_circuit, test_dataset, task_metric, epsilon
):
redundant_components.append(component)
minimal_circuit = test_circuit
return {
"original_size": len(circuit),
"minimal_size": len(minimal_circuit),
"redundant": redundant_components,
"is_minimal": len(redundant_components) == 0
}跨模型一致性
验证同一电路在不同模型中的存在性和功能一致性:
def cross_model_consistency(
model1, model2,
circuit1, circuit2,
task_metric
):
"""
测试电路在不同模型间的一致性
"""
# 计算两个模型的电路在相同任务上的表现
perf1 = task_metric(model1, circuit1)
perf2 = task_metric(model2, circuit2)
# 计算相似性
similarity = cosine_similarity(
circuit1.get_mechanism_embedding(),
circuit2.get_mechanism_embedding()
)
return {
"performance_alignment": abs(perf1 - perf2) < epsilon,
"mechanism_similarity": similarity
}消融对照实验
系统性地消融电路内外的组件,观察行为变化:
| 消融方式 | 预期结果 |
|---|---|
| 只保留电路内组件 | 行为保持 |
| 只消融电路内组件 | 行为消失 |
| 消融电路外组件 | 行为保持 |
| 随机消融 | 行为下降最小 |
局限性
过度简化风险
- 单一电路假设:现实中的复杂行为往往由多个相互作用的电路共同实现
- 忽略交互效应:单独分析各组件可能无法捕捉组件间的非线性交互
- 任务定义的主观性:同一”行为”可能有多种定义方式
电路交互与嵌套
┌─────────────────────────────────────┐
│ 复杂行为 │
│ ┌─────────┐ ┌─────────┐ ┌─────┐ │
│ │ Circuit │ │ Circuit │ │ ... │ │
│ │ A │──│ B │──│ │ │
│ └─────────┘ └─────────┘ └─────┘ │
│ ↑ ↑ │
│ └────────────┴──────────────────│
│ 相互依赖 │
└─────────────────────────────────────┘
可扩展性挑战
| 问题 | 影响 |
|---|---|
| 计算复杂度 | 的组件数导致 的搜索空间 |
| 深层网络 | 深层 Transformer 的电路边界更模糊 |
| 权重共享 | 组件的多用途使”最小电路”定义困难 |
| 动态行为 | 同一组件在不同上下文可能扮演不同角色 |
因果推断的局限性
- Patching 粒度:patch 到注意力头级别可能过于粗糙
- 干净/损坏文本假设:文本的”干净”定义可能影响结果
- 间接效应:patch 某位置可能通过其他路径间接影响输出
参考文献
相关主题
Footnotes
-
Conmy, A., et al. (2023). Towards Automated Circuit Discovery for Mechanistic Interpretability. NeurIPS 2023. https://arxiv.org/abs/2304.14997 ↩ ↩2
-
Zhang, F., & Nanda, N. (2024). Towards Best Practices of Activation Patching in Language Models. ICLR 2024. https://iclr.cc/virtual/2024/poster/18984 ↩
-
Olsson, C., et al. (2022). In-context Learning and Induction Heads. Transformer Circuits Thread. https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads ↩
-
Functional Faithfulness in the Wild: Circuit Discovery with Differentiable Computation Graph Pruning. (2024). https://arxiv.org/html/2407.03779v1 ↩
-
Efficient Automated Circuit Discovery in Transformers using Contextual Decomposition. (2024). https://arxiv.org/abs/2407.00886 ↩