Circuit Discovery

Circuit Discovery(电路发现)是机械可解释性(Mechanistic Interpretability)领域的核心方法论,旨在通过识别神经网络中执行特定任务的稀疏子网络(子图)来理解模型的内部计算机制。1

引言:什么是 Circuit Discovery

机械可解释性的目标

机械可解释性试图”逆向工程”神经网络,将其内部计算分解为人类可理解的算法。与传统可解释性方法(如特征重要性、注意力可视化)不同,机械可解释性追求对模型行为的系统性、因果性理解。

黑盒模型 → 可解释的电路/算法
    ↓
   电路 = 模型中执行特定任务的最小子图

电路的定义

在 Transformer 架构中,一个电路(Circuit) 由以下组件组成:

  • 注意力头(Attention Heads):跨位置传递信息
  • MLP 神经元:非线性变换
  • 残差连接:信息直接传递路径

电路发现的任务是:给定一个模型行为,找到负责该行为的最小组件集合。


Activation Patching:核心因果分析方法

Activation Patching(激活修补),又称 Causal TracingPath Patching,是电路发现的核心因果推断技术。2

基本原理

Activation Patching 的核心思想是因果干预

  1. 干净文本(不包含目标行为)上运行模型,记录各层激活
  2. 含目标行为的文本上运行模型
  3. 将目标位置的激活替换为干净文本的对应激活
  4. 观察输出是否恢复”干净”行为

若替换后输出恢复,则该位置对该行为不关键;若输出仍保持异常,则该位置关键

直接注意头归因

对于注意力头,我们关注其在特定位置的输出:

def activation_patching_head(
    model,           # 待分析模型
    clean_tokens,    # 干净文本的token序列
    corrupted_tokens,# 含目标行为的文本token序列
    head_index,     # 要patch的注意力头
    layer_index,    # 要patch的层
    position        # 要patch的位置
):
    """
    测试特定注意力头对特定位置输出的贡献
    """
    # 获取干净和损坏的激活
    _, clean_cache = model.run_with_cache(clean_tokens)
    _, corrupted_cache = model.run_with_cache(corrupted_tokens)
    
    # 运行模型,在指定位置使用干净激活替换损坏激活
    def patching_hook(value, hook):
        if hook.layer() == layer_index:
            value[:, position, head_index, :] = \
                clean_cache[hook.name][:, position, head_index, :]
        return value
    
    patched_logits = model.run_with_hooks(
        corrupted_tokens,
        fwd_hooks=[(f"blocks.{layer_index}.attn.hook_v", patching_hook)]
    )
    
    return patched_logits

路径 Patching

更精细的方法是同时 patch 多个路径的激活:

def path_patching(
    model,
    clean_tokens,
    corrupted_tokens,
    receiver_head,     # 接收方注意力头 (layer, head)
    sender_position,    # 发送方位置
    metric_fn          # 评估指标函数
):
    """
    分析从sender到receiver的路径贡献
    """
    _, clean_cache = model.run_with_cache(clean_tokens)
    _, corrupted_cache = model.run_with_cache(corrupted_tokens)
    
    def qoikv_hook(z, hook, receiver_layer, receiver_head_idx):
        """Query-Output-Key-Value patching"""
        if hook.layer() == receiver_layer:
            # 计算从sender位置到当前注意力头的贡献
            q = z[:, sender_position, receiver_head_idx, :]
            return z
        return z
    
    # 计算patch后的指标
    patched_logits = model.run_with_hooks(
        corrupted_tokens,
        fwd_hooks=[(f"blocks.{receiver_layer[0]}.attn.hook_z", qoikv_hook)]
    )
    
    return metric_fn(patched_logits)

度量标准

Activation Patching 常用的评估指标:

指标公式含义
Logit Difference恢复程度
Average KL Divergence分布差异
Cosine Similarity方向一致性

经典电路案例

Induction Head:归纳头

Induction Head 是最早被深入研究的电路之一,在 in-context learning 中起关键作用。3

机制原理

Induction Head 执行 match-and-copy 操作:

序列:... [A] [B] ... [A] → ?
        ↑_______↑
        查找A后的B,复制到当前位置

具体来说,两个注意力头协同工作:

  1. 前一个token头(Previous Token Head):在第一层,将信息从前一个token复制到当前token
  2. 归纳头(Induction Head):在第二层,基于”当前token”寻找之前出现的”相同token”,并 attending 到其后的 token

数学形式化

为前一个token头的输出, 为归纳头:

其中 表示 之前的 token。

与 In-context Learning 的关联

Olsson et al. (2022) 提出了六条证据链,证明归纳头可能是大多数 in-context learning 的机制来源:

  1. 宏观共现:训练早期出现”相变”,同时形成归纳头和 in-context learning 能力
  2. 架构共扰动:改变架构使归纳头无法形成时,in-context learning 同步退化
  3. 直接消融:消融归纳头后 in-context learning 大幅下降
  4. 泛化能力:归纳头可执行抽象的 pattern matching 而非仅复制
  5. 机制合理性:小模型中可精确解释归纳头的工作方式
  6. 尺度连续性:大小模型中行为一致

