Sparse Feature Circuits

概述

Sparse Feature Circuits1是将SAE特征与电路发现方法相结合的产物,旨在更系统性地发现和验证神经网络中的计算电路。

核心思想

  • 使用SAE作为”显微镜”,将叠加的表示分解为可解释的特征
  • 利用电路发现方法(如激活Patching)验证特征之间的因果关系
  • 构建基于特征的”电路图”,揭示模型的计算机制

1. 背景:电路发现与SAE

1.1 传统电路发现

传统电路发现方法2使用激活Patching技术:

# 传统电路发现的核心思想
def activation_patching_experiment(
    model,
    clean_tokens,
    corrupted_tokens,
    layer_idx,
    position_idx,
):
    """
    激活Patching实验
    
    1. 在损坏的输入上运行模型,记录中间激活
    2. 将特定位置/层的激活替换为干净输入的激活
    3. 观察输出变化
    """
    # 获取干净和损坏的激活
    clean_cache = run_with_cache(model, clean_tokens)
    corrupted_cache = run_with_cache(model, corrupted_tokens)
    
    # 替换激活
    patched_cache = corrupted_cache.copy()
    patched_cache[layer_idx][:, position_idx, :] = clean_cache[layer_idx][:, position_idx, :]
    
    # 运行并比较
    clean_output = model(clean_tokens)
    corrupted_output = model(corrupted_tokens)
    patched_output = run_with_cache(model, corrupted_tokens, patched_cache)
    
    # 计算贡献
    contribution = (clean_output - corrupted_output) - (patched_output - corrupted_output)
    
    return contribution

1.2 SAE作为”显微镜”

SAE解决了电路发现的一个关键问题:叠加

问题:在叠加的表示中,很难确定哪些神经元属于哪个电路

解决方案:SAE将叠加分解为稀疏特征,每个特征对应一个"电路节点"

1.3 Sparse Feature Circuits的结合

Sparse Feature Circuits = SAE特征 + 电路发现方法

┌─────────────────────────────────────────────────────────────┐
│                                                             │
│   输入 ──→ [层1] ──→ [层2] ──→ [层3] ──→ 输出            │
│              │         │         │                         │
│              ↓         ↓         ↓                         │
│           特征A      特征B      特征C                       │
│              │         │         │                         │
│              └─────────┴─────────┘                         │
│                         ↓                                   │
│                   电路图(特征之间)                          │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2. 方法论

2.1 特征级Patching

class FeaturePatching:
    """
    基于SAE特征的Patching实验
    
    与传统token-level Patching不同,这里操作的是SAE特征
    """
    
    def __init__(self, model, sae):
        self.model = model
        self.sae = sae
        self.device = next(model.parameters()).device
    
    def feature_patching_experiment(
        self,
        clean_tokens,
        corrupted_tokens,
        target_layer,
        target_feature_idx,
    ) -> dict:
        """
        特征级Patching实验
        
        1. 在干净输入上运行,获取SAE特征
        2. 在损坏输入上,将目标特征替换为干净版本
        3. 比较输出变化
        """
        # 运行干净输入
        clean_activations = self.get_layer_activations(clean_tokens, target_layer)
        clean_features = self.sae.encode(clean_activations)
        
        # 运行损坏输入
        corrupted_activations = self.get_layer_activations(corrupted_tokens, target_layer)
        corrupted_features = self.sae.encode(corrupted_activations)
        
        # Patch:替换目标特征
        patched_features = corrupted_features.clone()
        patched_features[:, :, target_feature_idx] = clean_features[:, :, target_feature_idx]
        
        # 解码并继续前向传播
        patched_activations = self.sae.decode(patched_features)
        
        # 计算贡献
        clean_output = self.model(clean_tokens)
        corrupted_output = self.model(corrupted_tokens)
        patched_output = self.run_with_modified_activation(
            corrupted_tokens, target_layer, patched_activations
        )
        
        # 计算各类损失
        results = {
            "clean_score": self.get_metric(clean_output),
            "corrupted_score": self.get_metric(corrupted_output),
            "patched_score": self.get_metric(patched_output),
            "patch_effect": self.get_metric(patched_output) - self.get_metric(corrupted_output),
            "total_effect": self.get_metric(clean_output) - self.get_metric(corrupted_output),
        }
        
        # 计算补丁效果比例
        if results["total_effect"] != 0:
            results["patch_ratio"] = results["patch_effect"] / results["total_effect"]
        else:
            results["patch_ratio"] = 0.0
        
        return results
    
    def get_layer_activations(self, tokens, layer_idx):
        """获取特定层的激活"""
        cache = {}
        
        def hook_fn(module, input, output):
            cache["activation"] = output[0]  # 残差流
        
        hook = self.model.model.layers[layer_idx].register_forward_hook(hook_fn)
        
        with torch.no_grad():
            self.model(tokens)
        
        hook.remove()
        return cache["activation"]
    
    def run_with_modified_activation(self, tokens, layer_idx, modified_activation):
        """使用修改后的激活运行模型"""
        cache = {}
        
        def hook_fn(module, input, output):
            # 替换激活
            return (modified_activation,) + output[1:]
        
        def pre_hook_fn(module, input):
            # 在下一层之前插入修改的激活
            return input
        
        # 注册hook
        hook = self.model.model.layers[layer_idx].register_forward_hook(hook_fn)
        
        with torch.no_grad():
            output = self.model(tokens)
        
        hook.remove()
        return output
    
    def get_metric(self, output):
        """计算评估指标(可以是loss、logit等)"""
        return output.logits[:, -1, :].max(dim=-1).values.mean().item()

