概述

机制可解释性(Mechanistic Interpretability)旨在理解神经网络内部工作原理的”电路”机制。形式化方法为这一领域带来了严格的理论基础,使我们能够对电路发现给出可证明的正确性保证。1

本文系统梳理形式化机制可解释性理论的核心内容:

  • 电路发现的理论框架
  • SMT(Satisfiability Modulo Theories)验证方法
  • 稳定性与鲁棒性保证
  • 形式化电路的正确性证明

电路发现的形式化基础

从神经元到电路

电路(Circuit)是由神经元之间的特定连接组成的子图,这些连接共同实现某种计算功能。

为神经网络, 为其计算图。电路 是边的子集。

电路发现的目标:找到最小的边集 使得移除 后网络的输出发生显著变化。

定义与符号

为神经网络, 为参数。电路可以形式化为:

定义(电路):电路 是参数子集 的索引,使得:

  1. 完整性 在验证集
  2. 最小性:对于任意真子集

其中 表示在某种度量下近似。

电路发现作为组合优化

电路发现问题可以形式化为:

其中 是仅使用电路 中权重的网络输出。


SMT验证框架

SMT基础

SMT(Satisfiability Modulo Theories)求解器可以检查一阶逻辑公式在特定理论(如实数算术、位向量)下的可满足性。

对于电路验证,我们使用SMT来证明电路的正确性。

电路验证的SMT编码

为输入, 为候选电路。验证问题编码为:

这可以转化为SMT公式:

工具与实现

import torch
import numpy as np
from z3 import *
 
 
class SMTBasedCircuitVerifier:
    """
    基于SMT的电路验证器
    
    使用Z3求解器验证电路的正确性
    """
    def __init__(self, model, circuit_nodes):
        self.model = model
        self.circuit_nodes = set(circuit_nodes)  # 电路中的节点索引
        self.device = next(model.parameters()).device
    
    def create_input_constraints(self, input_bounds):
        """
        创建输入约束:x_i ∈ [l_i, u_i]
        """
        constraints = []
        z3_vars = {}
        
        for i, (lower, upper) in enumerate(input_bounds):
            x_i = Real(f'x_{i}')
            z3_vars[i] = x_i
            constraints.append(x_i >= lower)
            constraints.append(x_i <= upper)
        
        return constraints, z3_vars
    
    def encode_network(self, z3_vars):
        """
        将神经网络编码为Z3公式
        
        假设简单MLP: h = ReLU(Wx + b)
        """
        # 获取网络参数
        layers = []
        for name, module in self.model.named_modules():
            if isinstance(module, torch.nn.Linear):
                W = module.weight.data.cpu().numpy()
                b = module.bias.data.cpu().numpy()
                layers.append((W, b))
        
        # 创建中间变量
        z3_layers = []
        current = z3_vars  # 输入层
        
        for i, (W, b) in enumerate(layers):
            layer_vars = {}
            
            for j in range(W.shape[0]):
                # 线性组合
                linear_comb = sum(W[j, k] * current[k] for k in range(W.shape[1]))
                linear_comb = linear_comb + b[j]
                
                # ReLU激活
                h_j = If(linear_comb > 0, linear_comb, 0)
                layer_vars[j] = h_j
            
            current = layer_vars
            z3_layers.append(layer_vars)
        
        return z3_layers
    
    def verify_circuit(self, input_bounds, output_tolerance=1e-6):
        """
        验证电路的正确性
        
        返回: (is_valid, counterexample)
        """
        s = Optimize() if False else Solver()
        
        # 创建输入变量和约束
        constraints, z3_vars = self.create_input_constraints(input_bounds)
        for c in constraints:
            s.add(c)
        
        # 创建完整网络
        full_layers = self.encode_network(z3_vars)
        
        # 创建电路网络(只包含电路节点)
        circuit_layers = []
        for i, layer_vars in enumerate(full_layers):
            circuit_layer = {}
            for j, var in layer_vars.items():
                if j in self.circuit_nodes:
                    circuit_layer[j] = var
            circuit_layers.append(circuit_layer)
        
        # 添加输出差异约束
        final_layer = full_layers[-1]
        circuit_final = circuit_layers[-1] if circuit_layers else {}
        
        # 简化:比较第一维输出
        if final_layer and (circuit_final or len(circuit_final) == len(final_layer)):
            diff = final_layer[0] - (circuit_final.get(0, RealVal(0)))
            s.add(abs(diff) > output_tolerance)
        
        # 检查可满足性
        result = s.check()
        
        if result == sat:
            model = s.model()
            counterexample = {str(k): model[k] for k in z3_vars.values()}
            return False, counterexample
        else:
            return True, None
 
 
