SAEBench基准测试框架

概述

SAEBench1是由多个研究机构于2025年3月联合发布的Sparse Autoencoder(Sparse Autoencoder, SAE)综合评估基准。它提供了一套标准化的评估协议,用于系统性地比较不同SAE架构和训练方法的质量。

核心目标

  • 标准化SAE评估流程
  • 提供多维度质量指标
  • 揭示不同SAE方法的权衡取舍
  • 推动SAE研究的可复现性

1. 评估维度

SAEBench定义了五个核心评估维度:

┌─────────────────────────────────────────────────────────────┐
│                      SAEBench评估维度                        │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   1. Loss Recovered (LR)          重建质量                  │
│      ↓                                                      │
│   2. Automated Interpretability (AI)  自动可解释性            │
│      ↓                                                      │
│   3. Absorption Metrics (AM)        吸收率                   │
│      ↓                                                      │
│   4. Sparse Covering Radius (SCR)  稀疏覆盖半径             │
│      ↓                                                      │
│   5. Sparse Probing (SP)           稀疏探测                 │
│                                                             │
└─────────────────────────────────────────────────────────────┘

1.1 Loss Recovered (LR)

定义:SAE重构所恢复的损失减少比例。

其中:

  • 是SAE重构损失
  • 是单位映射损失(直接输出输入)

物理意义

  • LR = 0.8 表示SAE恢复了80%的原本会丢失的信息
  • LR越高,重构质量越好

计算示例

def compute_loss_recovered(sae, dataloader, device):
    """计算SAE的Loss Recovered指标"""
    total_recon_loss = 0.0
    total_identity_loss = 0.0
    n_samples = 0
    
    for batch in dataloader:
        x = batch.to(device)
        
        # SAE前向传播
        with torch.no_grad():
            features = sae.encode(x)
            recon = sae.decode(features)
        
        # 计算损失
        recon_loss = F.mse_loss(recon, x, reduction='sum')
        identity_loss = F.mse_loss(x, torch.zeros_like(x), reduction='sum')
        
        total_recon_loss += recon_loss.item()
        total_identity_loss += identity_loss.item()
        n_samples += x.shape[0]
    
    avg_recon = total_recon_loss / n_samples
    avg_identity = total_identity_loss / n_samples
    
    lr = 1 - (avg_recon / avg_identity)
    
    return lr
 
# 使用示例
lr_score = compute_loss_recovered(sae, test_loader, device='cuda')
print(f"Loss Recovered: {lr_score:.4f}")

1.2 Automated Interpretability (AI)

定义:使用LLM自动评估特征的可解释性。

方法

  1. 收集每个特征激活最高的前 个文本片段
  2. 使用GPT-4等LLM评估这些片段是否具有一致的语义主题
  3. 计算”可解释”特征的比例

评分标准

评分范围: 1-5

1 - 无一致性:片段之间没有明显关联
2 - 弱一致性:片段之间有部分关联
3 - 中等一致性:大多数片段围绕同一主题
4 - 强一致性:几乎所有片段都是同一概念的变体
5 - 完全一致:片段完美对应一个清晰概念

实现代码

from anthropic import Anthropic
 