2.2 特征级归因

class FeatureAttribution:
    """
    为每个输出计算特征的重要性归因
    """
    
    def __init__(self, model, sae):
        self.model = model
        self.sae = sae
    
    def compute_feature_importance(
        self,
        tokens,
        target_layer,
        n_features=None,
    ) -> dict:
        """
        计算每个特征的输出重要性
        """
        if n_features is None:
            n_features = self.sae.n_features
        
        # 准备基线
        baseline_features = torch.zeros(
            1, tokens.shape[1], n_features, device=tokens.device
        )
        
        # 获取原始激活
        original_activations = self.get_layer_activations(tokens, target_layer)
        original_features = self.sae.encode(original_activations)
        
        # 计算每个特征的边际贡献
        importance_scores = []
        
        for feat_idx in range(min(n_features, self.sae.n_features)):
            # Patch到基线
            patched_features = original_features.clone()
            patched_features[:, :, feat_idx] = baseline_features[:, :, feat_idx]
            
            patched_activations = self.sae.decode(patched_features)
            
            # 运行并计算效果
            patched_output = self.run_with_modified_activation(
                tokens, target_layer, patched_activations
            )
            original_output = self.model(tokens)
            
            # 边际效应
            effect = (
                self.get_metric(original_output) - 
                self.get_metric(patched_output)
            )
            
            importance_scores.append({
                "feature_idx": feat_idx,
                "importance": effect,
                "is_active": original_features[0, -1, feat_idx] > 0,
            })
        
        # 排序
        importance_scores.sort(key=lambda x: x["importance"], reverse=True)
        
        return {
            "scores": importance_scores,
            "top_features": importance_scores[:10],
            "total_importance": sum(s["importance"] for s in importance_scores),
        }
    
    def get_layer_activations(self, tokens, layer_idx):
        """获取层激活"""
        cache = {}
        def hook_fn(module, input, output):
            cache["activation"] = output[0]
        
        hook = self.model.model.layers[layer_idx].register_forward_hook(hook_fn)
        with torch.no_grad():
            self.model(tokens)
        hook.remove()
        
        return cache["activation"]
    
    def run_with_modified_activation(self, tokens, layer_idx, modified_activation):
        """使用修改激活运行"""
        def hook_fn(module, input, output):
            return (modified_activation,) + output[1:]
        
        hook = self.model.model.layers[layer_idx].register_forward_hook(hook_fn)
        with torch.no_grad():
            output = self.model(tokens)
        hook.remove()
        
        return output
    
    def get_metric(self, output):
        """评估指标"""
        return output.logits[:, -1, :].max(dim=-1).values.mean().item()