class ApproximateCircuitVerifier:
    """
    近似电路验证器(当精确SMT不可行时使用)
    
    使用采样和局部验证
    """
    def __init__(self, model, circuit_nodes):
        self.model = model
        self.circuit_nodes = circuit_nodes
        self.device = next(model.parameters()).device
    
    def compute_ablation_effect(self, x, y):
        """
        计算消融电路节点的效果
        """
        # 完整输出
        self.model.eval()
        with torch.no_grad():
            full_output = self.model(x)
        
        # 创建消融版本
        def circuit_ablation_hook(module, input, output):
            # 将电路节点置零
            output_ablate = output.clone()
            for idx in self.circuit_nodes:
                if idx < output.shape[-1]:
                    output_ablate[..., idx] = 0
            return output_ablate
        
        # 注册钩子
        handles = []
        for name, module in self.model.named_modules():
            if isinstance(module, torch.nn.Linear):
                handle = module.register_forward_hook(circuit_ablation_hook)
                handles.append(handle)
        
        with torch.no_grad():
            ablated_output = self.model(x)
        
        # 移除钩子
        for handle in handles:
            handle.remove()
        
        # 计算差异
        diff = torch.norm(full_output - ablated_output, p=2) / full_output.numel()
        
        return diff.item()
    
    def statistical_verification(self, dataloader, n_samples=1000, 
                                 error_threshold=1e-3):
        """
        统计验证:检查电路消融的误差
        """
        errors = []
        
        self.model.eval()
        for i, (x, y) in enumerate(dataloader):
            if i >= n_samples:
                break
            
            x = x.to(self.device)
            if x.ndim > 2:
                x = x.view(x.size(0), -1)
            
            error = self.compute_ablation_effect(x, y)
            errors.append(error)
        
        errors = np.array(errors)
        
        return {
            'mean_error': errors.mean(),
            'std_error': errors.std(),
            'max_error': errors.max(),
            'percentile_95': np.percentile(errors, 95),
            'is_valid': errors.mean() < error_threshold
        }

稳定性与鲁棒性保证

电路稳定性定义

定义(电路稳定性):电路 称为 -稳定的,如果对于任意输入扰动

其中 是Lipschitz常数。

随机子采样稳定性

定理(子采样稳定性):设 是从数据子集 发现的电路, 是完整数据集。若 ,则:

其中 是正确/错误电路之间的性能差距。

class CircuitStabilityAnalyzer:
    """
    电路稳定性分析器
    """
    def __init__(self, model, circuit_nodes):
        self.model = model
        self.circuit_nodes = circuit_nodes
        self.device = next(model.parameters()).device
    
    def compute_circuit_sensitivity(self, x, y):
        """
        计算电路节点的敏感度
        
        敏感度 = 移除节点后输出的变化
        """
        self.model.eval()
        
        # 完整输出
        with torch.no_grad():
            full_output = self.model(x)
        
        sensitivities = {}
        
        for node in self.circuit_nodes:
            # 消融该节点
            def create_hook(node_idx):
                def hook(module, input, output):
                    output_ablate = output.clone()
                    if node_idx < output.shape[-1]:
                        output_ablate[..., node_idx] = 0
                    return output_ablate
                return hook
            
            handle = None
            for name, module in self.model.named_modules():
                if isinstance(module, torch.nn.Linear):
                    handle = module.register_forward_hook(create_hook(node))
                    break
            
            with torch.no_grad():
                ablated_output = self.model(x)
            
            if handle:
                handle.remove()
            
            # 计算敏感度
            sensitivity = torch.norm(full_output - ablated_output, p=2).item()
            sensitivities[node] = sensitivity
        
        return sensitivities
    
    def compute_lipschitz_bound(self, x):
        """
        计算电路的Lipschitz常数上界
        """
        self.model.eval()
        
        # 使用幂迭代法计算谱范数
        from torch.linalg import svd
        
        lipschitz_bounds = []
        
        for name, module in self.model.named_modules():
            if isinstance(module, torch.nn.Linear):
                W = module.weight.data
                # 谱范数(最大奇异值)
                s = torch.linalg.svd(W, compute_uv=False)
                lipschitz_bounds.append(s.max().item())
        
        # 乘积给出Lipschitz上界
        lipschitz_bound = np.prod(lipschitz_bounds)
        
        return lipschitz_bound
    
    def stability_under_perturbation(self, x, y, epsilon=0.1, n_perturbations=100):
        """
        测试电路在扰动下的稳定性
        """
        self.model.eval()
        
        # 原始输出
        with torch.no_grad():
            original_output = self.model(x)
        
        # 扰动输出
        perturbed_errors = []
        
        for _ in range(n_perturbations):
            delta_x = torch.randn_like(x) * epsilon
            x_perturbed = x + delta_x
            
            with torch.no_grad():
                perturbed_output = self.model(x_perturbed)
            
            error = torch.norm(original_output - perturbed_output, p=2).item()
            perturbed_errors.append(error)
        
        return {
            'mean_error': np.mean(perturbed_errors),
            'std_error': np.std(perturbed_errors),
            'max_error': np.max(perturbed_errors),
            'lipschitz_bound': self.compute_lipschitz_bound(x)
        }