class AutomatedInterpretability:
    """自动可解释性评估器"""
    
    def __init__(self, model_name="claude-sonnet-4-20250514"):
        self.client = Anthropic()
        self.model_name = model_name
    
    def evaluate_feature(self, feature_activations: list[str], feature_idx: int) -> dict:
        """
        评估单个特征的可解释性
        
        Args:
            feature_activations: 特征激活最高的文本片段列表
            feature_idx: 特征索引
        
        Returns:
            包含评分和解释的字典
        """
        prompt = f"""You are analyzing a feature (index {feature_idx}) from a Sparse Autoencoder.
 
The feature activates most strongly on these text examples:
{chr(10).join([f"{i+1}. {text}" for i, text in enumerate(feature_activations)])}
 
Please evaluate:
1. Do these examples share a common semantic concept or theme?
2. If yes, describe the concept in 1-2 sentences.
3. Rate the consistency from 1-5.
 
Respond in JSON format:
{{
    "concept": "description of the concept or 'none'",
    "rating": 1-5,
    "reasoning": "brief explanation"
}}"""
        
        response = self.client.messages.create(
            model=self.model_name,
            max_tokens=300,
            messages=[{"role": "user", "content": prompt}]
        )
        
        import json
        result = json.loads(response.content[0].text)
        return result
    
    def batch_evaluate(self, sae, dataset, n_features=100, top_k=20) -> float:
        """
        批量评估特征的自动可解释性分数
        
        Returns:
            平均可解释性评分
        """
        results = []
        
        # 获取每个特征的高激活片段
        feature_texts = self._collect_top_activations(sae, dataset, n_features, top_k)
        
        for feat_idx, texts in feature_texts.items():
            result = self.evaluate_feature(texts, feat_idx)
            results.append(result["rating"])
        
        return sum(results) / len(results)
    
    def _collect_top_activations(self, sae, dataset, n_features, top_k):
        """收集每个特征激活最高的文本"""
        feature_texts = {i: [] for i in range(n_features)}
        feature_scores = {i: [] for i in range(n_features)}
        
        for batch_idx, batch in enumerate(dataset):
            x = batch["tokens"]
            texts = batch["texts"]
            
            with torch.no_grad():
                features = sae.encode(x)
            
            # 找出最活跃的特征
            for i in range(len(x)):
                feat_vals = features[i]
                top_indices = torch.topk(feat_vals, top_k).indices
                
                for idx in top_indices:
                    if idx.item() < n_features:
                        feature_texts[idx.item()].append(texts[i])
                        feature_scores[idx.item()].append(feat_vals[idx].item())
        
        # 返回每个特征得分最高的片段
        result = {}
        for feat_idx in range(n_features):
            if feature_scores[feat_idx]:
                top_indices = torch.topk(
                    torch.tensor(feature_scores[feat_idx]), 
                    min(10, len(feature_scores[feat_idx]))
                ).indices
                result[feat_idx] = [feature_texts[feat_idx][i] for i in top_indices]
        
        return result

1.3 Absorption Metrics (AM)

定义:衡量原始模型激活中有多大比例被SAE特征”吸收”(解释)。

数学形式

设原始激活为 ,SAE编码为 ,解码为

吸收率

残差吸收率(衡量非线性交互):

其中 是线性投影。

实现

def compute_absorption_metrics(sae, dataloader, device):
    """计算SAE的吸收率指标"""
    total_absorption = 0.0
    total_residual_absorption = 0.0
    n_samples = 0
    
    for batch in dataloader:
        x = batch.to(device)
        
        with torch.no_grad():
            features = sae.encode(x)
            recon = sae.decode(features)
        
        # 直接吸收率
        recon_norm_sq = (recon ** 2).sum(-1)
        x_norm_sq = (x ** 2).sum(-1)
        absorption = (recon_norm_sq / (x_norm_sq + 1e-8)).mean()
        
        # 残差吸收率
        residual = x - recon
        
        # 线性投影到特征空间
        projected = sae.W_dec(features)  # 等价于 recon
        
        # 非线性残差
        nonlinear_residual = residual - projected + recon
        nonlinear_residual_norm = (nonlinear_residual ** 2).sum(-1)
        residual_norm = (residual ** 2).sum(-1)
        
        residual_absorption = 1 - (nonlinear_residual_norm / (residual_norm + 1e-8)).mean()
        
        total_absorption += absorption.item() * x.shape[0]
        total_residual_absorption += residual_absorption.item() * x.shape[0]
        n_samples += x.shape[0]
    
    return {
        "absorption": total_absorption / n_samples,
        "residual_absorption": total_residual_absorption / n_samples,
    }

1.4 Sparse Covering Radius (SCR)

定义:衡量特征空间覆盖效率的指标。

概念:在 维特征空间中, 个活跃特征的覆盖半径定义为:

其中 是第 个输入的激活向量, 是第 个活跃特征的解码器方向, 是活跃特征集合。

物理意义

  • SCR越小,表示活跃特征能更好地覆盖输入空间
  • SCR与稀疏度密切相关:更稀疏 → 需要更大覆盖半径

实现

def compute_sparse_covering_radius(sae, dataloader, device, n_samples=1000):
    """计算SAE的稀疏覆盖半径"""
    
    # 收集输入激活和对应特征
    all_activations = []
    all_features = []
    
    for batch in dataloader:
        x = batch.to(device)
        
        with torch.no_grad():
            features = sae.encode(x)
        
        all_activations.append(x)
        all_features.append(features)
        
        if sum(a.shape[0] for a in all_activations) >= n_samples:
            break
    
    activations = torch.cat(all_activations)[:n_samples]
    features = torch.cat(all_features)[:n_samples]
    
    # 获取解码器方向
    decoder_directions = sae.W_dec.weight.data  # [d_model, n_features]
    
    # 计算每个样本到其激活特征的最小距离
    min_distances = []
    
    for i in range(n_samples):
        x_i = activations[i]
        feat_i = features[i]
        
        # 获取活跃特征索引
        active_indices = (feat_i > 0).nonzero().squeeze()
        
        if len(active_indices) == 0:
            continue
        
        # 计算到最近活跃特征的距离
        active_directions = decoder_directions[:, active_indices]  # [d_model, n_active]
        
        # 计算重建向量
        recon = feat_i[active_indices] @ active_directions.T
        
        # 到覆盖集的距离
        diff = x_i - recon
        min_dist = diff.norm().item()
        min_distances.append(min_dist)
    
    scr = np.mean(min_distances)
    scr_std = np.std(min_distances)
    
    return {
        "scr": scr,
        "scr_std": scr_std,
        "avg_active_features": features.sum(-1).mean().item(),
    }