2.3 电路构建

from collections import defaultdict
import networkx as nx
 
class SparseFeatureCircuit:
    """
    从特征归因构建稀疏特征电路
    """
    
    def __init__(self, model, sae):
        self.model = model
        self.sae = sae
        self.attribution = FeatureAttribution(model, sae)
    
    def build_circuit(
        self,
        tokens,
        behavior_description: str,
        layers: list[int],
        importance_threshold: float = 0.1,
    ) -> nx.DiGraph:
        """
        构建稀疏特征电路
        
        Args:
            tokens: 输入token序列
            behavior_description: 目标行为的描述
            layers: 要分析的层列表
            importance_threshold: 重要性阈值
        
        Returns:
            有向图表示的电路
        """
        circuit = nx.DiGraph()
        circuit.graph["behavior"] = behavior_description
        
        # 收集所有重要特征
        important_features = []
        
        for layer_idx in layers:
            attributions = self.attribution.compute_feature_importance(
                tokens, layer_idx
            )
            
            for score in attributions["scores"]:
                if abs(score["importance"]) > importance_threshold:
                    important_features.append({
                        "layer": layer_idx,
                        "feature_idx": score["feature_idx"],
                        "importance": score["importance"],
                        "is_active": score["is_active"],
                    })
        
        # 添加节点
        for feat in important_features:
            node_id = f"layer{feat['layer']}_feat{feat['feature_idx']}"
            circuit.add_node(
                node_id,
                layer=feat["layer"],
                feature_idx=feat["feature_idx"],
                importance=feat["importance"],
                is_active=feat["is_active"],
            )
        
        # 发现边(通过跨层Patching)
        edges = self.discover_edges(
            tokens, layers, [f["feature_idx"] for f in important_features]
        )
        
        # 添加边
        for src, dst, weight in edges:
            src_id = f"layer{layers[src['layer_idx']]}_feat{src['feat_idx']}"
            dst_id = f"layer{layers[dst['layer_idx']]}_feat{dst['feat_idx']}"
            
            if circuit.has_node(src_id) and circuit.has_node(dst_id):
                circuit.add_edge(src_id, dst_id, weight=weight)
        
        return circuit
    
    def discover_edges(
        self,
        tokens,
        layers,
        feature_indices,
        threshold: float = 0.05,
    ) -> list[tuple]:
        """
        发现特征之间的边
        
        通过跨层Patching确定影响关系
        """
        edges = []
        
        # 对于每对相邻层
        for i in range(len(layers) - 1):
            src_layer = layers[i]
            dst_layer = layers[i + 1]
            
            # 获取两层的激活
            src_activations = self.attribution.get_layer_activations(tokens, src_layer)
            dst_activations = self.attribution.get_layer_activations(tokens, dst_layer)
            
            # 编码为特征
            src_features = self.sae.encode(src_activations)
            dst_features = self.sae.encode(dst_activations)
            
            # 检查每个源特征对目标特征的贡献
            for src_feat_idx in feature_indices:
                if src_feat_idx >= self.sae.n_features:
                    continue
                
                for dst_feat_idx in feature_indices:
                    if dst_feat_idx >= self.sae.n_features:
                        continue
                    
                    # Patch源特征
                    patched_src_features = src_features.clone()
                    patched_src_features[:, :, src_feat_idx] = 0
                    
                    patched_src_activations = self.sae.decode(patched_src_features)
                    
                    # 运行并检查目标特征变化
                    patched_dst = self.run_and_encode(
                        tokens, dst_layer, patched_src_activations
                    )
                    
                    # 计算贡献
                    original_dst_value = dst_features[0, -1, dst_feat_idx].item()
                    patched_dst_value = patched_dst[0, -1, dst_feat_idx].item()
                    
                    contribution = original_dst_value - patched_dst_value
                    
                    if abs(contribution) > threshold:
                        edges.append((
                            {"layer_idx": i, "feat_idx": src_feat_idx},
                            {"layer_idx": i + 1, "feat_idx": dst_feat_idx},
                            contribution
                        ))
        
        return edges
    
    def run_and_encode(self, tokens, layer_idx, src_activations):
        """运行模型并编码指定层"""
        def hook_fn(module, input, output):
            return (src_activations,) + output[1:]
        
        hook = self.model.model.layers[layer_idx].register_forward_hook(hook_fn)
        with torch.no_grad():
            self.model(tokens)
        hook.remove()
        
        # 获取后续激活
        cache = {}
        def cache_hook(module, input, output):
            cache["activation"] = output[0]
        
        cache_hook_handle = self.model.model.layers[layer_idx].register_forward_hook(cache_hook)
        with torch.no_grad():
            self.model(tokens)
        cache_hook_handle.remove()
        
        return self.sae.encode(cache["activation"])

