PIE框架:跨层Transcoder的高效电路发现

1. 问题背景

1.1 现有电路发现的挑战

传统电路发现方法面临严峻的计算挑战:

问题描述影响
穷举搜索对均匀采样的组件进行解释对不相关组件浪费计算
高成本解释成本随组件数线性增长无法扩展到大型模型
不精确无法识别关键跨层特征电路不完整

1.2 关键洞察

核心问题:并非所有组件对目标行为都重要,但现有方法对所有组件一视同仁。

PIE解决方案:先剪枝,再解释,最后评估。

2. PIE框架详解

2.1 框架概述

PIE代表Prune, Interpret, Evaluate,是一种CLT(Cross-Layer Transcoder)原生的端到端剪枝框架:

┌──────────────────────────────────────────────────────────────┐
│                        PIE Framework                          │
├──────────────────────────────────────────────────────────────┤
│                                                               │
│   输入: 模型 + 目标行为                                        │
│      │                                                       │
│      ▼                                                       │
│   ┌───────────────┐                                         │
│   │    Prune     │ ←── 剪枝优先于解释                       │
│   │  (特征剪枝)   │                                         │
│   └───────────────┘                                         │
│      │                                                       │
│      ▼                                                       │
│   ┌───────────────┐                                         │
│   │   Interpret   │ ←── 仅对剪枝后的特征进行解释             │
│   │  (自动解释)   │                                         │
│   └───────────────┘                                         │
│      │                                                       │
│      ▼                                                       │
│   ┌───────────────┐                                         │
│   │   Evaluate    │ ←── 行为保真度评估                       │
│   │  (解释评估)   │                                         │
│   └───────────────┘                                         │
│      │                                                       │
│      ▼                                                       │
│   输出: 高效、可信的电路                                       │
└──────────────────────────────────────────────────────────────┘

2.2 三阶段详解

2.2.1 Prune阶段

目标:识别对目标行为最重要的特征

策略:利用行为相关性进行剪枝,而非均匀采样

def prune_features(model, target_behavior, budget_k):
    """
    Prune to top-k features based on behavioral relevance.
    
    Args:
        model: Neural network with CLT
        target_behavior: Definition of target behavior
        budget_k: Number of features to keep
    
    Returns:
        pruned_features: Set of important features
    """
    # Compute behavioral importance for all features
    feature_importances = {}
    
    for feature in model.clt_features:
        # Measure importance via activation patching
        importance = measure_behavioral_importance(
            model, feature, target_behavior
        )
        feature_importances[feature] = importance
    
    # Select top-k features
    sorted_features = sorted(
        feature_importances.items(),
        key=lambda x: x[1],
        reverse=True
    )
    
    pruned_features = {f for f, _ in sorted_features[:budget_k]}
    
    return pruned_features

2.2.2 Interpret阶段

目标:解释剪枝后特征的功能

方法:特征归因修补(FAP)

def interpret_feature(feature, model, context):
    """
    Interpret a CLT feature using Feature Attribution Patching.
    """
    # Get gradient with respect to feature
    output = model(context)
    output.backward()
    
    # Compute FAP score
    # Aggregates gradient-weighted write contributions
    fap_score = compute_fap_score(feature, model)
    
    return fap_score
 
 
def compute_fap_score(feature, model):
    """
    Compute Feature Attribution Patching score.
    
    FAP = Σ (∂output / ∂feature_write) × feature_value
    """
    grad = feature.gradient  # From backward pass
    value = feature.value
    
    # Gradient-weighted contribution
    fap = (grad * value).sum()
    
    return fap.item()

2.2.3 Evaluate阶段

目标:评估解释的质量

指标:KL散度行为保持 + FADE风格解释质量

def evaluate_interpretation(
    original_circuit,
    interpreted_circuit,
    test_behaviors
):
    """
    Evaluate interpretation quality.
    """
    # Behavioral fidelity: KL divergence
    kl_divergences = []
    
    for behavior in test_behaviors:
        # Sample responses
        original_response = evaluate_behavior(
            original_circuit, behavior
        )
        interpreted_response = evaluate_behavior(
            interpreted_circuit, behavior
        )
        
        # Compute KL divergence
        kl = F.kl_div(
            original_response.log(),
            interpreted_response,
            reduction='batchmean'
        )
        kl_divergences.append(kl)
    
    avg_kl = np.mean(kl_divergences)
    
    return {
        'behavioral_fidelity': 1.0 - min(avg_kl, 1.0),  # Convert to accuracy-like metric
        'kl_divergences': kl_divergences
    }

3. FAP-Synergy方法

3.1 协同感知重排序

FAP-Synergy引入系统性的协同感知重排序:

