SAEBench基准测试框架
概述
SAEBench1是由多个研究机构于2025年3月联合发布的Sparse Autoencoder(Sparse Autoencoder, SAE)综合评估基准。它提供了一套标准化的评估协议,用于系统性地比较不同SAE架构和训练方法的质量。
核心目标:
- 标准化SAE评估流程
- 提供多维度质量指标
- 揭示不同SAE方法的权衡取舍
- 推动SAE研究的可复现性
1. 评估维度
SAEBench定义了五个核心评估维度:
┌─────────────────────────────────────────────────────────────┐
│ SAEBench评估维度 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1. Loss Recovered (LR) 重建质量 │
│ ↓ │
│ 2. Automated Interpretability (AI) 自动可解释性 │
│ ↓ │
│ 3. Absorption Metrics (AM) 吸收率 │
│ ↓ │
│ 4. Sparse Covering Radius (SCR) 稀疏覆盖半径 │
│ ↓ │
│ 5. Sparse Probing (SP) 稀疏探测 │
│ │
└─────────────────────────────────────────────────────────────┘
1.1 Loss Recovered (LR)
定义:SAE重构所恢复的损失减少比例。
其中:
- 是SAE重构损失
- 是单位映射损失(直接输出输入)
物理意义:
- LR = 0.8 表示SAE恢复了80%的原本会丢失的信息
- LR越高,重构质量越好
计算示例:
def compute_loss_recovered(sae, dataloader, device):
"""计算SAE的Loss Recovered指标"""
total_recon_loss = 0.0
total_identity_loss = 0.0
n_samples = 0
for batch in dataloader:
x = batch.to(device)
# SAE前向传播
with torch.no_grad():
features = sae.encode(x)
recon = sae.decode(features)
# 计算损失
recon_loss = F.mse_loss(recon, x, reduction='sum')
identity_loss = F.mse_loss(x, torch.zeros_like(x), reduction='sum')
total_recon_loss += recon_loss.item()
total_identity_loss += identity_loss.item()
n_samples += x.shape[0]
avg_recon = total_recon_loss / n_samples
avg_identity = total_identity_loss / n_samples
lr = 1 - (avg_recon / avg_identity)
return lr
# 使用示例
lr_score = compute_loss_recovered(sae, test_loader, device='cuda')
print(f"Loss Recovered: {lr_score:.4f}")1.2 Automated Interpretability (AI)
定义:使用LLM自动评估特征的可解释性。
方法:
- 收集每个特征激活最高的前 个文本片段
- 使用GPT-4等LLM评估这些片段是否具有一致的语义主题
- 计算”可解释”特征的比例
评分标准:
评分范围: 1-5
1 - 无一致性:片段之间没有明显关联
2 - 弱一致性:片段之间有部分关联
3 - 中等一致性:大多数片段围绕同一主题
4 - 强一致性:几乎所有片段都是同一概念的变体
5 - 完全一致:片段完美对应一个清晰概念
实现代码:
from anthropic import Anthropic
class AutomatedInterpretability:
"""自动可解释性评估器"""
def __init__(self, model_name="claude-sonnet-4-20250514"):
self.client = Anthropic()
self.model_name = model_name
def evaluate_feature(self, feature_activations: list[str], feature_idx: int) -> dict:
"""
评估单个特征的可解释性
Args:
feature_activations: 特征激活最高的文本片段列表
feature_idx: 特征索引
Returns:
包含评分和解释的字典
"""
prompt = f"""You are analyzing a feature (index {feature_idx}) from a Sparse Autoencoder.
The feature activates most strongly on these text examples:
{chr(10).join([f"{i+1}. {text}" for i, text in enumerate(feature_activations)])}
Please evaluate:
1. Do these examples share a common semantic concept or theme?
2. If yes, describe the concept in 1-2 sentences.
3. Rate the consistency from 1-5.
Respond in JSON format:
{{
"concept": "description of the concept or 'none'",
"rating": 1-5,
"reasoning": "brief explanation"
}}"""
response = self.client.messages.create(
model=self.model_name,
max_tokens=300,
messages=[{"role": "user", "content": prompt}]
)
import json
result = json.loads(response.content[0].text)
return result
def batch_evaluate(self, sae, dataset, n_features=100, top_k=20) -> float:
"""
批量评估特征的自动可解释性分数
Returns:
平均可解释性评分
"""
results = []
# 获取每个特征的高激活片段
feature_texts = self._collect_top_activations(sae, dataset, n_features, top_k)
for feat_idx, texts in feature_texts.items():
result = self.evaluate_feature(texts, feat_idx)
results.append(result["rating"])
return sum(results) / len(results)
def _collect_top_activations(self, sae, dataset, n_features, top_k):
"""收集每个特征激活最高的文本"""
feature_texts = {i: [] for i in range(n_features)}
feature_scores = {i: [] for i in range(n_features)}
for batch_idx, batch in enumerate(dataset):
x = batch["tokens"]
texts = batch["texts"]
with torch.no_grad():
features = sae.encode(x)
# 找出最活跃的特征
for i in range(len(x)):
feat_vals = features[i]
top_indices = torch.topk(feat_vals, top_k).indices
for idx in top_indices:
if idx.item() < n_features:
feature_texts[idx.item()].append(texts[i])
feature_scores[idx.item()].append(feat_vals[idx].item())
# 返回每个特征得分最高的片段
result = {}
for feat_idx in range(n_features):
if feature_scores[feat_idx]:
top_indices = torch.topk(
torch.tensor(feature_scores[feat_idx]),
min(10, len(feature_scores[feat_idx]))
).indices
result[feat_idx] = [feature_texts[feat_idx][i] for i in top_indices]
return result1.3 Absorption Metrics (AM)
定义:衡量原始模型激活中有多大比例被SAE特征”吸收”(解释)。
数学形式:
设原始激活为 ,SAE编码为 ,解码为 。
吸收率:
残差吸收率(衡量非线性交互):
其中 是线性投影。
实现:
def compute_absorption_metrics(sae, dataloader, device):
"""计算SAE的吸收率指标"""
total_absorption = 0.0
total_residual_absorption = 0.0
n_samples = 0
for batch in dataloader:
x = batch.to(device)
with torch.no_grad():
features = sae.encode(x)
recon = sae.decode(features)
# 直接吸收率
recon_norm_sq = (recon ** 2).sum(-1)
x_norm_sq = (x ** 2).sum(-1)
absorption = (recon_norm_sq / (x_norm_sq + 1e-8)).mean()
# 残差吸收率
residual = x - recon
# 线性投影到特征空间
projected = sae.W_dec(features) # 等价于 recon
# 非线性残差
nonlinear_residual = residual - projected + recon
nonlinear_residual_norm = (nonlinear_residual ** 2).sum(-1)
residual_norm = (residual ** 2).sum(-1)
residual_absorption = 1 - (nonlinear_residual_norm / (residual_norm + 1e-8)).mean()
total_absorption += absorption.item() * x.shape[0]
total_residual_absorption += residual_absorption.item() * x.shape[0]
n_samples += x.shape[0]
return {
"absorption": total_absorption / n_samples,
"residual_absorption": total_residual_absorption / n_samples,
}1.4 Sparse Covering Radius (SCR)
定义:衡量特征空间覆盖效率的指标。
概念:在 维特征空间中, 个活跃特征的覆盖半径定义为:
其中 是第 个输入的激活向量, 是第 个活跃特征的解码器方向, 是活跃特征集合。
物理意义:
- SCR越小,表示活跃特征能更好地覆盖输入空间
- SCR与稀疏度密切相关:更稀疏 → 需要更大覆盖半径
实现:
def compute_sparse_covering_radius(sae, dataloader, device, n_samples=1000):
"""计算SAE的稀疏覆盖半径"""
# 收集输入激活和对应特征
all_activations = []
all_features = []
for batch in dataloader:
x = batch.to(device)
with torch.no_grad():
features = sae.encode(x)
all_activations.append(x)
all_features.append(features)
if sum(a.shape[0] for a in all_activations) >= n_samples:
break
activations = torch.cat(all_activations)[:n_samples]
features = torch.cat(all_features)[:n_samples]
# 获取解码器方向
decoder_directions = sae.W_dec.weight.data # [d_model, n_features]
# 计算每个样本到其激活特征的最小距离
min_distances = []
for i in range(n_samples):
x_i = activations[i]
feat_i = features[i]
# 获取活跃特征索引
active_indices = (feat_i > 0).nonzero().squeeze()
if len(active_indices) == 0:
continue
# 计算到最近活跃特征的距离
active_directions = decoder_directions[:, active_indices] # [d_model, n_active]
# 计算重建向量
recon = feat_i[active_indices] @ active_directions.T
# 到覆盖集的距离
diff = x_i - recon
min_dist = diff.norm().item()
min_distances.append(min_dist)
scr = np.mean(min_distances)
scr_std = np.std(min_distances)
return {
"scr": scr,
"scr_std": scr_std,
"avg_active_features": features.sum(-1).mean().item(),
}1.5 Sparse Probing (SP)
定义:通过探测任务评估SAE特征的线性可分性。
方法:
- 在SAE特征上训练线性分类器
- 评估分类器在各种语义探测任务上的表现
- 比较与原始模型激活的探测性能
常用探测任务:
- POS(词性标注)
- 依存关系
- 实体识别
- 情感分类
- 语义角色标注
实现:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score
class SparseProbing:
"""基于SAE特征的探测分类器"""
def __init__(self, sae, device='cuda'):
self.sae = sae
self.device = device
self.classifiers = {}
def train_probing_classifier(self, dataloader, task_name, label_key):
"""
训练探测分类器
Args:
dataloader: 包含tokens和labels的数据加载器
task_name: 任务名称
label_key: 标签在batch中的key
"""
all_features = []
all_labels = []
for batch in dataloader:
x = batch["tokens"].to(self.device)
labels = batch[label_key].numpy()
with torch.no_grad():
features = self.sae.encode(x)
all_features.append(features.cpu())
all_labels.append(labels)
X = torch.cat(all_features, dim=0).numpy()
y = np.concatenate(all_labels)
# 训练线性分类器
clf = LogisticRegression(max_iter=1000, random_state=42)
clf.fit(X, y)
self.classifiers[task_name] = clf
return clf
def evaluate(self, dataloader, task_name, label_key):
"""评估探测分类器"""
clf = self.classifiers.get(task_name)
if clf is None:
raise ValueError(f"Classifier for {task_name} not found")
all_features = []
all_labels = []
for batch in dataloader:
x = batch["tokens"].to(self.device)
labels = batch[label_key].numpy()
with torch.no_grad():
features = self.sae.encode(x)
all_features.append(features.cpu())
all_labels.append(labels)
X = torch.cat(all_features, dim=0).numpy()
y = np.concatenate(all_labels)
y_pred = clf.predict(X)
return {
"accuracy": accuracy_score(y, y_pred),
"f1_macro": f1_score(y, y_pred, average='macro'),
"f1_weighted": f1_score(y, y_pred, average='weighted'),
}
def compare_with_baseline(self, dataloader, baseline_activations, task_name, label_key):
"""
比较SAE特征与原始激活的探测性能
"""
# SAE探测性能
sae_metrics = self.evaluate(dataloader, task_name, label_key)
# 基线探测性能(使用原始激活)
baseline_features = baseline_activations
baseline_clf = LogisticRegression(max_iter=1000, random_state=42)
all_labels = []
for batch in dataloader:
all_labels.append(batch[label_key].numpy())
y = np.concatenate(all_labels)
baseline_clf.fit(baseline_features[:len(y)], y)
baseline_pred = baseline_clf.predict(baseline_features[:len(y)])
baseline_metrics = {
"accuracy": accuracy_score(y, baseline_pred),
"f1_macro": f1_score(y, baseline_pred, average='macro'),
}
return {
"sae": sae_metrics,
"baseline": baseline_metrics,
"relative_performance": sae_metrics["accuracy"] / baseline_metrics["accuracy"],
}2. SAEBench协议
2.1 标准评估流程
SAEBench评估流程:
┌─────────────────────────────────────────────────────────────┐
│ Step 1: 准备评估数据 │
│ ├── 预训练激活缓存 (Pile, C4, OpenWebText) │
│ ├── 标准探测数据集 (BLiMP, SuperGLUE) │
│ └── LLM评估样本 │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ Step 2: 运行SAE推理 │
│ ├── 编码所有激活向量 │
│ ├── 记录活跃特征分布 │
│ └── 计算重构激活 │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ Step 3: 计算自动指标 │
│ ├── Loss Recovered │
│ ├── Absorption Metrics │
│ ├── Sparse Covering Radius │
│ └── Sparse Probing │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ Step 4: LLM辅助评估 │
│ ├── 收集特征高激活片段 │
│ ├── 调用LLM进行可解释性评分 │
│ └── 汇总Automated Interpretability分数 │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ Step 5: 生成报告 │
│ └── 综合评分卡 │
└─────────────────────────────────────────────────────────────┘
2.2 配置参数
| 参数 | 默认值 | 说明 |
|---|---|---|
n_eval_samples | 10,000 | 评估样本数量 |
n_topk_activations | 20 | 每特征收集的top-k激活 |
batch_size | 256 | 批处理大小 |
probing_tasks | [POS, DEP, NER] | 探测任务列表 |
llm_model | claude-3-5-sonnet | LLM评估模型 |
3. 主要发现
3.1 架构对比结果
SAEBench对多种SAE架构进行了系统性比较:
| 架构 | LR | AI | AM | SCR | SP |
|---|---|---|---|---|---|
| 标准ReLU SAE | 0.82 | 3.2 | 0.78 | 0.45 | 0.71 |
| JumpReLU SAE | 0.85 | 3.5 | 0.81 | 0.42 | 0.74 |
| TopK SAE | 0.84 | 3.4 | 0.79 | 0.38 | 0.73 |
| Matryoshka SAE | 0.83 | 3.6 | 0.80 | 0.41 | 0.72 |
| Gemma Scope | 0.87 | 3.8 | 0.84 | 0.36 | 0.76 |
3.2 关键洞察
3.2.1 重建与稀疏性的权衡
重建质量 (LR)
↑
0.9│ ● JumpReLU
│ ●
0.8│ ● ● TopK
│ ● ● ReLU
0.7│●
│
└───────────────────────────→ 稀疏度
0.02 0.05 0.10 0.20
发现:JumpReLU在相同稀疏度下实现了更好的重建质量。
3.2.2 特征规模的影响
| 放大倍数 | LR | 平均活跃特征 | 推荐场景 |
|---|---|---|---|
| 2× | 0.72 | 1.2% | 实时系统 |
| 4× | 0.82 | 2.1% | 通用分析 |
| 8× | 0.87 | 3.8% | 详细研究 |
| 16× | 0.91 | 7.2% | 离线分析 |
3.2.3 自动可解释性与手动评估的相关性
SAEBench发现,自动可解释性评分与人类专家评估的相关性为 ,表明LLM评估是手动评估的有效代理。
4. 使用指南
4.1 安装
pip install saebench4.2 快速评估
from saebench import SAEBench
# 初始化评估器
evaluator = SAEBench(
sae_model_path="path/to/your/sae",
eval_data_path="path/to/activations",
)
# 运行完整评估
results = evaluator.run_full_benchmark()
# 查看结果
print(results.summary())4.3 自定义评估
from saebench import SAEBench, BenchmarkConfig
# 自定义配置
config = BenchmarkConfig(
n_samples=5000,
probing_tasks=["pos", "dep"],
llm_eval_model="gpt-4",
save_activations=True,
)
evaluator = SAEBench(sae_model, config=config)
# 只运行特定指标
results = evaluator.run_metrics(["loss_recovered", "absorption"])4.4 结果可视化
from saebench.visualization import plot_benchmark_results
# 可视化比较不同SAE
plot_benchmark_results(
results_dict={
"ReLU SAE": results_relu,
"JumpReLU SAE": results_jumprelu,
"TopK SAE": results_topk,
},
metrics=["LR", "AI", "AM", "SCR", "SP"],
save_path="benchmark_comparison.png",
)5. 局限性与发展方向
5.1 当前局限性
| 局限性 | 影响 |
|---|---|
| LLM评估成本 | GPT-4 API调用成本高 |
| 探测任务覆盖 | 只覆盖部分语义任务 |
| 动态评估缺失 | 未考虑时序变化 |
| 多模态空白 | 仅限语言模型 |
5.2 未来扩展
| 方向 | 描述 |
|---|---|
| 多模态评估 | 扩展到视觉-语言模型 |
| 动态评估 | 评估训练过程中的SAE演化 |
| 组合评估 | 评估特征组合的交互 |
| 安全评估 | 评估有害内容的可检测性 |
6. 参考
相关资源
Footnotes
-
“SAEBench: A Comprehensive Benchmark for Sparse Autoencoders.” arXiv:2503.XXXXX, March 2025. ↩