Certified Circuits:电路发现的稳定性保证

1. 问题背景

1.1 电路发现的挑战

机械可解释性通过识别负责特定行为的电路(神经网络的最小子网络)来理解模型。然而,电路发现面临严峻挑战:

问题描述影响
数据集敏感性电路强依赖于所选的概念数据集可能捕获数据集伪影而非真实概念
分布外失效电路难以迁移到分布外数据可靠性不足
脆弱性微小的数据集变化导致电路组件剧变结果不可重复

1.2 核心问题:Brittleness

现有电路发现方法的脆弱性:

数据集A发现的电路:组件 {N1, N2, N3, N4, N5}
数据集A'发现的电路:组件 {N1, N2, N6, N7, N8}
         ↑
      仅有微小差异的组件

问题:这种不稳定性使得电路难以作为可靠的模型解释。

2. Certified Circuits框架

2.1 核心思想

Certified Circuits通过随机数据子采样编辑距离界定为电路发现提供稳定性保证:

  1. 随机子采样:多次从数据集中随机采样子集
  2. 编辑距离界定:量化电路组件对数据集变化的敏感性
  3. 稳定性认证:识别对数据集变化不敏感的组件

2.2 形式化定义

2.2.1 编辑距离

对于两个电路组件集 ,定义编辑距离

其中 表示对称差集。

2.2.2 稳定性保证

是概念数据集, 是其扰动版本(通过删除或添加少量样本得到), 是允许的编辑距离阈值:

含义:如果数据集变化在界定范围内,电路组件的变化也在界定范围内。

2.3 框架架构

┌─────────────────────────────────────────────────────────────┐
│                   Certified Circuits 框架                    │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   原始数据集D ──► 随机子采样 ──► 多次电路发现              │
│         │              │                                    │
│         │              ▼                                    │
│         │       ┌─────────────┐                            │
│         │       │   M次发现   │                            │
│         │       │ C₁, C₂, ..., Cₘ│                          │
│         │       └─────────────┘                            │
│         │              │                                    │
│         ▼              ▼                                    │
│   ┌────────────────────────────────────────┐               │
│   │         稳定性分析与认证                │               │
│   │                                        │               │
│   │   • 组件稳定性评分                      │               │
│   │   • 编辑距离界定                       │               │
│   │   • 不稳定组件识别与剔除               │               │
│   └────────────────────────────────────────┘               │
│                        │                                   │
│                        ▼                                   │
│              ┌─────────────────┐                           │
│              │  认证的稳定电路 │                           │
│              │  (Certified Circuit)│                          │
│              └─────────────────┘                           │
└─────────────────────────────────────────────────────────────┘

3. 算法详解

3.1 组件稳定性评分

对于每个组件 ,计算其稳定性评分

稳定性等级

  • 高稳定性(≥0.9):几乎在所有发现中都出现
  • 中稳定性(0.5-0.9):经常出现
  • 低稳定性(<0.5):偶尔出现或随机出现

3.2 随机子采样策略

def certified_circuit_discovery(
    model,
    concept_dataset,
    n_iterations: int = 50,
    sample_ratio: float = 0.8,
    edit_distance_threshold: float = 0.2
):
    """
    Discover circuits with stability certification.
    
    Args:
        model: Neural network
        concept_dataset: Concept definition samples
        n_iterations: Number of random subsampling iterations
        sample_ratio: Ratio of samples to keep per iteration
        edit_distance_threshold: Maximum allowed edit distance
    
    Returns:
        certified_circuit: Stable circuit with certification
    """
    n_samples = len(concept_dataset)
    component_counts = Counter()
    
    # Multiple rounds of circuit discovery with random subsampling
    for i in range(n_iterations):
        # Random subsampling
        n_keep = int(n_samples * sample_ratio)
        indices = np.random.choice(n_samples, n_keep, replace=False)
        subsampled_dataset = [concept_dataset[j] for j in indices]
        
        # Discover circuit on subsampled dataset
        circuit_i = discover_circuit(model, subsampled_dataset)
        
        # Count component occurrences
        component_counts.update(circuit_i.components)
    
    # Compute stability scores
    stability_scores = {
        component: count / n_iterations
        for component, count in component_counts.items()
    }
    
    # Filter to stable components
    stable_components = {
        c for c, score in stability_scores.items()
        if score >= edit_distance_threshold
    }
    
    # Verify edit distance bound
    verified = verify_edit_distance_bound(
        stability_scores,
        edit_distance_threshold
    )
    
    if verified:
        return CertifiedCircuit(
            components=stable_components,
            stability_scores=stability_scores,
            certification="PROVABLY STABLE"
        )
    else:
        # Return most stable components with weaker guarantee
        return CertifiedCircuit(
            components=set(list(stability_scores.keys())[:k]),
            stability_scores=stability_scores,
            certification="STATISTICALLY STABLE"
        )
 
 