电路最小性与验证

最小电路定义

定义(最小电路):电路 是最小的,如果不存在真子集 使得

验证方法

class CircuitMinimalityVerifier:
    """
    电路最小性验证器
    """
    def __init__(self, model, circuit_nodes):
        self.model = model
        self.circuit_nodes = circuit_nodes
        self.device = next(model.parameters()).device
    
    def check_minimality(self, dataloader, tolerance=1e-4):
        """
        检查电路的最小性
        
        尝试移除每个节点,检查是否仍然等效
        """
        self.model.eval()
        
        # 参考输出(完整电路)
        ref_outputs = []
        with torch.no_grad():
            for x, _ in dataloader:
                x = x.to(self.device)
                output = self.model(x)
                ref_outputs.append(output)
        
        # 对每个节点,检查是否可以移除
        redundant_nodes = []
        essential_nodes = []
        
        for node in self.circuit_nodes:
            # 创建消融版本
            def create_hook(node_idx):
                def hook(module, input, output):
                    output_ablate = output.clone()
                    if node_idx < output.shape[-1]:
                        output_ablate[..., node_idx] = 0
                    return output_ablate
                return hook
            
            handle = None
            for name, module in self.model.named_modules():
                if isinstance(module, torch.nn.Linear):
                    handle = module.register_forward_hook(create_hook(node))
                    break
            
            # 计算消融后的输出
            ablated_outputs = []
            with torch.no_grad():
                for x, _ in dataloader:
                    x = x.to(self.device)
                    output = self.model(x)
                    ablated_outputs.append(output)
            
            if handle:
                handle.remove()
            
            # 比较输出
            ref_cat = torch.cat(ref_outputs, dim=0)
            ablated_cat = torch.cat(ablated_outputs, dim=0)
            
            error = torch.norm(ref_cat - ablated_cat, p=2).item() / ref_cat.numel()
            
            if error < tolerance:
                redundant_nodes.append(node)
            else:
                essential_nodes.append(node)
        
        return {
            'essential_nodes': essential_nodes,
            'redundant_nodes': redundant_nodes,
            'minimality_ratio': len(essential_nodes) / len(self.circuit_nodes)
        }
    
    def find_minimal_subcircuit(self, dataloader, tolerance=1e-4):
        """
        使用贪心搜索找到最小子电路
        """
        self.model.eval()
        
        # 计算原始电路输出
        ref_outputs = []
        with torch.no_grad():
            for x, _ in dataloader:
                x = x.to(self.device)
                output = self.model(x)
                ref_outputs.append(output)
        
        ref_cat = torch.cat(ref_outputs, dim=0)
        
        # 贪心移除
        current_nodes = set(self.circuit_nodes)
        minimal_nodes = set(self.circuit_nodes)
        
        # 按敏感度排序
        sensitivities = {}
        for node in current_nodes:
            sensitivity = self._compute_node_sensitivity(node, dataloader, ref_cat)
            sensitivities[node] = sensitivity
        
        sorted_nodes = sorted(current_nodes, key=lambda n: sensitivities[n])
        
        # 尝试移除低敏感度节点
        for node in sorted_nodes:
            current_nodes.remove(node)
            
            # 检查是否仍然等效
            if self._check_equivalence(current_nodes, dataloader, ref_cat, tolerance):
                minimal_nodes.remove(node)
            else:
                current_nodes.add(node)
        
        return {
            'minimal_nodes': list(minimal_nodes),
            'reduction_ratio': 1 - len(minimal_nodes) / len(self.circuit_nodes)
        }
    
    def _compute_node_sensitivity(self, node, dataloader, ref_output):
        """计算节点的敏感度"""
        # 简化实现
        return 1.0
    
    def _check_equivalence(self, nodes, dataloader, ref_output, tolerance):
        """检查移除节点后是否等效"""
        # 简化实现
        return False

电路可迁移性分析

迁移性定义

定义(电路迁移性):电路 从模型 是可迁移的,如果 中仍然实现相同的计算功能。

迁移性度量