1.5 Sparse Probing (SP)

定义:通过探测任务评估SAE特征的线性可分性。

方法

  1. 在SAE特征上训练线性分类器
  2. 评估分类器在各种语义探测任务上的表现
  3. 比较与原始模型激活的探测性能

常用探测任务

  • POS(词性标注)
  • 依存关系
  • 实体识别
  • 情感分类
  • 语义角色标注

实现

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score
 
class SparseProbing:
    """基于SAE特征的探测分类器"""
    
    def __init__(self, sae, device='cuda'):
        self.sae = sae
        self.device = device
        self.classifiers = {}
    
    def train_probing_classifier(self, dataloader, task_name, label_key):
        """
        训练探测分类器
        
        Args:
            dataloader: 包含tokens和labels的数据加载器
            task_name: 任务名称
            label_key: 标签在batch中的key
        """
        all_features = []
        all_labels = []
        
        for batch in dataloader:
            x = batch["tokens"].to(self.device)
            labels = batch[label_key].numpy()
            
            with torch.no_grad():
                features = self.sae.encode(x)
            
            all_features.append(features.cpu())
            all_labels.append(labels)
        
        X = torch.cat(all_features, dim=0).numpy()
        y = np.concatenate(all_labels)
        
        # 训练线性分类器
        clf = LogisticRegression(max_iter=1000, random_state=42)
        clf.fit(X, y)
        
        self.classifiers[task_name] = clf
        
        return clf
    
    def evaluate(self, dataloader, task_name, label_key):
        """评估探测分类器"""
        clf = self.classifiers.get(task_name)
        if clf is None:
            raise ValueError(f"Classifier for {task_name} not found")
        
        all_features = []
        all_labels = []
        
        for batch in dataloader:
            x = batch["tokens"].to(self.device)
            labels = batch[label_key].numpy()
            
            with torch.no_grad():
                features = self.sae.encode(x)
            
            all_features.append(features.cpu())
            all_labels.append(labels)
        
        X = torch.cat(all_features, dim=0).numpy()
        y = np.concatenate(all_labels)
        
        y_pred = clf.predict(X)
        
        return {
            "accuracy": accuracy_score(y, y_pred),
            "f1_macro": f1_score(y, y_pred, average='macro'),
            "f1_weighted": f1_score(y, y_pred, average='weighted'),
        }
    
    def compare_with_baseline(self, dataloader, baseline_activations, task_name, label_key):
        """
        比较SAE特征与原始激活的探测性能
        """
        # SAE探测性能
        sae_metrics = self.evaluate(dataloader, task_name, label_key)
        
        # 基线探测性能(使用原始激活)
        baseline_features = baseline_activations
        baseline_clf = LogisticRegression(max_iter=1000, random_state=42)
        
        all_labels = []
        for batch in dataloader:
            all_labels.append(batch[label_key].numpy())
        y = np.concatenate(all_labels)
        
        baseline_clf.fit(baseline_features[:len(y)], y)
        baseline_pred = baseline_clf.predict(baseline_features[:len(y)])
        
        baseline_metrics = {
            "accuracy": accuracy_score(y, baseline_pred),
            "f1_macro": f1_score(y, baseline_pred, average='macro'),
        }
        
        return {
            "sae": sae_metrics,
            "baseline": baseline_metrics,
            "relative_performance": sae_metrics["accuracy"] / baseline_metrics["accuracy"],
        }

2. SAEBench协议

2.1 标准评估流程

SAEBench评估流程:

┌─────────────────────────────────────────────────────────────┐
│  Step 1: 准备评估数据                                       │
│  ├── 预训练激活缓存 (Pile, C4, OpenWebText)               │
│  ├── 标准探测数据集 (BLiMP, SuperGLUE)                       │
│  └── LLM评估样本                                             │
└─────────────────────────────────────────────────────────────┘
                            │
                            ▼
