电路分析方法论

电路分析(Circuit Analysis)是机制可解释性(Mechanistic Interpretability)的核心方法,旨在识别和理解神经网络中执行特定计算的可解释子结构。本章系统介绍电路分析的理论基础、主要技术和实践方法。

1. 电路假说

1.1 形式化定义

电路假说(Circuit Hypothesis) 认为,深度神经网络可以理解为一组相对局部化、可解释的子电路的组合,每个子电路负责特定的计算功能。

形式化地,给定一个神经网络 和一个行为 (如正确预测某个词),电路假说声称:

其中 表示网络 中所有可能的电路集合, 表示电路 是否被激活。

1.2 电路的组成要素

一个典型电路由以下要素组成:

  • 注意头(Attention Heads):Q/K/V 投影及注意力计算
  • MLP 神经元:非线性变换层
  • 残差连接(Residual Connections):信息直接传递路径
  • 嵌入(Embeddings):输入表示

1.3 电路发现流程

# 电路发现伪代码
def circuit_discovery(model, behavior, target_layer):
    # 1. 识别关键位置
    key_positions = identify_key_positions(model, behavior)
    
    # 2. 初始化电路
    circuit = initialize_circuit(key_positions)
    
    # 3. 迭代扩展
    while not converged:
        new_components = expand_circuit(model, circuit, behavior)
        circuit = prune_circuit(circuit + new_components)
    
    return circuit

2. 激活修补技术

2.1 核心思想

激活修补(Activation Patching),也称为因果追溯(Causal Tracing),通过有控制地替换网络中特定位置的激活来测量其对输出的因果贡献。

给定:

  • 干净输入
  • 损坏输入
  • 目标位置 (第 层第 个神经元)

修补后的输出:

2.2 因果效应度量

定义位置 对行为 因果效应为:

其中 表示在位置 进行修补后的模型输出。

2.3 Logit Lens

Logit Lens 是一种基于激活修补的可解释性技术,通过逐层修补隐藏状态并映射到词表空间来追踪信息流动:

import torch
import torch.nn.functional as F
from transformer_lens import HookedTransformer
 
def logit_lens(model, tokens, layer_idx):
    """
    Logit Lens: 追踪第layer_idx层隐藏状态到输出词的映射
    """
    # 前向传播并缓存所有层的激活
    _, cache = model.run_with_cache(tokens)
    
    # 获取指定层的激活
    hidden_states = cache[f"blocks.{layer_idx}.hook_resid_post"]
    
    # 投影到词表空间
    unembed = model.W_U  # (d_model, vocab_size)
    logits = hidden_states @ unembed
    
    # 获取top-k预测词
    top_k = 10
    probs = F.softmax(logits, dim=-1)
    top_probs, top_indices = torch.topk(probs, top_k, dim=-1)
    
    return top_indices, top_probs

2.4 逐层因果追溯实现

class CausalTracer:
    def __init__(self, model, clean_tokens, corrupt_tokens):
        self.model = model
        self.clean_tokens = clean_tokens
        self.corrupt_tokens = corrupt_tokens
        
    def trace_layer(self, layer_idx, position=-1):
        """
        追踪第layer_idx层指定位置对输出的因果效应
        """
        # 获取干净和损坏的激活
        with torch.no_grad():
            clean_logits, clean_cache = self.model.run_with_cache(self.clean_tokens)
            corrupt_logits, corrupt_cache = self.model.run_with_cache(self.corrupt_tokens)
        
        # 创建修补后的输入
        patched_tokens = self.corrupt_tokens.clone()
        
        # 对每层进行修补并计算效应
        effects = []
        for layer in range(layer_idx + 1):
            # 复制损坏激活
            patched_cache = {k: v.clone() for k, v in corrupt_cache.items()}
            
            # 修补目标层
            if position == -1:  # 所有位置
                clean_hook = clean_cache[f"blocks.{layer}.hook_resid_post"]
                patched_cache[f"blocks.{layer}.hook_resid_post"] = clean_hook
            else:
                clean_hook = clean_cache[f"blocks.{layer}.hook_resid_post"]
                patched_cache[f"blocks.{layer}.hook_resid_post"] = clean_hook[:, position:position+1, :]
            
            # 运行修补后的模型
            patched_logits = self.model.run_with_hooks(
                patched_tokens,
                fwd_hooks=[(k, lambda x, _: v) for k, v in patched_cache.items()]
            )
            
            # 计算效应
            effect = self.compute_effect(clean_logits, patched_logits)
            effects.append(effect)
        
        return effects
    
    def compute_effect(self, clean_logits, patched_logits):
        """计算因果效应"""
        # 关注正确token的概率变化
        return {
            "clean_prob": F.softmax(clean_logits, dim=-1)[0, -1].max().item(),
            "patched_prob": F.softmax(patched_logits, dim=-1)[0, -1].max().item(),
            "effect": F.softmax(clean_logits, dim=-1)[0, -1].max().item() - 
                      F.softmax(patched_logits, dim=-1)[0, -1].max().item()
        }

3. 路径归因方法