3. 电路分析

3.1 可视化

import matplotlib.pyplot as plt
import networkx as nx
 
def visualize_circuit(circuit: nx.DiGraph, save_path: str = None):
    """
    可视化稀疏特征电路
    """
    fig, ax = plt.subplots(1, 1, figsize=(16, 12))
    
    # 按层分组
    layers = defaultdict(list)
    for node in circuit.nodes():
        layer = circuit.nodes[node]["layer"]
        layers[layer].append(node)
    
    # 布局
    pos = {}
    n_layers = len(layers)
    
    for layer_idx, (layer, nodes) in enumerate(sorted(layers.items())):
        n_nodes = len(nodes)
        for i, node in enumerate(nodes):
            x = layer_idx
            y = (i - n_nodes / 2) / max(n_nodes - 1, 1)
            pos[node] = (x, y)
    
    # 绘制节点
    node_colors = []
    node_sizes = []
    
    for node in circuit.nodes():
        importance = circuit.nodes[node]["importance"]
        node_colors.append(importance)
        node_sizes.append(100 + abs(importance) * 500)
    
    nx.draw_networkx_nodes(
        circuit, pos,
        node_color=node_colors,
        node_size=node_sizes,
        cmap=plt.cm.RdBu_r,
        alpha=0.8,
        ax=ax
    )
    
    # 绘制边
    edges = list(circuit.edges())
    if edges:
        edge_weights = [abs(circuit.edges[e[0], e[1]]["weight"]) for e in edges]
        edge_colors = [circuit.edges[e[0], e[1]]["weight"] for e in edges]
        
        nx.draw_networkx_edges(
            circuit, pos,
            edgelist=edges,
            edge_color=edge_colors,
            width=[1 + w * 2 for w in edge_weights],
            alpha=0.6,
            cmap=plt.cm.RdBu_r,
            ax=ax
        )
    
    # 添加层标签
    for layer_idx, layer in enumerate(sorted(layers.keys())):
        ax.annotate(
            f"Layer {layer}",
            xy=(layer_idx, 1.15),
            ha='center',
            fontsize=12,
            fontweight='bold'
        )
    
    ax.set_title(f"Sparse Feature Circuit: {circuit.graph.get('behavior', 'Unknown')}")
    ax.axis('off')
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    return fig

3.2 电路统计

def analyze_circuit_statistics(circuit: nx.DiGraph) -> dict:
    """
    计算电路的统计信息
    """
    stats = {
        "n_nodes": circuit.number_of_nodes(),
        "n_edges": circuit.number_of_edges(),
        "density": nx.density(circuit),
    }
    
    # 层间连接
    layer_connections = defaultdict(list)
    for u, v in circuit.edges():
        src_layer = circuit.nodes[u]["layer"]
        dst_layer = circuit.nodes[v]["layer"]
        layer_connections[(src_layer, dst_layer)].append((u, v))
    
    stats["layer_connections"] = dict(layer_connections)
    
    # 节点的入度和出度
    in_degrees = dict(circuit.in_degree())
    out_degrees = dict(circuit.out_degree())
    
    stats["max_in_degree"] = max(in_degrees.values())
    stats["max_out_degree"] = max(out_degrees.values())
    
    # 关键节点(hub节点)
    hub_nodes = [
        node for node, degree in in_degrees.items()
        if degree >= 3 or out_degrees[node] >= 3
    ]
    stats["hub_nodes"] = hub_nodes
    
    # 连通分量
    stats["n_components"] = nx.number_weakly_connected_components(circuit)
    
    return stats

