Certified Circuits:电路发现的稳定性保证
1. 问题背景
1.1 电路发现的挑战
机械可解释性通过识别负责特定行为的电路(神经网络的最小子网络)来理解模型。然而,电路发现面临严峻挑战:
| 问题 | 描述 | 影响 |
|---|---|---|
| 数据集敏感性 | 电路强依赖于所选的概念数据集 | 可能捕获数据集伪影而非真实概念 |
| 分布外失效 | 电路难以迁移到分布外数据 | 可靠性不足 |
| 脆弱性 | 微小的数据集变化导致电路组件剧变 | 结果不可重复 |
1.2 核心问题:Brittleness
现有电路发现方法的脆弱性:
数据集A发现的电路:组件 {N1, N2, N3, N4, N5}
数据集A'发现的电路:组件 {N1, N2, N6, N7, N8}
↑
仅有微小差异的组件
问题:这种不稳定性使得电路难以作为可靠的模型解释。
2. Certified Circuits框架
2.1 核心思想
Certified Circuits通过随机数据子采样和编辑距离界定为电路发现提供稳定性保证:
- 随机子采样:多次从数据集中随机采样子集
- 编辑距离界定:量化电路组件对数据集变化的敏感性
- 稳定性认证:识别对数据集变化不敏感的组件
2.2 形式化定义
2.2.1 编辑距离
对于两个电路组件集 和 ,定义编辑距离:
其中 表示对称差集。
2.2.2 稳定性保证
设 是概念数据集, 是其扰动版本(通过删除或添加少量样本得到), 是允许的编辑距离阈值:
含义:如果数据集变化在界定范围内,电路组件的变化也在界定范围内。
2.3 框架架构
┌─────────────────────────────────────────────────────────────┐
│ Certified Circuits 框架 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 原始数据集D ──► 随机子采样 ──► 多次电路发现 │
│ │ │ │
│ │ ▼ │
│ │ ┌─────────────┐ │
│ │ │ M次发现 │ │
│ │ │ C₁, C₂, ..., Cₘ│ │
│ │ └─────────────┘ │
│ │ │ │
│ ▼ ▼ │
│ ┌────────────────────────────────────────┐ │
│ │ 稳定性分析与认证 │ │
│ │ │ │
│ │ • 组件稳定性评分 │ │
│ │ • 编辑距离界定 │ │
│ │ • 不稳定组件识别与剔除 │ │
│ └────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ 认证的稳定电路 │ │
│ │ (Certified Circuit)│ │
│ └─────────────────┘ │
└─────────────────────────────────────────────────────────────┘
3. 算法详解
3.1 组件稳定性评分
对于每个组件 ,计算其稳定性评分:
稳定性等级:
- 高稳定性(≥0.9):几乎在所有发现中都出现
- 中稳定性(0.5-0.9):经常出现
- 低稳定性(<0.5):偶尔出现或随机出现
3.2 随机子采样策略
def certified_circuit_discovery(
model,
concept_dataset,
n_iterations: int = 50,
sample_ratio: float = 0.8,
edit_distance_threshold: float = 0.2
):
"""
Discover circuits with stability certification.
Args:
model: Neural network
concept_dataset: Concept definition samples
n_iterations: Number of random subsampling iterations
sample_ratio: Ratio of samples to keep per iteration
edit_distance_threshold: Maximum allowed edit distance
Returns:
certified_circuit: Stable circuit with certification
"""
n_samples = len(concept_dataset)
component_counts = Counter()
# Multiple rounds of circuit discovery with random subsampling
for i in range(n_iterations):
# Random subsampling
n_keep = int(n_samples * sample_ratio)
indices = np.random.choice(n_samples, n_keep, replace=False)
subsampled_dataset = [concept_dataset[j] for j in indices]
# Discover circuit on subsampled dataset
circuit_i = discover_circuit(model, subsampled_dataset)
# Count component occurrences
component_counts.update(circuit_i.components)
# Compute stability scores
stability_scores = {
component: count / n_iterations
for component, count in component_counts.items()
}
# Filter to stable components
stable_components = {
c for c, score in stability_scores.items()
if score >= edit_distance_threshold
}
# Verify edit distance bound
verified = verify_edit_distance_bound(
stability_scores,
edit_distance_threshold
)
if verified:
return CertifiedCircuit(
components=stable_components,
stability_scores=stability_scores,
certification="PROVABLY STABLE"
)
else:
# Return most stable components with weaker guarantee
return CertifiedCircuit(
components=set(list(stability_scores.keys())[:k]),
stability_scores=stability_scores,
certification="STATISTICALLY STABLE"
)
def verify_edit_distance_bound(
stability_scores,
threshold
):
"""
Verify that edit distance is bounded by threshold.
"""
# For any two subsamples, the expected edit distance
# is bounded by the instability of components
# Compute upper bound on expected edit distance
instability_sum = sum(1 - score for score in stability_scores.values())
# If instability is below threshold, edit distance is bounded
return instability_sum <= threshold3.3 稳定性验证
3.3.1 经验验证
def empirical_stability_check(
certified_circuit,
test_datasets,
original_dataset
):
"""
Empirically verify stability on held-out datasets.
"""
results = []
for test_dataset in test_datasets:
# Discover circuit on test dataset
test_circuit = discover_circuit(model, test_dataset)
# Compute edit distance
edit_dist = compute_edit_distance(
certified_circuit.components,
test_circuit.components
)
# Compute behavioral alignment
behavioral_align = compute_behavioral_alignment(
certified_circuit,
test_circuit,
validation_set
)
results.append({
'edit_distance': edit_dist,
'behavioral_alignment': behavioral_align
})
return results3.4 组件分类
┌─────────────────────────────────────────────────────────────┐
│ 组件稳定性分类 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 高稳定性 ──────────────────────────────────► 核心组件 │
│ (≥90%) 保留,高置信度 │
│ │
│ 中稳定性 ──────────────────────────────────► 辅助组件 │
│ (50-90%) 可选,可能可移除 │
│ │
│ 低稳定性 ──────────────────────────────────► 边缘组件 │
│ (<50%) 不稳定,谨慎使用 │
│ │
└─────────────────────────────────────────────────────────────┘
4. 实验结果
4.1 三种架构的验证
| 架构 | 任务 | 准确率提升 | 组件减少 |
|---|---|---|---|
| ResNet | ImageNet分类 | +56% | -80% |
| ViT | OOD分类 | +48% | -75% |
| GPT-2 | IOI任务 | +42% | -70% |
4.2 分布外泛化
| 方法 | In-Distribution | OOD-1 | OOD-2 | OOD-3 |
|---|---|---|---|---|
| 基准电路 | 85.2% | 62.1% | 58.3% | 55.7% |
| Certified | 88.7% | 78.9% | 76.2% | 74.1% |
关键结果:
- 分布外性能显著提升
- 稳定性与准确性兼得
4.3 组件数量对比
| 架构 | 基准组件数 | Certified组件数 | 减少比例 |
|---|---|---|---|
| ResNet-50 | 156 | 31 | 80% |
| ViT-B/16 | 203 | 48 | 76% |
| GPT-2-small | 89 | 27 | 70% |
5. 与形式化可解释性的关系
5.1 方法对比
| 特性 | 形式化可解释性 | Certified Circuits |
|---|---|---|
| 保证类型 | 确定性保证 | 概率性保证 |
| 验证方法 | SMT求解 | 随机子采样 |
| 计算成本 | 高 | 中 |
| 适用规模 | 小规模 | 中等规模 |
5.2 互补优势
┌─────────────────────────────────────────────────────────┐
│ 两种方法的互补性 │
├─────────────────────────────────────────────────────────┤
│ │
│ 形式化可解释性 ──► 确定性、数学严格 │
│ │ │
│ │ 互补 │
│ ▼ │
│ Certified Circuits ──► 统计有效、实际可用 │
│ │
│ 结合使用: │
│ • Certified Circuits筛选稳定组件 │
│ • 形式化方法验证关键属性 │
│ │
└─────────────────────────────────────────────────────────┘
6. PyTorch实现
import torch
import torch.nn as nn
import numpy as np
from collections import Counter, defaultdict
from typing import Set, Dict, List, Tuple
from dataclasses import dataclass
@dataclass
class Component:
"""Represents a circuit component."""
layer_idx: int
component_type: str # 'attention', 'mlp', 'neuron'
component_id: int
@dataclass
class CertifiedCircuit:
"""Circuit with stability certification."""
components: Set[Component]
stability_scores: Dict[Component, float]
certification_level: str # 'PROVABLY_STABLE', 'STATISTICALLY_STABLE'
n_iterations: int
class CircuitDiscoverer:
"""Base circuit discovery method (e.g., activation patching)."""
def __init__(self, model):
self.model = model
def discover(self, dataset):
"""Discover circuit on given dataset."""
raise NotImplementedError
class CertifiedCircuitDiscovery:
"""
Certified Circuits: Stability-guaranteed circuit discovery.
"""
def __init__(
self,
base_discoverer: CircuitDiscoverer,
n_iterations: int = 50,
sample_ratio: float = 0.8,
stability_threshold: float = 0.5
):
self.discoverer = base_discoverer
self.n_iterations = n_iterations
self.sample_ratio = sample_ratio
self.stability_threshold = stability_threshold
def discover(
self,
dataset: List,
return_stability_info: bool = False
) -> CertifiedCircuit:
"""
Discover certified stable circuit.
"""
n_samples = len(dataset)
component_counts = Counter()
all_circuits = []
# Multiple rounds of subsampled discovery
for i in range(self.n_iterations):
# Random subsample
n_keep = int(n_samples * self.sample_ratio)
indices = np.random.choice(
n_samples,
n_keep,
replace=False
)
subsampled = [dataset[j] for j in indices]
# Discover circuit
circuit = self.discoverer.discover(subsampled)
all_circuits.append(circuit)
component_counts.update(circuit.components)
# Compute stability scores
stability_scores = {
component: count / self.n_iterations
for component, count in component_counts.items()
}
# Filter to stable components
stable_components = {
comp for comp, score in stability_scores.items()
if score >= self.stability_threshold
}
# Determine certification level
avg_stability = np.mean(list(stability_scores.values()))
cert_level = "PROVABLY_STABLE" if avg_stability > 0.8 else "STATISTICALLY_STABLE"
certified = CertifiedCircuit(
components=stable_components,
stability_scores=stability_scores,
certification_level=cert_level,
n_iterations=self.n_iterations
)
if return_stability_info:
return certified, all_circuits, component_counts
return certified
def verify_edit_distance(
self,
certified: CertifiedCircuit,
test_circuits: List
) -> Dict:
"""
Verify edit distance bound empirically.
"""
results = {
'edit_distances': [],
'avg_edit_distance': 0,
'max_edit_distance': 0,
'within_threshold': True
}
threshold = len(certified.components) * (1 - self.stability_threshold)
for test_circuit in test_circuits:
edit_dist = self._compute_edit_distance(
certified.components,
test_circuit.components
)
results['edit_distances'].append(edit_dist)
if edit_dist > threshold:
results['within_threshold'] = False
results['avg_edit_distance'] = np.mean(results['edit_distances'])
results['max_edit_distance'] = np.max(results['edit_distances'])
return results
def _compute_edit_distance(
self,
components1: Set[Component],
components2: Set[Component]
) -> int:
"""Compute symmetric difference size."""
symmetric_diff = components1 ^ components2
return len(symmetric_diff)
def create_certified_circuit_workflow(model, concept_dataset):
"""
Complete workflow for certified circuit discovery.
"""
# Initialize base discoverer
discoverer = ActivationPatchingDiscoverer(model)
# Initialize certified discovery
cert_discoverer = CertifiedCircuitDiscovery(
base_discoverer=discoverer,
n_iterations=50,
sample_ratio=0.8,
stability_threshold=0.5
)
# Discover certified circuit
certified = cert_discoverer.discover(concept_dataset)
# Additional verification on test datasets
# (implementation details)
return certified7. 应用与实践
7.1 使用指南
- 选择数据集:准备概念数据集,包含正负样本
- 设置参数:选择迭代次数、采样比例、稳定性阈值
- 运行发现:获取认证的稳定电路
- 验证结果:在测试数据集上验证
7.2 注意事项
- 计算成本:迭代次数影响结果的可靠性
- 阈值选择:需要根据任务调整稳定性阈值
- 验证必要:即使有认证,也应在实际数据上验证
8. 总结与展望
8.1 核心贡献
- 稳定性保证:通过随机子采样实现可证明的稳定性
- 简洁电路:更少的组件实现更强的性能
- 分布外泛化:显著提升分布外可靠性
8.2 局限性
- 计算成本较高
- 仍依赖经验性的基础电路发现方法
- 对极罕见组件的识别能力有限
8.3 未来方向
- 与形式化方法的深度结合
- 更高效的子采样策略
- 自适应阈值调整