┌─────────────────────────────────────────────────────────────┐
│  Step 2: 运行SAE推理                                        │
│  ├── 编码所有激活向量                                        │
│  ├── 记录活跃特征分布                                        │
│  └── 计算重构激活                                            │
└─────────────────────────────────────────────────────────────┘
                            │
                            ▼
┌─────────────────────────────────────────────────────────────┐
│  Step 3: 计算自动指标                                        │
│  ├── Loss Recovered                                          │
│  ├── Absorption Metrics                                      │
│  ├── Sparse Covering Radius                                  │
│  └── Sparse Probing                                          │
└─────────────────────────────────────────────────────────────┘
                            │
                            ▼
┌─────────────────────────────────────────────────────────────┐
│  Step 4: LLM辅助评估                                         │
│  ├── 收集特征高激活片段                                       │
│  ├── 调用LLM进行可解释性评分                                  │
│  └── 汇总Automated Interpretability分数                      │
└─────────────────────────────────────────────────────────────┘
                            │
                            ▼
┌─────────────────────────────────────────────────────────────┐
│  Step 5: 生成报告                                            │
│  └── 综合评分卡                                              │
└─────────────────────────────────────────────────────────────┘

2.2 配置参数

参数默认值说明
n_eval_samples10,000评估样本数量
n_topk_activations20每特征收集的top-k激活
batch_size256批处理大小
probing_tasks[POS, DEP, NER]探测任务列表
llm_modelclaude-3-5-sonnetLLM评估模型

3. 主要发现

3.1 架构对比结果

SAEBench对多种SAE架构进行了系统性比较:

架构LRAIAMSCRSP
标准ReLU SAE0.823.20.780.450.71
JumpReLU SAE0.853.50.810.420.74
TopK SAE0.843.40.790.380.73
Matryoshka SAE0.833.60.800.410.72
Gemma Scope0.873.80.840.360.76

3.2 关键洞察

3.2.1 重建与稀疏性的权衡

重建质量 (LR)
    ↑
 0.9│            ● JumpReLU
    │         ●
 0.8│      ●         ● TopK
    │   ●                   ● ReLU
 0.7│●
    │
    └───────────────────────────→ 稀疏度
         0.02   0.05   0.10   0.20

发现:JumpReLU在相同稀疏度下实现了更好的重建质量。

3.2.2 特征规模的影响

放大倍数LR平均活跃特征推荐场景
0.721.2%实时系统
0.822.1%通用分析
0.873.8%详细研究
16×0.917.2%离线分析

3.2.3 自动可解释性与手动评估的相关性

SAEBench发现,自动可解释性评分与人类专家评估的相关性为 ,表明LLM评估是手动评估的有效代理。


4. 使用指南

4.1 安装

pip install saebench

4.2 快速评估

from saebench import SAEBench
 
# 初始化评估器
evaluator = SAEBench(
    sae_model_path="path/to/your/sae",
    eval_data_path="path/to/activations",
)
 
# 运行完整评估
results = evaluator.run_full_benchmark()
 
# 查看结果
print(results.summary())

4.3 自定义评估

from saebench import SAEBench, BenchmarkConfig
 
# 自定义配置
config = BenchmarkConfig(
    n_samples=5000,
    probing_tasks=["pos", "dep"],
    llm_eval_model="gpt-4",
    save_activations=True,
)
 
evaluator = SAEBench(sae_model, config=config)
 
# 只运行特定指标
results = evaluator.run_metrics(["loss_recovered", "absorption"])

4.4 结果可视化

from saebench.visualization import plot_benchmark_results
 
# 可视化比较不同SAE
plot_benchmark_results(
    results_dict={
        "ReLU SAE": results_relu,
        "JumpReLU SAE": results_jumprelu,
        "TopK SAE": results_topk,
    },
    metrics=["LR", "AI", "AM", "SCR", "SP"],
    save_path="benchmark_comparison.png",
)

5. 局限性与发展方向

5.1 当前局限性

局限性影响
LLM评估成本GPT-4 API调用成本高
探测任务覆盖只覆盖部分语义任务
动态评估缺失未考虑时序变化
多模态空白仅限语言模型

5.2 未来扩展

方向描述
多模态评估扩展到视觉-语言模型
动态评估评估训练过程中的SAE演化
组合评估评估特征组合的交互
安全评估评估有害内容的可检测性

6. 参考


相关资源

Footnotes

  1. “SAEBench: A Comprehensive Benchmark for Sparse Autoencoders.” arXiv:2503.XXXXX, March 2025.