Sparse Feature Circuits
概述
Sparse Feature Circuits1是将SAE特征与电路发现方法相结合的产物,旨在更系统性地发现和验证神经网络中的计算电路。
核心思想:
- 使用SAE作为”显微镜”,将叠加的表示分解为可解释的特征
- 利用电路发现方法(如激活Patching)验证特征之间的因果关系
- 构建基于特征的”电路图”,揭示模型的计算机制
1. 背景:电路发现与SAE
1.1 传统电路发现
传统电路发现方法2使用激活Patching技术:
# 传统电路发现的核心思想
def activation_patching_experiment(
model,
clean_tokens,
corrupted_tokens,
layer_idx,
position_idx,
):
"""
激活Patching实验
1. 在损坏的输入上运行模型,记录中间激活
2. 将特定位置/层的激活替换为干净输入的激活
3. 观察输出变化
"""
# 获取干净和损坏的激活
clean_cache = run_with_cache(model, clean_tokens)
corrupted_cache = run_with_cache(model, corrupted_tokens)
# 替换激活
patched_cache = corrupted_cache.copy()
patched_cache[layer_idx][:, position_idx, :] = clean_cache[layer_idx][:, position_idx, :]
# 运行并比较
clean_output = model(clean_tokens)
corrupted_output = model(corrupted_tokens)
patched_output = run_with_cache(model, corrupted_tokens, patched_cache)
# 计算贡献
contribution = (clean_output - corrupted_output) - (patched_output - corrupted_output)
return contribution1.2 SAE作为”显微镜”
SAE解决了电路发现的一个关键问题:叠加。
问题:在叠加的表示中,很难确定哪些神经元属于哪个电路
解决方案:SAE将叠加分解为稀疏特征,每个特征对应一个"电路节点"
1.3 Sparse Feature Circuits的结合
Sparse Feature Circuits = SAE特征 + 电路发现方法
┌─────────────────────────────────────────────────────────────┐
│ │
│ 输入 ──→ [层1] ──→ [层2] ──→ [层3] ──→ 输出 │
│ │ │ │ │
│ ↓ ↓ ↓ │
│ 特征A 特征B 特征C │
│ │ │ │ │
│ └─────────┴─────────┘ │
│ ↓ │
│ 电路图(特征之间) │
│ │
└─────────────────────────────────────────────────────────────┘
2. 方法论
2.1 特征级Patching
class FeaturePatching:
"""
基于SAE特征的Patching实验
与传统token-level Patching不同,这里操作的是SAE特征
"""
def __init__(self, model, sae):
self.model = model
self.sae = sae
self.device = next(model.parameters()).device
def feature_patching_experiment(
self,
clean_tokens,
corrupted_tokens,
target_layer,
target_feature_idx,
) -> dict:
"""
特征级Patching实验
1. 在干净输入上运行,获取SAE特征
2. 在损坏输入上,将目标特征替换为干净版本
3. 比较输出变化
"""
# 运行干净输入
clean_activations = self.get_layer_activations(clean_tokens, target_layer)
clean_features = self.sae.encode(clean_activations)
# 运行损坏输入
corrupted_activations = self.get_layer_activations(corrupted_tokens, target_layer)
corrupted_features = self.sae.encode(corrupted_activations)
# Patch:替换目标特征
patched_features = corrupted_features.clone()
patched_features[:, :, target_feature_idx] = clean_features[:, :, target_feature_idx]
# 解码并继续前向传播
patched_activations = self.sae.decode(patched_features)
# 计算贡献
clean_output = self.model(clean_tokens)
corrupted_output = self.model(corrupted_tokens)
patched_output = self.run_with_modified_activation(
corrupted_tokens, target_layer, patched_activations
)
# 计算各类损失
results = {
"clean_score": self.get_metric(clean_output),
"corrupted_score": self.get_metric(corrupted_output),
"patched_score": self.get_metric(patched_output),
"patch_effect": self.get_metric(patched_output) - self.get_metric(corrupted_output),
"total_effect": self.get_metric(clean_output) - self.get_metric(corrupted_output),
}
# 计算补丁效果比例
if results["total_effect"] != 0:
results["patch_ratio"] = results["patch_effect"] / results["total_effect"]
else:
results["patch_ratio"] = 0.0
return results
def get_layer_activations(self, tokens, layer_idx):
"""获取特定层的激活"""
cache = {}
def hook_fn(module, input, output):
cache["activation"] = output[0] # 残差流
hook = self.model.model.layers[layer_idx].register_forward_hook(hook_fn)
with torch.no_grad():
self.model(tokens)
hook.remove()
return cache["activation"]
def run_with_modified_activation(self, tokens, layer_idx, modified_activation):
"""使用修改后的激活运行模型"""
cache = {}
def hook_fn(module, input, output):
# 替换激活
return (modified_activation,) + output[1:]
def pre_hook_fn(module, input):
# 在下一层之前插入修改的激活
return input
# 注册hook
hook = self.model.model.layers[layer_idx].register_forward_hook(hook_fn)
with torch.no_grad():
output = self.model(tokens)
hook.remove()
return output
def get_metric(self, output):
"""计算评估指标(可以是loss、logit等)"""
return output.logits[:, -1, :].max(dim=-1).values.mean().item()2.2 特征级归因
class FeatureAttribution:
"""
为每个输出计算特征的重要性归因
"""
def __init__(self, model, sae):
self.model = model
self.sae = sae
def compute_feature_importance(
self,
tokens,
target_layer,
n_features=None,
) -> dict:
"""
计算每个特征的输出重要性
"""
if n_features is None:
n_features = self.sae.n_features
# 准备基线
baseline_features = torch.zeros(
1, tokens.shape[1], n_features, device=tokens.device
)
# 获取原始激活
original_activations = self.get_layer_activations(tokens, target_layer)
original_features = self.sae.encode(original_activations)
# 计算每个特征的边际贡献
importance_scores = []
for feat_idx in range(min(n_features, self.sae.n_features)):
# Patch到基线
patched_features = original_features.clone()
patched_features[:, :, feat_idx] = baseline_features[:, :, feat_idx]
patched_activations = self.sae.decode(patched_features)
# 运行并计算效果
patched_output = self.run_with_modified_activation(
tokens, target_layer, patched_activations
)
original_output = self.model(tokens)
# 边际效应
effect = (
self.get_metric(original_output) -
self.get_metric(patched_output)
)
importance_scores.append({
"feature_idx": feat_idx,
"importance": effect,
"is_active": original_features[0, -1, feat_idx] > 0,
})
# 排序
importance_scores.sort(key=lambda x: x["importance"], reverse=True)
return {
"scores": importance_scores,
"top_features": importance_scores[:10],
"total_importance": sum(s["importance"] for s in importance_scores),
}
def get_layer_activations(self, tokens, layer_idx):
"""获取层激活"""
cache = {}
def hook_fn(module, input, output):
cache["activation"] = output[0]
hook = self.model.model.layers[layer_idx].register_forward_hook(hook_fn)
with torch.no_grad():
self.model(tokens)
hook.remove()
return cache["activation"]
def run_with_modified_activation(self, tokens, layer_idx, modified_activation):
"""使用修改激活运行"""
def hook_fn(module, input, output):
return (modified_activation,) + output[1:]
hook = self.model.model.layers[layer_idx].register_forward_hook(hook_fn)
with torch.no_grad():
output = self.model(tokens)
hook.remove()
return output
def get_metric(self, output):
"""评估指标"""
return output.logits[:, -1, :].max(dim=-1).values.mean().item()2.3 电路构建
from collections import defaultdict
import networkx as nx
class SparseFeatureCircuit:
"""
从特征归因构建稀疏特征电路
"""
def __init__(self, model, sae):
self.model = model
self.sae = sae
self.attribution = FeatureAttribution(model, sae)
def build_circuit(
self,
tokens,
behavior_description: str,
layers: list[int],
importance_threshold: float = 0.1,
) -> nx.DiGraph:
"""
构建稀疏特征电路
Args:
tokens: 输入token序列
behavior_description: 目标行为的描述
layers: 要分析的层列表
importance_threshold: 重要性阈值
Returns:
有向图表示的电路
"""
circuit = nx.DiGraph()
circuit.graph["behavior"] = behavior_description
# 收集所有重要特征
important_features = []
for layer_idx in layers:
attributions = self.attribution.compute_feature_importance(
tokens, layer_idx
)
for score in attributions["scores"]:
if abs(score["importance"]) > importance_threshold:
important_features.append({
"layer": layer_idx,
"feature_idx": score["feature_idx"],
"importance": score["importance"],
"is_active": score["is_active"],
})
# 添加节点
for feat in important_features:
node_id = f"layer{feat['layer']}_feat{feat['feature_idx']}"
circuit.add_node(
node_id,
layer=feat["layer"],
feature_idx=feat["feature_idx"],
importance=feat["importance"],
is_active=feat["is_active"],
)
# 发现边(通过跨层Patching)
edges = self.discover_edges(
tokens, layers, [f["feature_idx"] for f in important_features]
)
# 添加边
for src, dst, weight in edges:
src_id = f"layer{layers[src['layer_idx']]}_feat{src['feat_idx']}"
dst_id = f"layer{layers[dst['layer_idx']]}_feat{dst['feat_idx']}"
if circuit.has_node(src_id) and circuit.has_node(dst_id):
circuit.add_edge(src_id, dst_id, weight=weight)
return circuit
def discover_edges(
self,
tokens,
layers,
feature_indices,
threshold: float = 0.05,
) -> list[tuple]:
"""
发现特征之间的边
通过跨层Patching确定影响关系
"""
edges = []
# 对于每对相邻层
for i in range(len(layers) - 1):
src_layer = layers[i]
dst_layer = layers[i + 1]
# 获取两层的激活
src_activations = self.attribution.get_layer_activations(tokens, src_layer)
dst_activations = self.attribution.get_layer_activations(tokens, dst_layer)
# 编码为特征
src_features = self.sae.encode(src_activations)
dst_features = self.sae.encode(dst_activations)
# 检查每个源特征对目标特征的贡献
for src_feat_idx in feature_indices:
if src_feat_idx >= self.sae.n_features:
continue
for dst_feat_idx in feature_indices:
if dst_feat_idx >= self.sae.n_features:
continue
# Patch源特征
patched_src_features = src_features.clone()
patched_src_features[:, :, src_feat_idx] = 0
patched_src_activations = self.sae.decode(patched_src_features)
# 运行并检查目标特征变化
patched_dst = self.run_and_encode(
tokens, dst_layer, patched_src_activations
)
# 计算贡献
original_dst_value = dst_features[0, -1, dst_feat_idx].item()
patched_dst_value = patched_dst[0, -1, dst_feat_idx].item()
contribution = original_dst_value - patched_dst_value
if abs(contribution) > threshold:
edges.append((
{"layer_idx": i, "feat_idx": src_feat_idx},
{"layer_idx": i + 1, "feat_idx": dst_feat_idx},
contribution
))
return edges
def run_and_encode(self, tokens, layer_idx, src_activations):
"""运行模型并编码指定层"""
def hook_fn(module, input, output):
return (src_activations,) + output[1:]
hook = self.model.model.layers[layer_idx].register_forward_hook(hook_fn)
with torch.no_grad():
self.model(tokens)
hook.remove()
# 获取后续激活
cache = {}
def cache_hook(module, input, output):
cache["activation"] = output[0]
cache_hook_handle = self.model.model.layers[layer_idx].register_forward_hook(cache_hook)
with torch.no_grad():
self.model(tokens)
cache_hook_handle.remove()
return self.sae.encode(cache["activation"])3. 电路分析
3.1 可视化
import matplotlib.pyplot as plt
import networkx as nx
def visualize_circuit(circuit: nx.DiGraph, save_path: str = None):
"""
可视化稀疏特征电路
"""
fig, ax = plt.subplots(1, 1, figsize=(16, 12))
# 按层分组
layers = defaultdict(list)
for node in circuit.nodes():
layer = circuit.nodes[node]["layer"]
layers[layer].append(node)
# 布局
pos = {}
n_layers = len(layers)
for layer_idx, (layer, nodes) in enumerate(sorted(layers.items())):
n_nodes = len(nodes)
for i, node in enumerate(nodes):
x = layer_idx
y = (i - n_nodes / 2) / max(n_nodes - 1, 1)
pos[node] = (x, y)
# 绘制节点
node_colors = []
node_sizes = []
for node in circuit.nodes():
importance = circuit.nodes[node]["importance"]
node_colors.append(importance)
node_sizes.append(100 + abs(importance) * 500)
nx.draw_networkx_nodes(
circuit, pos,
node_color=node_colors,
node_size=node_sizes,
cmap=plt.cm.RdBu_r,
alpha=0.8,
ax=ax
)
# 绘制边
edges = list(circuit.edges())
if edges:
edge_weights = [abs(circuit.edges[e[0], e[1]]["weight"]) for e in edges]
edge_colors = [circuit.edges[e[0], e[1]]["weight"] for e in edges]
nx.draw_networkx_edges(
circuit, pos,
edgelist=edges,
edge_color=edge_colors,
width=[1 + w * 2 for w in edge_weights],
alpha=0.6,
cmap=plt.cm.RdBu_r,
ax=ax
)
# 添加层标签
for layer_idx, layer in enumerate(sorted(layers.keys())):
ax.annotate(
f"Layer {layer}",
xy=(layer_idx, 1.15),
ha='center',
fontsize=12,
fontweight='bold'
)
ax.set_title(f"Sparse Feature Circuit: {circuit.graph.get('behavior', 'Unknown')}")
ax.axis('off')
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
return fig3.2 电路统计
def analyze_circuit_statistics(circuit: nx.DiGraph) -> dict:
"""
计算电路的统计信息
"""
stats = {
"n_nodes": circuit.number_of_nodes(),
"n_edges": circuit.number_of_edges(),
"density": nx.density(circuit),
}
# 层间连接
layer_connections = defaultdict(list)
for u, v in circuit.edges():
src_layer = circuit.nodes[u]["layer"]
dst_layer = circuit.nodes[v]["layer"]
layer_connections[(src_layer, dst_layer)].append((u, v))
stats["layer_connections"] = dict(layer_connections)
# 节点的入度和出度
in_degrees = dict(circuit.in_degree())
out_degrees = dict(circuit.out_degree())
stats["max_in_degree"] = max(in_degrees.values())
stats["max_out_degree"] = max(out_degrees.values())
# 关键节点(hub节点)
hub_nodes = [
node for node, degree in in_degrees.items()
if degree >= 3 or out_degrees[node] >= 3
]
stats["hub_nodes"] = hub_nodes
# 连通分量
stats["n_components"] = nx.number_weakly_connected_components(circuit)
return stats4. 应用案例
4.1 归纳头电路
def analyze_induction_head_circuit(model, sae, tokens):
"""
分析归纳头电路
归纳头负责:
1. 定位之前的"模式起始"token (如 [)
2. 复制该模式之后的内容
"""
circuit_builder = SparseFeatureCircuit(model, sae)
# 关键层(通常在Transformer中层15-25)
layers = list(range(15, 26))
# 构建电路
circuit = circuit_builder.build_circuit(
tokens,
behavior_description="Induction Head (Pattern Completion)",
layers=layers,
importance_threshold=0.05,
)
# 分析电路结构
stats = analyze_circuit_statistics(circuit)
# 可视化
visualize_circuit(circuit, "induction_head_circuit.png")
return circuit, stats4.2 IOI (Indirect Object Identification) 电路
def analyze_ioi_circuit(model, sae, tokens):
"""
分析IOI (间接对象识别) 电路
IOI电路负责:
1. 识别句子的主语 (如 "Then, Mary gave John")
2. 找到正确的对象 (如 "John")
3. 在正确的位置输出
"""
circuit_builder = SparseFeatureCircuit(model, sae)
layers = list(range(0, 32))
circuit = circuit_builder.build_circuit(
tokens,
behavior_description="Indirect Object Identification",
layers=layers,
importance_threshold=0.08,
)
# 找出关键节点
key_nodes = [
node for node in circuit.nodes()
if circuit.in_degree(node) + circuit.out_degree(node) >= 5
]
print(f"Key nodes: {key_nodes}")
return circuit5. 与传统电路发现的对比
5.1 方法对比
| 方面 | 传统电路发现 | Sparse Feature Circuits |
|---|---|---|
| 操作单元 | Token级激活 | SAE特征 |
| 叠加处理 | 无法处理 | 自然解决 |
| 特征解释 | 困难 | 相对容易 |
| 完整性 | 可能有遗漏 | 更全面 |
| 计算成本 | 较高 | 较高 |
5.2 定量对比
# 对比实验示例
def compare_circuit_methods(
model, sae, test_cases, target_behavior
):
"""
对比传统电路发现和Sparse Feature Circuits
"""
results = {
"traditional": {
"circuits_found": [],
"feature_explainability": [],
"coverage": [],
},
"sparse_feature": {
"circuits_found": [],
"feature_explainability": [],
"coverage": [],
}
}
for case in test_cases:
tokens = case["tokens"]
# 传统方法
traditional_circuit = traditional_circuit_discovery(model, tokens)
results["traditional"]["circuits_found"].append(traditional_circuit)
results["traditional"]["feature_explainability"].append(
estimate_explainability(traditional_circuit)
)
# Sparse Feature Circuits
sf_circuit = SparseFeatureCircuit(model, sae).build_circuit(
tokens, target_behavior, layers=case["layers"]
)
results["sparse_feature"]["circuits_found"].append(sf_circuit)
results["sparse_feature"]["feature_explainability"].append(
estimate_explainability(sf_circuit)
)
return results6. 局限性与未来方向
6.1 当前局限性
| 局限性 | 描述 |
|---|---|
| 计算成本 | Patching实验数量随特征数平方增长 |
| 边界模糊 | 电路边界难以精确定义 |
| 动态电路 | 同一电路在不同输入可能有变化 |
| 特征可靠性 | SAE特征的解释可能被误导 |
6.2 未来方向
| 方向 | 描述 |
|---|---|
| 自动化 | 自动发现和验证电路 |
| 层次化 | 从局部到全局的电路层次 |
| 动态 | 捕捉电路的输入依赖性 |
| 安全应用 | 使用电路知识进行安全干预 |