SVA (Suffix Vector Ablation)

SVA 是一种验证电路完整性的方法,通过移除(ablating)特定模式来测试电路的功能完整性。

def suffix_vector_ablation(
    model,
    tokens,
    circuit_heads,     # 发现的电路中的注意力头
    pattern_length=10  # 后缀模式长度
):
    """
    对电路中的头进行组合消融,验证是否等价于完全移除
    """
    results = {}
    
    # 1. 正常前向传播
    baseline_logits = model(tokens)
    
    # 2. 只消融电路中的头
    circuit_ablated = ablate_heads(model, tokens, circuit_heads)
    
    # 3. 验证两者的差异
    diff = torch.norm(baseline_logits - circuit_ablated, p=2)
    results["fidelity"] = diff.item()
    
    return results

电路发现算法

Automated Circuit Discovery (ACDC)

ACDC(NeurIPS 2023)是首个系统化的自动电路发现算法。1

算法流程

1. 选择任务和数据集
        ↓
2. 构建计算图
        ↓
3. 初始化电路(所有组件)
        ↓
4. 迭代剪枝
        ↓
5. 验证电路完整性

核心算法

def acdc(
    model,
    task_metric,       # 评估任务完成度的指标
    graph,             # 计算图 (layers × heads)
    epsilon=0.01,      # 剪枝阈值
    max_iterations=100
):
    """
    Automated Circuit Discovery
    """
    # Step 1: 计算所有组件的 importance
    # 使用 logit diff 或其他指标
    importance = compute_component_importance(model, graph, task_metric)
    
    # Step 2: 初始化电路为所有组件
    circuit = set(graph.nodes)
    
    # Step 3: 迭代剪枝
    for iteration in range(max_iterations):
        for node in graph.nodes:
            if node not in circuit:
                continue
            
            # 尝试移除该节点
            temp_circuit = circuit - {node}
            
            # 计算移除后的性能损失
            performance_loss = compute_loss(model, temp_circuit, task_metric)
            
            # 如果损失在阈值内,则正式移除
            if performance_loss < epsilon:
                circuit = temp_circuit
        
        # 检查收敛
        if is_stable(circuit, importance):
            break
    
    return circuit
 
def compute_component_importance(model, graph, metric):
    """
    计算每个组件的重要性分数
    使用 activation patching 估计因果贡献
    """
    importance = {}
    
    for node in graph.nodes:
        # patch 单个组件的激活
        patched_output = patch_single_component(model, node)
        
        # 计算指标变化
        original_metric = metric(model)
        patched_metric = metric(patched_output)
        
        importance[node] = original_metric - patched_metric
    
    return importance

实验结果

在 GPT-2 Small(约 32,000 条边)上:

任务ACDC 发现的边数手工发现的边数召回率
Greater-Than6868100%
Induction~200~200~100%
Docstring Parsing~150~150~100%

DiscoGP:可微图剪枝

DiscoGP(2024)提出一种基于可微掩码的电路发现方法,比 ACDC 更高效。4

核心思想

将电路发现重新形式化为一个连续优化问题

其中 是组件的掩码向量, 是任务损失, 正则化(促进稀疏性)。

可微松弛

使用 Gumbel-Softmax 或 Straight-Through Estimator 实现可微掩码:

class DifferentiableMask(nn.Module):
    def __init__(self, num_components, temperature=1.0):
        super().__init__()
        self.logits = nn.Parameter(torch.zeros(num_components))
        self.temperature = temperature
    
    def forward(self):
        """
        返回伯努利分布的连续松弛
        """
        probs = torch.sigmoid(self.logits)
        
        # Gumbel-Softmax 松弛
        gumbels = -torch.log(-torch.log(torch.rand_like(probs) + 1e-20) + 1e-20)
        scores = (probs.log() + gumbels) / self.temperature
        mask = torch.sigmoid(scores)
        
        return mask
    
    def hard_mask(self):
        """用于推理的硬掩码"""
        return (self.logits > 0).float()
 
def discogp(
    model,
    graph,
    task_metric,
    lambda_l0=0.01,
    lr=0.1,
    max_iterations=500
):
    """
    DiscoGP: Differentiable Graph Pruning for Circuit Discovery
    """
    # 初始化可学习掩码
    mask_module = DifferentiableMask(len(graph.nodes))
    optimizer = torch.optim.Adam(mask_module.parameters(), lr=lr)
    
    for iteration in range(max_iterations):
        optimizer.zero_grad()
        
        # 获取当前掩码
        mask = mask_module()
        
        # 应用掩码并计算损失
        masked_model = apply_mask(model, graph, mask)
        loss = -task_metric(masked_model)  # 最大化任务指标 = 最小化负指标
        
        # L0 正则化
        l0_loss = lambda_l0 * mask.sum()
        total_loss = loss + l0_loss
        
        total_loss.backward()
        optimizer.step()
    
    # 返回发现的电路(使用硬掩码)
    hard_mask = mask_module.hard_mask()
    circuit = [node for node, m in zip(graph.nodes, hard_mask) if m > 0]
    
    return circuit