4. 应用案例

4.1 归纳头电路

def analyze_induction_head_circuit(model, sae, tokens):
    """
    分析归纳头电路
    
    归纳头负责:
    1. 定位之前的"模式起始"token (如 [)
    2. 复制该模式之后的内容
    """
    circuit_builder = SparseFeatureCircuit(model, sae)
    
    # 关键层(通常在Transformer中层15-25)
    layers = list(range(15, 26))
    
    # 构建电路
    circuit = circuit_builder.build_circuit(
        tokens,
        behavior_description="Induction Head (Pattern Completion)",
        layers=layers,
        importance_threshold=0.05,
    )
    
    # 分析电路结构
    stats = analyze_circuit_statistics(circuit)
    
    # 可视化
    visualize_circuit(circuit, "induction_head_circuit.png")
    
    return circuit, stats

4.2 IOI (Indirect Object Identification) 电路

def analyze_ioi_circuit(model, sae, tokens):
    """
    分析IOI (间接对象识别) 电路
    
    IOI电路负责:
    1. 识别句子的主语 (如 "Then, Mary gave John")
    2. 找到正确的对象 (如 "John")
    3. 在正确的位置输出
    """
    circuit_builder = SparseFeatureCircuit(model, sae)
    
    layers = list(range(0, 32))
    
    circuit = circuit_builder.build_circuit(
        tokens,
        behavior_description="Indirect Object Identification",
        layers=layers,
        importance_threshold=0.08,
    )
    
    # 找出关键节点
    key_nodes = [
        node for node in circuit.nodes()
        if circuit.in_degree(node) + circuit.out_degree(node) >= 5
    ]
    
    print(f"Key nodes: {key_nodes}")
    
    return circuit

5. 与传统电路发现的对比

5.1 方法对比

方面传统电路发现Sparse Feature Circuits
操作单元Token级激活SAE特征
叠加处理无法处理自然解决
特征解释困难相对容易
完整性可能有遗漏更全面
计算成本较高较高

5.2 定量对比

# 对比实验示例
def compare_circuit_methods(
    model, sae, test_cases, target_behavior
):
    """
    对比传统电路发现和Sparse Feature Circuits
    """
    results = {
        "traditional": {
            "circuits_found": [],
            "feature_explainability": [],
            "coverage": [],
        },
        "sparse_feature": {
            "circuits_found": [],
            "feature_explainability": [],
            "coverage": [],
        }
    }
    
    for case in test_cases:
        tokens = case["tokens"]
        
        # 传统方法
        traditional_circuit = traditional_circuit_discovery(model, tokens)
        results["traditional"]["circuits_found"].append(traditional_circuit)
        results["traditional"]["feature_explainability"].append(
            estimate_explainability(traditional_circuit)
        )
        
        # Sparse Feature Circuits
        sf_circuit = SparseFeatureCircuit(model, sae).build_circuit(
            tokens, target_behavior, layers=case["layers"]
        )
        results["sparse_feature"]["circuits_found"].append(sf_circuit)
        results["sparse_feature"]["feature_explainability"].append(
            estimate_explainability(sf_circuit)
        )
    
    return results

6. 局限性与未来方向

6.1 当前局限性

局限性描述
计算成本Patching实验数量随特征数平方增长
边界模糊电路边界难以精确定义
动态电路同一电路在不同输入可能有变化
特征可靠性SAE特征的解释可能被误导

6.2 未来方向

方向描述
自动化自动发现和验证电路
层次化从局部到全局的电路层次
动态捕捉电路的输入依赖性
安全应用使用电路知识进行安全干预

7. 参考文献


相关资源

Footnotes

  1. “Sparse Feature Circuits: Discovering Interpretable Circuits in Neural Networks.” ICLR 2025.

  2. Wang et al. “Interpretability in the Wild: a Circuit for Indirect Object Identification in GPT-2 small.” ICLR 2023.