3.1 直接归因

直接归因(Direct Attribution) 衡量每个输入特征对输出的直接贡献:

3.2 积分梯度

积分梯度(Integrated Gradients) 通过对输入到基线的路径进行积分来计算归因:

def integrated_gradients(model, input_ids, baseline_ids, num_steps=50):
    """
    积分梯度实现
    """
    # 缩放因子
    scaled_inputs = [
        baseline_ids + (float(i) / num_steps) * (input_ids - baseline_ids)
        for i in range(num_steps + 1)
    ]
    
    # 计算梯度的数值积分
    gradients = []
    for scaled_input in scaled_inputs:
        scaled_input.requires_grad_(True)
        output = model(scaled_input)
        # 取目标类别或logit
        score = output[0, -1, target_token_id]
        score.backward()
        gradients.append(scaled_input.grad.clone())
    
    # 梯度的平均值乘以输入差
    avg_gradients = torch.stack(gradients).mean(dim=0)
    integrated_grads = (input_ids - baseline_ids) * avg_gradients
    
    return integrated_grads

3.3 注意力流

注意力流(Attention Flow) 将注意力权重与梯度信息结合来归因:

def attention_flow(model, tokens, layer_idx, head_idx):
    """
    注意力流:结合注意力权重和梯度信息
    """
    # 获取注意力权重
    _, cache = model.run_with_cache(tokens)
    attn_weights = cache[f"blocks.{layer_idx}.attn.hook_att"][0, head_idx]
    
    # 获取输出梯度
    tokens.requires_grad_(True)
    output = model(tokens)
    output[0, -1].backward()
    output_grad = tokens.grad[0]
    
    # 注意力流归因
    # 从源token到目标token的流
    flow = attn_weights @ torch.abs(output_grad)
    
    return flow

3.4 路径修补

路径修补(Path Patching) 通过修改注意力头之间的连接来测量信息流:

class PathPatcher:
    def __init__(self, model):
        self.model = model
        
    def patch_residual_path(self, tokens, from_layer, from_head, to_layer, to_head):
        """
        修补从(from_layer, from_head)到(to_layer, to_head)的路径
        """
        def hook_fn(value, hook, clean_value):
            """修改value为clean_value"""
            return clean_value
        
        # 运行损坏的模型获取缓存
        _, corrupt_cache = self.model.run_with_cache(corrupt_tokens)
        
        # 运行干净的模型获取缓存
        _, clean_cache = self.model.run_with_cache(clean_tokens)
        
        # 创建hook列表来修补路径
        hooks = []
        
        # 修补from_head的输出
        from_hook_name = f"blocks.{from_layer}.attn.hook_v"
        hooks.append((from_hook_name, lambda v, h: clean_cache[from_hook_name]))
        
        # 修补to_head的输入
        to_hook_name = f"blocks.{to_layer}.attn.hook_q"
        hooks.append((to_hook_name, lambda v, h: clean_cache[to_hook_name]))
        
        # 运行修补后的模型
        patched_logits = self.model.run_with_hooks(tokens, fwd_hooks=hooks)
        
        return patched_logits

4. 自动化电路发现

4.1 基于梯度的方法

基于梯度的方法 利用梯度信息来识别重要连接:

def gradient_based_discovery(model, tokens, target_token_id):
    """
    基于梯度的电路发现
    """
    tokens.requires_grad_(True)
    output = model(tokens)
    
    # 获取目标token的logit
    target_logit = output[0, -1, target_token_id]
    target_logit.backward()
    
    # 收集所有参数的梯度
    gradients = {}
    for name, param in model.named_parameters():
        if param.grad is not None:
            gradients[name] = param.grad.abs()
    
    # 识别高梯度区域
    important_connections = []
    for name, grad in gradients.items():
        # 识别显著的梯度
        threshold = grad.mean() + 3 * grad.std()
        mask = grad > threshold
        
        important_connections.append({
            'name': name,
            'magnitude': grad[mask].mean().item(),
            'location': torch.where(mask)
        })
    
    return important_connections

4.2 剪枝方法

def iterative_pruning(model, tokens, target_token_id, sparsity=0.9):
    """
    迭代剪枝发现核心电路
    """
    # 评估当前模型
    current_model = model
    
    for iteration in range(10):
        # 计算当前模型的重要性分数
        importance = compute_importance_scores(current_model, tokens)
        
        # 剪枝最低分数的连接
        threshold = np.percentile(importance, sparsity * 100)
        mask = importance > threshold
        
        # 更新模型
        current_model = apply_mask(current_model, mask)
        
        # 评估剪枝后的性能
        performance = evaluate_model(current_model, tokens, target_token_id)
        print(f"Iteration {iteration}: Performance = {performance:.4f}")
        
        if performance < target_threshold:
            break
    
    return current_model

4.3 因果追溯算法