class CircuitTransferabilityAnalyzer:
    """
    电路迁移性分析器
    """
    def __init__(self, source_model, target_model, circuit_nodes):
        self.source_model = source_model
        self.target_model = target_model
        self.circuit_nodes = circuit_nodes
        self.device = next(source_model.parameters()).device
    
    def compute_transferability_score(self, dataloader, task):
        """
        计算电路的迁移性分数
        """
        # 在源模型上提取电路
        source_circuit = self._extract_circuit()
        
        # 在目标模型上测试
        target_model = self.target_model.to(self.device)
        target_model.eval()
        
        # 测试性能
        correct = 0
        total = 0
        
        with torch.no_grad():
            for x, y in dataloader:
                x = x.to(self.device)
                y = y.to(self.device)
                
                output = target_model(x)
                pred = output.argmax(dim=1)
                
                correct += (pred == y).sum().item()
                total += y.size(0)
        
        accuracy = correct / total
        
        return {
            'transfer_accuracy': accuracy,
            'source_circuit_size': len(source_circuit),
            'transferability_score': accuracy  # 简化的迁移性分数
        }
    
    def _extract_circuit(self):
        """提取源模型中的电路"""
        return self.circuit_nodes

与形式化验证的结合

电路正确性证明框架

class FormalCircuitProver:
    """
    形式化电路证明器
    
    结合SMT和符号执行进行电路验证
    """
    def __init__(self, model):
        self.model = model
    
    def prove_circuit_correctness(self, circuit, spec, input_domain):
        """
        证明电路正确性
        
        参数:
            circuit: 候选电路
            spec: 规格(电路应满足的属性)
            input_domain: 输入域定义
        """
        # 符号执行
        symbolic_trace = self.symbolic_execute(circuit, input_domain)
        
        # SMT验证
        smt_formula = self.generate_smt_formula(symbolic_trace, spec)
        
        # 求解
        solver = Solver()
        solver.add(smt_formula)
        
        if solver.check() == unsat:
            return {'proved': True, 'counterexample': None}
        else:
            return {'proved': False, 'counterexample': solver.model()}
    
    def symbolic_execute(self, circuit, input_domain):
        """符号执行电路"""
        # 实现省略
        pass
    
    def generate_smt_formula(self, trace, spec):
        """生成SMT公式"""
        # 实现省略
        pass

实践应用

电路发现pipeline

class CircuitDiscoveryPipeline:
    """
    完整的电路发现流程
    """
    def __init__(self, model):
        self.model = model
        self.device = next(model.parameters()).device
    
    def discover_circuits(self, dataloader, task_description, 
                         confidence_threshold=0.95):
        """
        完整的电路发现流程
        """
        # 1. 激活分析
        activation_analyzer = ActivationAnalyzer(self.model)
        important_neurons = activation_analyzer.find_important_neurons(dataloader)
        
        # 2. 电路构建
        circuit_builder = CircuitBuilder(self.model)
        candidate_circuit = circuit_builder.build(important_neurons)
        
        # 3. 验证
        verifier = SMTBasedCircuitVerifier(self.model, candidate_circuit)
        is_valid, counterexample = verifier.verify_circuit(
            input_bounds=[(-1, 1)] * 784  # 假设MNIST
        )
        
        # 4. 稳定性分析
        stability_analyzer = CircuitStabilityAnalyzer(self.model, candidate_circuit)
        stability_result = stability_analyzer.stability_under_perturbation(
            next(iter(dataloader))[0].to(self.device),
            None
        )
        
        # 5. 最小性检查
        minimality_verifier = CircuitMinimalityVerifier(self.model, candidate_circuit)
        minimality_result = minimality_verifier.check_minimality(dataloader)
        
        return {
            'circuit': candidate_circuit,
            'is_valid': is_valid,
            'stability': stability_result,
            'minimality': minimality_result,
            'counterexample': counterexample
        }
 
 
class ActivationAnalyzer:
    """激活分析器"""
    def __init__(self, model):
        self.model = model
        self.activations = {}
    
    def find_important_neurons(self, dataloader, top_k=100):
        """找到最重要的神经元"""
        # 实现省略
        return list(range(top_k))
 
 
class CircuitBuilder:
    """电路构建器"""
    def __init__(self, model):
        self.model = model
    
    def build(self, important_neurons):
        """从重要神经元构建电路"""
        return important_neurons

数学附录

核心定义速查

概念定义验证方法
电路实现特定功能的神经元子图功能完整性
稳定性扰动下输出不变Lipschitz分析
最小性无冗余节点贪心搜索
迁移性跨模型泛化任务准确率
正确性满足规格SMT求解

SMT理论支持

理论应用
实数算术 (RA)浮点网络验证
位向量 (BV)定点网络验证
数组理论内存建模
非线性实数 (NRA)激活函数

参考文献


相关主题

Footnotes

  1. Geiping, J., et al. (2026). Formal mechanistic interpretability. ICLR 2026.