与 ACDC 的比较

特性ACDCDiscoGP
搜索方式离散剪枝连续优化
前向传递次数~1000+~2-5
需要微调
可扩展性中等较高

Contextual Decomposition (CD)

CD 方法利用 Transformer 的线性结构,对注意力模式进行分解,特别适合发现位置感知的电路。5


电路验证:如何验证发现的电路是正确的

完整性测试(Fidelity)

验证发现的电路是否真正负责目标行为:

def fidelity_test(
    model,
    original_circuit,
    test_dataset,
    task_metric
):
    """
    测试电路的完整性
    """
    results = {}
    
    # 1. 基线性能
    baseline = task_metric(model, test_dataset)
    results["baseline"] = baseline
    
    # 2. 只消融电路内组件
    circuit_only = ablate_components(model, original_circuit)
    results["circuit_ablated"] = task_metric(circuit_only, test_dataset)
    
    # 3. 完全消融(所有组件)
    all_ablated = ablate_all(model)
    results["all_ablated"] = task_metric(all_ablated, test_dataset)
    
    # 4. 计算 fidelity
    # 如果电路完全负责行为,则 circuit_ablated ≈ all_ablated
    fidelity = (baseline - results["circuit_ablated"]) / \
               (baseline - results["all_ablated"])
    
    results["fidelity_score"] = fidelity
    
    return results

最小性测试(Minimality)

验证电路是最小的——即没有冗余组件:

def minimality_test(
    model,
    circuit,
    test_dataset,
    task_metric,
    epsilon=0.01
):
    """
    测试电路的最小性
    """
    minimal_circuit = set(circuit)
    redundant_components = []
    
    for component in circuit:
        # 尝试移除单个组件
        test_circuit = minimal_circuit - {component}
        
        if is_performance_preserved(
            model, test_circuit, test_dataset, task_metric, epsilon
        ):
            redundant_components.append(component)
            minimal_circuit = test_circuit
    
    return {
        "original_size": len(circuit),
        "minimal_size": len(minimal_circuit),
        "redundant": redundant_components,
        "is_minimal": len(redundant_components) == 0
    }

跨模型一致性

验证同一电路在不同模型中的存在性和功能一致性:

def cross_model_consistency(
    model1, model2,
    circuit1, circuit2,
    task_metric
):
    """
    测试电路在不同模型间的一致性
    """
    # 计算两个模型的电路在相同任务上的表现
    perf1 = task_metric(model1, circuit1)
    perf2 = task_metric(model2, circuit2)
    
    # 计算相似性
    similarity = cosine_similarity(
        circuit1.get_mechanism_embedding(),
        circuit2.get_mechanism_embedding()
    )
    
    return {
        "performance_alignment": abs(perf1 - perf2) < epsilon,
        "mechanism_similarity": similarity
    }

消融对照实验

系统性地消融电路内外的组件,观察行为变化:

消融方式预期结果
只保留电路内组件行为保持
只消融电路内组件行为消失
消融电路外组件行为保持
随机消融行为下降最小

局限性

过度简化风险

  1. 单一电路假设:现实中的复杂行为往往由多个相互作用的电路共同实现
  2. 忽略交互效应:单独分析各组件可能无法捕捉组件间的非线性交互
  3. 任务定义的主观性:同一”行为”可能有多种定义方式

电路交互与嵌套

┌─────────────────────────────────────┐
│          复杂行为                    │
│  ┌─────────┐  ┌─────────┐  ┌─────┐  │
│  │ Circuit │  │ Circuit │  │ ... │  │
│  │    A    │──│    B    │──│     │  │
│  └─────────┘  └─────────┘  └─────┘  │
│      ↑            ↑                  │
│      └────────────┴──────────────────│
│           相互依赖                    │
└─────────────────────────────────────┘

可扩展性挑战

问题影响
计算复杂度 的组件数导致 的搜索空间
深层网络深层 Transformer 的电路边界更模糊
权重共享组件的多用途使”最小电路”定义困难
动态行为同一组件在不同上下文可能扮演不同角色

因果推断的局限性

  1. Patching 粒度:patch 到注意力头级别可能过于粗糙
  2. 干净/损坏文本假设:文本的”干净”定义可能影响结果
  3. 间接效应:patch 某位置可能通过其他路径间接影响输出

参考文献


相关主题

Footnotes

  1. Conmy, A., et al. (2023). Towards Automated Circuit Discovery for Mechanistic Interpretability. NeurIPS 2023. https://arxiv.org/abs/2304.14997 2

  2. Zhang, F., & Nanda, N. (2024). Towards Best Practices of Activation Patching in Language Models. ICLR 2024. https://iclr.cc/virtual/2024/poster/18984

  3. Olsson, C., et al. (2022). In-context Learning and Induction Heads. Transformer Circuits Thread. https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads

  4. Functional Faithfulness in the Wild: Circuit Discovery with Differentiable Computation Graph Pruning. (2024). https://arxiv.org/html/2407.03779v1

  5. Efficient Automated Circuit Discovery in Transformers using Contextual Decomposition. (2024). https://arxiv.org/abs/2407.00886