class CausalTracer:
    def __init__(self, model):
        self.model = model
        
    def full_circuit_trace(self, clean_tokens, corrupt_tokens, target_behavior):
        """
        完整电路追溯
        """
        num_layers = self.model.cfg.n_layers
        
        # 存储每层每个位置的重要性
        importance_scores = torch.zeros(num_layers, self.model.cfg.n_heads)
        
        for layer_idx in range(num_layers):
            for head_idx in range(self.model.cfg.n_heads):
                # 修补该注意力头的输出
                patched_logits = self.patch_head(
                    corrupt_tokens,
                    layer_idx,
                    head_idx,
                    clean_tokens
                )
                
                # 计算因果效应
                effect = self.compute_causal_effect(
                    clean_logits, 
                    patched_logits, 
                    target_behavior
                )
                
                importance_scores[layer_idx, head_idx] = effect
        
        return importance_scores
    
    def patch_head(self, tokens, layer, head, clean_tokens):
        """修补特定注意力头的输出"""
        def head_hook(value, hook):
            _, clean_cache = self.model.run_with_cache(clean_tokens)
            clean_head_output = clean_cache[f"blocks.{layer}.attn.hook_v"][:, head, :]
            return clean_head_output
        
        hooks = [(f"blocks.{layer}.attn.hook_v", head_hook)]
        return self.model.run_with_hooks(tokens, fwd_hooks=hooks)

5. 电路验证

5.1 行为验证

验证电路是否真正负责目标行为:

def verify_circuit_behavior(circuit, test_cases):
    """
    验证电路的行为一致性
    """
    results = []
    
    for test_case in test_cases:
        # 完整模型预测
        full_output = circuit.model(test_case.input)
        
        # 仅使用电路计算
        circuit_output = circuit.compute(test_case.input)
        
        # 检查一致性
        is_consistent = torch.allclose(full_output, circuit_output, atol=1e-3)
        
        results.append({
            'input': test_case.input,
            'full_output': full_output,
            'circuit_output': circuit_output,
            'consistent': is_consistent
        })
    
    return results

5.2 扰动分析

def perturbation_analysis(circuit, base_input, perturbation_strengths):
    """
    扰动分析:测试电路对输入扰动的敏感性
    """
    responses = []
    
    base_output = circuit.compute(base_input)
    
    for strength in perturbation_strengths:
        # 添加随机扰动
        perturbation = torch.randn_like(base_input) * strength
        perturbed_input = base_input + perturbation
        
        # 计算电路响应
        response = circuit.compute(perturbed_input)
        
        # 计算响应变化
        delta = torch.norm(response - base_output) / torch.norm(base_output)
        
        responses.append({
            'strength': strength,
            'response_norm': torch.norm(response).item(),
            'relative_change': delta.item()
        })
    
    return responses

6. 电路分析实践案例

6.1 Induction Head 电路发现

Induction Head 是一种著名的电路,负责实现上下文中的”复制”模式:

def discover_induction_head(model, tokens):
    """
    发现Induction Head电路
    - Head 1 (Q-Former): 寻找"previous token"的相同token
    - Head 2 (K-Former): 基于前一个token的预测来匹配
    """
    traces = []
    
    # 追踪第1层和第5层之间的信息流
    for layer1_head in range(model.cfg.n_heads):
        for layer2_head in range(model.cfg.n_heads):
            # 测量层1到层2的信息传递
            effect = measure_inter_head_effect(
                model, tokens,
                from_layer=1, from_head=layer1_head,
                to_layer=5, to_head=layer2_head
            )
            
            traces.append({
                'from': (1, layer1_head),
                'to': (5, layer2_head),
                'effect': effect
            })
    
    # 识别高效应连接
    significant = [t for t in traces if t['effect'] > threshold]
    
    return significant

6.2 文档级问答电路

def extract_qa_circuit(model, question_tokens, context_tokens, answer_positions):
    """
    提取问答任务的电路
    """
    # 组合输入
    full_tokens = torch.cat([question_tokens, context_tokens], dim=1)
    
    # 运行因果追溯
    tracer = CausalTracer(model)
    importance = tracer.full_circuit_trace(
        full_tokens,
        corrupted_tokens,
        target_behavior='predict_answer'
    )
    
    # 提取关键注意力头
    top_heads = torch.topk(importance.flatten(), k=10)
    circuit_heads = [(i.item() // model.cfg.n_heads, i.item() % model.cfg.n_heads) 
                     for i in top_heads.indices]
    
    return circuit_heads

7. 与相关方法的关系

7.1 与 Sparse Autoencoders 的关系

电路分析识别特定行为相关的组件,而SAE识别数据驱动的特征。两者可以互补:

  • 电路分析 → 特定任务的因果结构
  • SAE分析 → 潜在的语义特征

7.2 与特征几何的关系

电路分析关注动态信息流,而特征几何关注静态表示空间。两者共同揭示神经网络的组织原理。

8. 总结

电路分析是理解神经网络内部工作机制的有力工具。通过激活修补、路径归因和自动化发现等方法,我们可以:

  1. 定位负责特定行为的网络组件
  2. 理解信息在网络中的流动方式
  3. 验证假设的电路结构
  4. 解释模型行为的根本原因

这些方法为构建更透明、可解释的AI系统提供了重要基础。


参考资料