def verify_edit_distance_bound(
    stability_scores,
    threshold
):
    """
    Verify that edit distance is bounded by threshold.
    """
    # For any two subsamples, the expected edit distance
    # is bounded by the instability of components
    
    # Compute upper bound on expected edit distance
    instability_sum = sum(1 - score for score in stability_scores.values())
    
    # If instability is below threshold, edit distance is bounded
    return instability_sum <= threshold

3.3 稳定性验证

3.3.1 经验验证

def empirical_stability_check(
    certified_circuit,
    test_datasets,
    original_dataset
):
    """
    Empirically verify stability on held-out datasets.
    """
    results = []
    
    for test_dataset in test_datasets:
        # Discover circuit on test dataset
        test_circuit = discover_circuit(model, test_dataset)
        
        # Compute edit distance
        edit_dist = compute_edit_distance(
            certified_circuit.components,
            test_circuit.components
        )
        
        # Compute behavioral alignment
        behavioral_align = compute_behavioral_alignment(
            certified_circuit,
            test_circuit,
            validation_set
        )
        
        results.append({
            'edit_distance': edit_dist,
            'behavioral_alignment': behavioral_align
        })
    
    return results

3.4 组件分类

┌─────────────────────────────────────────────────────────────┐
│                    组件稳定性分类                           │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  高稳定性 ──────────────────────────────────► 核心组件     │
│  (≥90%)                               保留,高置信度        │
│                                                             │
│  中稳定性 ──────────────────────────────────► 辅助组件     │
│  (50-90%)                             可选,可能可移除      │
│                                                             │
│  低稳定性 ──────────────────────────────────► 边缘组件     │
│  (<50%)                               不稳定,谨慎使用      │
│                                                             │
└─────────────────────────────────────────────────────────────┘

4. 实验结果

4.1 三种架构的验证

架构任务准确率提升组件减少
ResNetImageNet分类+56%-80%
ViTOOD分类+48%-75%
GPT-2IOI任务+42%-70%

4.2 分布外泛化

方法In-DistributionOOD-1OOD-2OOD-3
基准电路85.2%62.1%58.3%55.7%
Certified88.7%78.9%76.2%74.1%

关键结果

  • 分布外性能显著提升
  • 稳定性与准确性兼得

4.3 组件数量对比

架构基准组件数Certified组件数减少比例
ResNet-501563180%
ViT-B/162034876%
GPT-2-small892770%

5. 与形式化可解释性的关系

5.1 方法对比

特性形式化可解释性Certified Circuits
保证类型确定性保证概率性保证
验证方法SMT求解随机子采样
计算成本
适用规模小规模中等规模

5.2 互补优势

┌─────────────────────────────────────────────────────────┐
│              两种方法的互补性                            │
├─────────────────────────────────────────────────────────┤
│                                                         │
│   形式化可解释性 ──► 确定性、数学严格                    │
│        │                                                │
│        │  互补                                         │
│        ▼                                                │
│   Certified Circuits ──► 统计有效、实际可用             │
│                                                         │
│   结合使用:                                            │
│   • Certified Circuits筛选稳定组件                      │
│   • 形式化方法验证关键属性                              │
│                                                         │
└─────────────────────────────────────────────────────────┘

6. PyTorch实现

import torch
import torch.nn as nn
import numpy as np
from collections import Counter, defaultdict
from typing import Set, Dict, List, Tuple
from dataclasses import dataclass
 
 
@dataclass
class Component:
    """Represents a circuit component."""
    layer_idx: int
    component_type: str  # 'attention', 'mlp', 'neuron'
    component_id: int
 
 
@dataclass
class CertifiedCircuit:
    """Circuit with stability certification."""
    components: Set[Component]
    stability_scores: Dict[Component, float]
    certification_level: str  # 'PROVABLY_STABLE', 'STATISTICALLY_STABLE'
    n_iterations: int
    
 
class CircuitDiscoverer:
    """Base circuit discovery method (e.g., activation patching)."""
    def __init__(self, model):
        self.model = model
        
    def discover(self, dataset):
        """Discover circuit on given dataset."""
        raise NotImplementedError
 
 
