概述
机制可解释性(Mechanistic Interpretability)旨在理解神经网络内部工作原理的”电路”机制。形式化方法为这一领域带来了严格的理论基础,使我们能够对电路发现给出可证明的正确性保证。1
本文系统梳理形式化机制可解释性理论的核心内容:
- 电路发现的理论框架
- SMT(Satisfiability Modulo Theories)验证方法
- 稳定性与鲁棒性保证
- 形式化电路的正确性证明
电路发现的形式化基础
从神经元到电路
电路(Circuit)是由神经元之间的特定连接组成的子图,这些连接共同实现某种计算功能。
设 为神经网络, 为其计算图。电路 是边的子集。
电路发现的目标:找到最小的边集 使得移除 后网络的输出发生显著变化。
定义与符号
设 为神经网络, 为参数。电路可以形式化为:
定义(电路):电路 是参数子集 的索引,使得:
- 完整性: 在验证集 上
- 最小性:对于任意真子集 ,
其中 表示在某种度量下近似。
电路发现作为组合优化
电路发现问题可以形式化为:
其中 是仅使用电路 中权重的网络输出。
SMT验证框架
SMT基础
SMT(Satisfiability Modulo Theories)求解器可以检查一阶逻辑公式在特定理论(如实数算术、位向量)下的可满足性。
对于电路验证,我们使用SMT来证明电路的正确性。
电路验证的SMT编码
设 为输入, 为候选电路。验证问题编码为:
这可以转化为SMT公式:
工具与实现
import torch
import numpy as np
from z3 import *
class SMTBasedCircuitVerifier:
"""
基于SMT的电路验证器
使用Z3求解器验证电路的正确性
"""
def __init__(self, model, circuit_nodes):
self.model = model
self.circuit_nodes = set(circuit_nodes) # 电路中的节点索引
self.device = next(model.parameters()).device
def create_input_constraints(self, input_bounds):
"""
创建输入约束:x_i ∈ [l_i, u_i]
"""
constraints = []
z3_vars = {}
for i, (lower, upper) in enumerate(input_bounds):
x_i = Real(f'x_{i}')
z3_vars[i] = x_i
constraints.append(x_i >= lower)
constraints.append(x_i <= upper)
return constraints, z3_vars
def encode_network(self, z3_vars):
"""
将神经网络编码为Z3公式
假设简单MLP: h = ReLU(Wx + b)
"""
# 获取网络参数
layers = []
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
W = module.weight.data.cpu().numpy()
b = module.bias.data.cpu().numpy()
layers.append((W, b))
# 创建中间变量
z3_layers = []
current = z3_vars # 输入层
for i, (W, b) in enumerate(layers):
layer_vars = {}
for j in range(W.shape[0]):
# 线性组合
linear_comb = sum(W[j, k] * current[k] for k in range(W.shape[1]))
linear_comb = linear_comb + b[j]
# ReLU激活
h_j = If(linear_comb > 0, linear_comb, 0)
layer_vars[j] = h_j
current = layer_vars
z3_layers.append(layer_vars)
return z3_layers
def verify_circuit(self, input_bounds, output_tolerance=1e-6):
"""
验证电路的正确性
返回: (is_valid, counterexample)
"""
s = Optimize() if False else Solver()
# 创建输入变量和约束
constraints, z3_vars = self.create_input_constraints(input_bounds)
for c in constraints:
s.add(c)
# 创建完整网络
full_layers = self.encode_network(z3_vars)
# 创建电路网络(只包含电路节点)
circuit_layers = []
for i, layer_vars in enumerate(full_layers):
circuit_layer = {}
for j, var in layer_vars.items():
if j in self.circuit_nodes:
circuit_layer[j] = var
circuit_layers.append(circuit_layer)
# 添加输出差异约束
final_layer = full_layers[-1]
circuit_final = circuit_layers[-1] if circuit_layers else {}
# 简化:比较第一维输出
if final_layer and (circuit_final or len(circuit_final) == len(final_layer)):
diff = final_layer[0] - (circuit_final.get(0, RealVal(0)))
s.add(abs(diff) > output_tolerance)
# 检查可满足性
result = s.check()
if result == sat:
model = s.model()
counterexample = {str(k): model[k] for k in z3_vars.values()}
return False, counterexample
else:
return True, None
class ApproximateCircuitVerifier:
"""
近似电路验证器(当精确SMT不可行时使用)
使用采样和局部验证
"""
def __init__(self, model, circuit_nodes):
self.model = model
self.circuit_nodes = circuit_nodes
self.device = next(model.parameters()).device
def compute_ablation_effect(self, x, y):
"""
计算消融电路节点的效果
"""
# 完整输出
self.model.eval()
with torch.no_grad():
full_output = self.model(x)
# 创建消融版本
def circuit_ablation_hook(module, input, output):
# 将电路节点置零
output_ablate = output.clone()
for idx in self.circuit_nodes:
if idx < output.shape[-1]:
output_ablate[..., idx] = 0
return output_ablate
# 注册钩子
handles = []
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
handle = module.register_forward_hook(circuit_ablation_hook)
handles.append(handle)
with torch.no_grad():
ablated_output = self.model(x)
# 移除钩子
for handle in handles:
handle.remove()
# 计算差异
diff = torch.norm(full_output - ablated_output, p=2) / full_output.numel()
return diff.item()
def statistical_verification(self, dataloader, n_samples=1000,
error_threshold=1e-3):
"""
统计验证:检查电路消融的误差
"""
errors = []
self.model.eval()
for i, (x, y) in enumerate(dataloader):
if i >= n_samples:
break
x = x.to(self.device)
if x.ndim > 2:
x = x.view(x.size(0), -1)
error = self.compute_ablation_effect(x, y)
errors.append(error)
errors = np.array(errors)
return {
'mean_error': errors.mean(),
'std_error': errors.std(),
'max_error': errors.max(),
'percentile_95': np.percentile(errors, 95),
'is_valid': errors.mean() < error_threshold
}稳定性与鲁棒性保证
电路稳定性定义
定义(电路稳定性):电路 称为 -稳定的,如果对于任意输入扰动 :
其中 是Lipschitz常数。
随机子采样稳定性
定理(子采样稳定性):设 是从数据子集 发现的电路, 是完整数据集。若 ,则:
其中 是正确/错误电路之间的性能差距。
class CircuitStabilityAnalyzer:
"""
电路稳定性分析器
"""
def __init__(self, model, circuit_nodes):
self.model = model
self.circuit_nodes = circuit_nodes
self.device = next(model.parameters()).device
def compute_circuit_sensitivity(self, x, y):
"""
计算电路节点的敏感度
敏感度 = 移除节点后输出的变化
"""
self.model.eval()
# 完整输出
with torch.no_grad():
full_output = self.model(x)
sensitivities = {}
for node in self.circuit_nodes:
# 消融该节点
def create_hook(node_idx):
def hook(module, input, output):
output_ablate = output.clone()
if node_idx < output.shape[-1]:
output_ablate[..., node_idx] = 0
return output_ablate
return hook
handle = None
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
handle = module.register_forward_hook(create_hook(node))
break
with torch.no_grad():
ablated_output = self.model(x)
if handle:
handle.remove()
# 计算敏感度
sensitivity = torch.norm(full_output - ablated_output, p=2).item()
sensitivities[node] = sensitivity
return sensitivities
def compute_lipschitz_bound(self, x):
"""
计算电路的Lipschitz常数上界
"""
self.model.eval()
# 使用幂迭代法计算谱范数
from torch.linalg import svd
lipschitz_bounds = []
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
W = module.weight.data
# 谱范数(最大奇异值)
s = torch.linalg.svd(W, compute_uv=False)
lipschitz_bounds.append(s.max().item())
# 乘积给出Lipschitz上界
lipschitz_bound = np.prod(lipschitz_bounds)
return lipschitz_bound
def stability_under_perturbation(self, x, y, epsilon=0.1, n_perturbations=100):
"""
测试电路在扰动下的稳定性
"""
self.model.eval()
# 原始输出
with torch.no_grad():
original_output = self.model(x)
# 扰动输出
perturbed_errors = []
for _ in range(n_perturbations):
delta_x = torch.randn_like(x) * epsilon
x_perturbed = x + delta_x
with torch.no_grad():
perturbed_output = self.model(x_perturbed)
error = torch.norm(original_output - perturbed_output, p=2).item()
perturbed_errors.append(error)
return {
'mean_error': np.mean(perturbed_errors),
'std_error': np.std(perturbed_errors),
'max_error': np.max(perturbed_errors),
'lipschitz_bound': self.compute_lipschitz_bound(x)
}电路最小性与验证
最小电路定义
定义(最小电路):电路 是最小的,如果不存在真子集 使得 。
验证方法
class CircuitMinimalityVerifier:
"""
电路最小性验证器
"""
def __init__(self, model, circuit_nodes):
self.model = model
self.circuit_nodes = circuit_nodes
self.device = next(model.parameters()).device
def check_minimality(self, dataloader, tolerance=1e-4):
"""
检查电路的最小性
尝试移除每个节点,检查是否仍然等效
"""
self.model.eval()
# 参考输出(完整电路)
ref_outputs = []
with torch.no_grad():
for x, _ in dataloader:
x = x.to(self.device)
output = self.model(x)
ref_outputs.append(output)
# 对每个节点,检查是否可以移除
redundant_nodes = []
essential_nodes = []
for node in self.circuit_nodes:
# 创建消融版本
def create_hook(node_idx):
def hook(module, input, output):
output_ablate = output.clone()
if node_idx < output.shape[-1]:
output_ablate[..., node_idx] = 0
return output_ablate
return hook
handle = None
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
handle = module.register_forward_hook(create_hook(node))
break
# 计算消融后的输出
ablated_outputs = []
with torch.no_grad():
for x, _ in dataloader:
x = x.to(self.device)
output = self.model(x)
ablated_outputs.append(output)
if handle:
handle.remove()
# 比较输出
ref_cat = torch.cat(ref_outputs, dim=0)
ablated_cat = torch.cat(ablated_outputs, dim=0)
error = torch.norm(ref_cat - ablated_cat, p=2).item() / ref_cat.numel()
if error < tolerance:
redundant_nodes.append(node)
else:
essential_nodes.append(node)
return {
'essential_nodes': essential_nodes,
'redundant_nodes': redundant_nodes,
'minimality_ratio': len(essential_nodes) / len(self.circuit_nodes)
}
def find_minimal_subcircuit(self, dataloader, tolerance=1e-4):
"""
使用贪心搜索找到最小子电路
"""
self.model.eval()
# 计算原始电路输出
ref_outputs = []
with torch.no_grad():
for x, _ in dataloader:
x = x.to(self.device)
output = self.model(x)
ref_outputs.append(output)
ref_cat = torch.cat(ref_outputs, dim=0)
# 贪心移除
current_nodes = set(self.circuit_nodes)
minimal_nodes = set(self.circuit_nodes)
# 按敏感度排序
sensitivities = {}
for node in current_nodes:
sensitivity = self._compute_node_sensitivity(node, dataloader, ref_cat)
sensitivities[node] = sensitivity
sorted_nodes = sorted(current_nodes, key=lambda n: sensitivities[n])
# 尝试移除低敏感度节点
for node in sorted_nodes:
current_nodes.remove(node)
# 检查是否仍然等效
if self._check_equivalence(current_nodes, dataloader, ref_cat, tolerance):
minimal_nodes.remove(node)
else:
current_nodes.add(node)
return {
'minimal_nodes': list(minimal_nodes),
'reduction_ratio': 1 - len(minimal_nodes) / len(self.circuit_nodes)
}
def _compute_node_sensitivity(self, node, dataloader, ref_output):
"""计算节点的敏感度"""
# 简化实现
return 1.0
def _check_equivalence(self, nodes, dataloader, ref_output, tolerance):
"""检查移除节点后是否等效"""
# 简化实现
return False电路可迁移性分析
迁移性定义
定义(电路迁移性):电路 从模型 到 是可迁移的,如果 在 中仍然实现相同的计算功能。
迁移性度量
class CircuitTransferabilityAnalyzer:
"""
电路迁移性分析器
"""
def __init__(self, source_model, target_model, circuit_nodes):
self.source_model = source_model
self.target_model = target_model
self.circuit_nodes = circuit_nodes
self.device = next(source_model.parameters()).device
def compute_transferability_score(self, dataloader, task):
"""
计算电路的迁移性分数
"""
# 在源模型上提取电路
source_circuit = self._extract_circuit()
# 在目标模型上测试
target_model = self.target_model.to(self.device)
target_model.eval()
# 测试性能
correct = 0
total = 0
with torch.no_grad():
for x, y in dataloader:
x = x.to(self.device)
y = y.to(self.device)
output = target_model(x)
pred = output.argmax(dim=1)
correct += (pred == y).sum().item()
total += y.size(0)
accuracy = correct / total
return {
'transfer_accuracy': accuracy,
'source_circuit_size': len(source_circuit),
'transferability_score': accuracy # 简化的迁移性分数
}
def _extract_circuit(self):
"""提取源模型中的电路"""
return self.circuit_nodes与形式化验证的结合
电路正确性证明框架
class FormalCircuitProver:
"""
形式化电路证明器
结合SMT和符号执行进行电路验证
"""
def __init__(self, model):
self.model = model
def prove_circuit_correctness(self, circuit, spec, input_domain):
"""
证明电路正确性
参数:
circuit: 候选电路
spec: 规格(电路应满足的属性)
input_domain: 输入域定义
"""
# 符号执行
symbolic_trace = self.symbolic_execute(circuit, input_domain)
# SMT验证
smt_formula = self.generate_smt_formula(symbolic_trace, spec)
# 求解
solver = Solver()
solver.add(smt_formula)
if solver.check() == unsat:
return {'proved': True, 'counterexample': None}
else:
return {'proved': False, 'counterexample': solver.model()}
def symbolic_execute(self, circuit, input_domain):
"""符号执行电路"""
# 实现省略
pass
def generate_smt_formula(self, trace, spec):
"""生成SMT公式"""
# 实现省略
pass实践应用
电路发现pipeline
class CircuitDiscoveryPipeline:
"""
完整的电路发现流程
"""
def __init__(self, model):
self.model = model
self.device = next(model.parameters()).device
def discover_circuits(self, dataloader, task_description,
confidence_threshold=0.95):
"""
完整的电路发现流程
"""
# 1. 激活分析
activation_analyzer = ActivationAnalyzer(self.model)
important_neurons = activation_analyzer.find_important_neurons(dataloader)
# 2. 电路构建
circuit_builder = CircuitBuilder(self.model)
candidate_circuit = circuit_builder.build(important_neurons)
# 3. 验证
verifier = SMTBasedCircuitVerifier(self.model, candidate_circuit)
is_valid, counterexample = verifier.verify_circuit(
input_bounds=[(-1, 1)] * 784 # 假设MNIST
)
# 4. 稳定性分析
stability_analyzer = CircuitStabilityAnalyzer(self.model, candidate_circuit)
stability_result = stability_analyzer.stability_under_perturbation(
next(iter(dataloader))[0].to(self.device),
None
)
# 5. 最小性检查
minimality_verifier = CircuitMinimalityVerifier(self.model, candidate_circuit)
minimality_result = minimality_verifier.check_minimality(dataloader)
return {
'circuit': candidate_circuit,
'is_valid': is_valid,
'stability': stability_result,
'minimality': minimality_result,
'counterexample': counterexample
}
class ActivationAnalyzer:
"""激活分析器"""
def __init__(self, model):
self.model = model
self.activations = {}
def find_important_neurons(self, dataloader, top_k=100):
"""找到最重要的神经元"""
# 实现省略
return list(range(top_k))
class CircuitBuilder:
"""电路构建器"""
def __init__(self, model):
self.model = model
def build(self, important_neurons):
"""从重要神经元构建电路"""
return important_neurons数学附录
核心定义速查
| 概念 | 定义 | 验证方法 |
|---|---|---|
| 电路 | 实现特定功能的神经元子图 | 功能完整性 |
| 稳定性 | 扰动下输出不变 | Lipschitz分析 |
| 最小性 | 无冗余节点 | 贪心搜索 |
| 迁移性 | 跨模型泛化 | 任务准确率 |
| 正确性 | 满足规格 | SMT求解 |
SMT理论支持
| 理论 | 应用 |
|---|---|
| 实数算术 (RA) | 浮点网络验证 |
| 位向量 (BV) | 定点网络验证 |
| 数组理论 | 内存建模 |
| 非线性实数 (NRA) | 激活函数 |
参考文献
相关主题
Footnotes
-
Geiping, J., et al. (2026). Formal mechanistic interpretability. ICLR 2026. ↩