def fap_synergy_reranking(
    features,
    behaviors,
    model,
    k: int
):
    """
    Synergy-aware reranking for strict budget constraints.
    
    Args:
        features: Candidate features
        behaviors: Set of target behaviors
        model: Neural network model
        k: Budget constraint
    
    Returns:
        selected_features: Optimally selected k features
    """
    # Compute individual FAP scores
    individual_scores = {
        f: compute_fap_score(f, model)
        for f in features
    }
    
    # Compute synergy scores
    synergy_matrix = {}
    for f1, f2 in combinations(features, 2):
        synergy = compute_synergy(f1, f2, behaviors, model)
        synergy_matrix[(f1, f2)] = synergy
    
    # Greedy selection with synergy
    selected = []
    remaining = set(features)
    
    while len(selected) < k:
        best_feature = None
        best_score = float('-inf')
        
        for f in remaining:
            # Individual score
            score = individual_scores[f]
            
            # Synergy bonus with selected features
            synergy_bonus = sum(
                synergy_matrix.get((f, s), 0) + synergy_matrix.get((s, f), 0)
                for s in selected
            ) / max(len(selected), 1)
            
            # Combined score
            combined_score = score + synergy_bonus
            
            if combined_score > best_score:
                best_score = combined_score
                best_feature = f
        
        selected.append(best_feature)
        remaining.remove(best_feature)
    
    return set(selected)
 
 
def compute_synergy(f1, f2, behaviors, model):
    """
    Compute synergy score between two features.
    """
    # Joint importance
    joint_importance = measure_behavioral_importance(
        model, {f1, f2}, behaviors
    )
    
    # Individual importances
    ind1 = measure_behavioral_importance(model, {f1}, behaviors)
    ind2 = measure_behavioral_importance(model, {f2}, behaviors)
    
    # Synergy = joint - sum of individuals
    synergy = joint_importance - (ind1 + ind2)
    
    return max(synergy, 0)  # Only positive synergy

3.2 预算约束下的操作机制

┌─────────────────────────────────────────────────────────────┐
│          不同预算下的最优策略                                 │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  宽松预算 (K ≥ 200):                                       │
│  ┌─────────────────────────────────────────────────────┐   │
│  │ 基础FAP和adapted baselines表现稳健                    │   │
│  │ 策略: 直接选择top-k                                   │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
│  严格预算 (K < 100):                                       │
│  ┌─────────────────────────────────────────────────────┐   │
│  │ FAP-Synergy表现最优                                   │   │
│  │ 策略: 协同感知重排序                                  │   │
│  └─────────────────────────────────────────────────────┘   │
│                                                             │
│  "有效预算"优势:                                          │
│  FAP-Synergy在K=50达到基线在K=75的行为保真度              │
│  → 节省33%的解释成本                                      │
│                                                             │
└─────────────────────────────────────────────────────────────┘

4. 实验结果

4.1 IOI任务上的表现

方法K=50K=100K=200K=400K=800
基准FAP72.3%81.2%87.5%91.2%94.1%
FAP-Synergy78.9%85.3%89.1%92.1%94.5%

4.2 Doc-String任务

方法K=50K=100K=200
基准FAP68.5%76.2%82.4%
FAP-Synergy74.1%80.3%84.1%

4.3 效率对比

方法行为保真度解释成本
基线 (K=75)78.9%75 units
FAP-Synergy (K=50)78.9%50 units
节省-33%

5. 框架优势

5.1 剪枝优先范式

传统方法PIE方法
先解释所有先剪枝关键特征
O(N) 解释成本O(K) 解释成本,K << N
不精确高精度

5.2 协同感知

  • 识别互补特征:选择协同工作的特征
  • 避免冗余:减少重叠特征的重复解释
  • 预算效率:在严格预算下表现更优

6. 实践指南

6.1 使用流程

def pie_circuit_discovery(model, target_behavior, budget_k):
    """
    Complete PIE circuit discovery workflow.
    """
    # Stage 1: Prune
    print("Stage 1: Pruning features...")
    pruned_features = prune_features(model, target_behavior, budget_k)
    print(f"Pruned to {len(pruned_features)} features")
    
    # Stage 2: Interpret
    print("Stage 2: Interpreting features...")
    interpretations = {}
    for feature in pruned_features:
        interpretations[feature] = interpret_feature(feature, model, target_behavior)
    
    # Optional: FAP-Synergy reranking
    if budget_k < 100:
        print("Applying FAP-Synergy reranking...")
        pruned_features = fap_synergy_reranking(
            pruned_features,
            target_behavior.behaviors,
            model,
            budget_k
        )
    
    # Stage 3: Evaluate
    print("Stage 3: Evaluating interpretation...")
    results = evaluate_interpretation(
        model,
        pruned_features,
        interpretations,
        target_behavior.test_set
    )
    
    return {
        'features': pruned_features,
        'interpretations': interpretations,
        'evaluation': results
    }

6.2 注意事项

  1. 预算选择:根据精度要求选择合适的K
  2. 行为定义:清晰定义目标行为
  3. 协同感知:严格预算下使用FAP-Synergy

7. 总结与展望

7.1 核心贡献

  1. 剪枝优先范式:将解释成本从O(N)降低到O(K)
  2. FAP方法:梯度加权的特征归因
  3. FAP-Synergy:协同感知的特征选择

7.2 实用性

  • 在K=50达到基线K=75的行为保真度
  • 节省33%的解释成本
  • 对各种预算约束的适应性

7.3 未来方向

  • 自动预算选择
  • 多目标协同优化
  • 与其他可解释性方法的结合

参考资料