class CertifiedCircuitDiscovery:
    """
    Certified Circuits: Stability-guaranteed circuit discovery.
    """
    def __init__(
        self,
        base_discoverer: CircuitDiscoverer,
        n_iterations: int = 50,
        sample_ratio: float = 0.8,
        stability_threshold: float = 0.5
    ):
        self.discoverer = base_discoverer
        self.n_iterations = n_iterations
        self.sample_ratio = sample_ratio
        self.stability_threshold = stability_threshold
        
    def discover(
        self,
        dataset: List,
        return_stability_info: bool = False
    ) -> CertifiedCircuit:
        """
        Discover certified stable circuit.
        """
        n_samples = len(dataset)
        component_counts = Counter()
        all_circuits = []
        
        # Multiple rounds of subsampled discovery
        for i in range(self.n_iterations):
            # Random subsample
            n_keep = int(n_samples * self.sample_ratio)
            indices = np.random.choice(
                n_samples, 
                n_keep, 
                replace=False
            )
            subsampled = [dataset[j] for j in indices]
            
            # Discover circuit
            circuit = self.discoverer.discover(subsampled)
            all_circuits.append(circuit)
            component_counts.update(circuit.components)
        
        # Compute stability scores
        stability_scores = {
            component: count / self.n_iterations
            for component, count in component_counts.items()
        }
        
        # Filter to stable components
        stable_components = {
            comp for comp, score in stability_scores.items()
            if score >= self.stability_threshold
        }
        
        # Determine certification level
        avg_stability = np.mean(list(stability_scores.values()))
        cert_level = "PROVABLY_STABLE" if avg_stability > 0.8 else "STATISTICALLY_STABLE"
        
        certified = CertifiedCircuit(
            components=stable_components,
            stability_scores=stability_scores,
            certification_level=cert_level,
            n_iterations=self.n_iterations
        )
        
        if return_stability_info:
            return certified, all_circuits, component_counts
        return certified
    
    def verify_edit_distance(
        self,
        certified: CertifiedCircuit,
        test_circuits: List
    ) -> Dict:
        """
        Verify edit distance bound empirically.
        """
        results = {
            'edit_distances': [],
            'avg_edit_distance': 0,
            'max_edit_distance': 0,
            'within_threshold': True
        }
        
        threshold = len(certified.components) * (1 - self.stability_threshold)
        
        for test_circuit in test_circuits:
            edit_dist = self._compute_edit_distance(
                certified.components,
                test_circuit.components
            )
            results['edit_distances'].append(edit_dist)
            
            if edit_dist > threshold:
                results['within_threshold'] = False
        
        results['avg_edit_distance'] = np.mean(results['edit_distances'])
        results['max_edit_distance'] = np.max(results['edit_distances'])
        
        return results
    
    def _compute_edit_distance(
        self,
        components1: Set[Component],
        components2: Set[Component]
    ) -> int:
        """Compute symmetric difference size."""
        symmetric_diff = components1 ^ components2
        return len(symmetric_diff)
 
 
def create_certified_circuit_workflow(model, concept_dataset):
    """
    Complete workflow for certified circuit discovery.
    """
    # Initialize base discoverer
    discoverer = ActivationPatchingDiscoverer(model)
    
    # Initialize certified discovery
    cert_discoverer = CertifiedCircuitDiscovery(
        base_discoverer=discoverer,
        n_iterations=50,
        sample_ratio=0.8,
        stability_threshold=0.5
    )
    
    # Discover certified circuit
    certified = cert_discoverer.discover(concept_dataset)
    
    # Additional verification on test datasets
    # (implementation details)
    
    return certified

7. 应用与实践

7.1 使用指南

  1. 选择数据集:准备概念数据集,包含正负样本
  2. 设置参数:选择迭代次数、采样比例、稳定性阈值
  3. 运行发现:获取认证的稳定电路
  4. 验证结果:在测试数据集上验证

7.2 注意事项

  • 计算成本:迭代次数影响结果的可靠性
  • 阈值选择:需要根据任务调整稳定性阈值
  • 验证必要:即使有认证,也应在实际数据上验证

8. 总结与展望

8.1 核心贡献

  1. 稳定性保证:通过随机子采样实现可证明的稳定性
  2. 简洁电路:更少的组件实现更强的性能
  3. 分布外泛化:显著提升分布外可靠性

8.2 局限性

  • 计算成本较高
  • 仍依赖经验性的基础电路发现方法
  • 对极罕见组件的识别能力有限

8.3 未来方向

  • 与形式化方法的深度结合
  • 更高效的子采样策略
  • 自适应阈值调整

参考资料