概述

稀疏自编码器(Sparse Autoencoder, SAE) 是一种用于神经网络可解释性的强大工具,通过无监督方式从模型激活中分解出可解释的特征。1

SAE的核心思想是:训练一个自编码器在重建原始激活的同时,强制中间表示保持稀疏性。这种稀疏性约束使得每个潜在维度倾向于对应一个有意义的、可解释的特征方向。

原始Transformer层:
input → [MLP层] → output

使用SAE分解:
input → [MLP层] → 激活x → SAE解码 → 重建x̂
                        ↓
                   稀疏特征h = f(x)
                        ↓
              h的每个维度 = 一个特征(如"大写字母"、"数学运算"等)

1. 引言:多义性(Polysemanticity)问题

1.1 神经网络的”一神经元多含义”现象

传统观点认为单个神经元可能对应一个可解释的概念(如”猫神经元”)。然而,现实情况更为复杂:

多义性(Polysemanticity) 指的是单个神经元同时编码多个不相关特征的现象。2

神经元激活模式传统解释实际情况
高激活时单一概念可能同时响应”狗”和”猫”
低激活时无特征可能抑制其他不相关特征
不同语境不同含义上下文依赖的多义响应

1.2 为什么会出现多义性?

原因一:超完备表示(Overcomplete Representation)

神经网络需要表示的特征数量(可能是数百万)远超其维度。通过叠加(Superposition),模型可以将多个特征压缩到有限维度中。

原因二:训练目标的隐式偏好

模型被训练来最小化损失函数,而非显式地分离特征。梯度下降自然倾向于将相关特征放置在相似方向上。

1.3 可解释性的挑战

多义性给可解释性研究带来根本性挑战:

  1. 方向不易解释:单个神经元的激活方向可能不对应任何单一概念
  2. 特征纠缠:相关特征的方向可能混合在一起
  3. 非线性复杂性:特征可能是非线性的,不对应简单方向

2. Superposition现象:Toy Models of Superposition

2.1 核心发现

Anthropic在2022年的论文”Toy Models of Superposition”中通过构造简化的玩具模型,系统研究了Superposition现象。2

核心问题:神经网络如何在维度受限的情况下表示超过其维度的特征?

2.2 玩具模型设置

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
 
class ToySuperpositionModel(nn.Module):
    """
    玩具模型:训练一个小网络表示稀疏特征
    
    假设我们有 n_features 个独立特征,但只有 d_hidden 个神经元。
    当 n_features > d_hidden 时,必须发生Superposition。
    """
    def __init__(self, n_features=100, n_hidden=10, n_samples=10000):
        super().__init__()
        self.n_features = n_features
        self.n_hidden = n_hidden
        
        # 输入:n_features 维的one-hot向量(稀疏特征)
        # 隐藏层:n_hidden 维
        # 输出:重建 n_features 维
        
        self.encoding = nn.Linear(n_features, n_hidden, bias=False)
        self.decoding = nn.Linear(n_hidden, n_features, bias=False)
        
        # 初始化为近似正交
        nn.init.xavier_uniform_(self.encoding.weight)
        nn.init.zeros_(self.decoding.weight)
        
    def forward(self, x):
        # x: (batch, n_features) - 稀疏的one-hot向量
        h = self.encoding(x)  # (batch, n_hidden)
        x_recon = self.decoding(h)  # (batch, n_features)
        return x_recon, h

2.3 理论分析

设置 个独立的稀疏特征, 个隐藏神经元。

关键发现

  1. 当特征稀疏且数量适中时),每个特征可以获得”专属”神经元方向。

  2. 当特征过多时),发生Superposition:

    • 多个特征共享同一神经元
    • 通过非线性(如ReLU的分区效应)分离特征
    • 产生” polysemantic”神经元

数学表述

假设我们有 个特征共享同一个神经元 。该神经元对第 个特征的响应为:

其中 是第 个特征的激活值, 是对应的权重。

2.4 干扰(Interference)与非线性

Superposition的代价是干扰:一个特征的激活可能意外激活其他特征。

解决方案:利用非线性过滤。

假设两个特征 f₁ 和 f₂ 共享神经元 h:
h = w₁f₁ + w₂f₂

如果 f₁ > 0 且 f₂ = 0 → h > 0(正确激活)
如果 f₁ = 0 且 f₂ > 0 → h > 0(正确激活)
如果 f₁ > 0 且 f₂ > 0 → h 过大(干扰!)

使用ReLU分区:
h = ReLU(w₁f₁ + w₂f₂ - θ)
只有当组合激活超过阈值时才激活 → 减少干扰

2.5 关键结论

Superposition的核心洞察:当特征稀疏时,神经网络可以”超载”单个神经元,并通过非线性激活函数来分离特征。SAE的目标正是从这种叠加表示中”解压”出原始的独立特征。


3. SAE基本原理

3.1 架构

SAE是一种特殊的自编码器,由编码器(Encoder)和解码器(Decoder)组成:1

SAE架构:

输入激活 x (d_model维)
    ↓
[编码器] h = f(x) = ReLU(W_enc · x + b_enc)
    ↓
稀疏潜在表示 h (d_sae维,通常 d_sae >> d_model)
    ↓
[解码器] x̂ = W_dec · h + b_dec
    ↓
重建激活 x̂ (d_model维)

关键点

  • 潜在维度 通常远大于输入维度 (超完备)
  • 编码器通常没有非线性(线性编码),或使用ReLU
  • 解码器是线性的
  • 潜在表示 被强制稀疏

3.2 训练目标

SAE的损失函数由两部分组成:1

其中:

符号含义
重建损失(均方误差)
$|h|_1 = \sum_ih_i
稀疏惩罚系数(通常 0.001~0.1)

双重目标

  1. 保真度:准确重建原始激活
  2. 稀疏性:使用尽可能少的特征

3.3 为什么L1正则化产生稀疏性?

L1正则化的几何解释:

L1球(diamond) vs L2球(sphere):

L2球:                    L1球:
  ╭───╮                    ╱╲
 ╱     ╲                  ╱  ╲
│       │                ╱    ╲
╲       ╱               ╱      ╲
 ╲     ╱               ╱________╲
  ╰───╯

与损失函数等高线交点:

L2: 圆与圆的交点(连续解)
L1: 角与角的交点(稀疏解)
    ↑ 角点通常落在坐标轴上 → 多个维度为0

4. 训练细节

4.1 数据收集

SAE训练需要收集目标模型的中间激活:

class ActivationCollector:
    """收集Transformer模型中间层的激活"""
    
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.activations = []
        
    def hook_fn(self, module, input, output):
        # 捕获目标层的输出
        # output可能是tuple (hidden_states,) 或直接是tensor
        if isinstance(output, tuple):
            act = output[0]
        else:
            act = output
        self.activations.append(act.detach().cpu())
        
    def collect(self, texts, batch_size=8):
        """收集大量文本上的激活"""
        self.activations = []
        hooks = []
        
        # 注册hook
        for name, module in self.model.named_modules():
            if name == self.target_layer:
                hooks.append(
                    module.register_forward_hook(self.hook_fn)
                )
        
        # 收集激活
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            with torch.no_grad():
                self.model(batch)
        
        # 移除hook
        for h in hooks:
            h.remove()
            
        return torch.cat(self.activations, dim=0)

4.2 损失函数实现

class SAELoss(nn.Module):
    """
    SAE损失函数实现
    
    L = ||x - x̂||² + λ·||h||₁
    """
    def __init__(self, lambda_sparse=0.01):
        super().__init__()
        self.lambda_sparse = lambda_sparse
        
    def forward(self, x, x_recon, h):
        # 重建损失
        recon_loss = torch.sum((x - x_recon) ** 2, dim=-1).mean()
        
        # L1稀疏损失
        sparse_loss = torch.sum(torch.abs(h), dim=-1).mean()
        
        # 总损失
        total_loss = recon_loss + self.lambda_sparse * sparse_loss
        
        return total_loss, {
            'recon_loss': recon_loss.item(),
            'sparse_loss': sparse_loss.item()
        }

4.3 超参数选择

超参数典型值范围说明
4× ~ 64× 潜在维度,越大能分解更多特征
0.001 ~ 0.1L1系数,需要根据激活分布调整
学习率1e-4 ~ 3e-4建议使用学习率调度
Batch Size4096 ~ 8192越大越稳定
训练步数100k ~ 1M需要充分训练直到收敛

4.4 训练配置示例

# SAE训练配置
sae_config = {
    'd_in': 5120,          # 输入维度(MLP输出)
    'd_sae': 32768,        # SAE潜在维度(通常 4-8倍)
    'lambda': 0.003,       # L1稀疏系数
    'lr': 1e-4,            # 学习率
    'warmup_steps': 1000,  # 预热步数
    'train_steps': 200000, # 训练步数
}
 
class SparseAutoencoder(nn.Module):
    def __init__(self, d_in, d_sae, lambda_reg):
        super().__init__()
        self.d_in = d_in
        self.d_sae = d_sae
        
        # 编码器:线性 + ReLU
        self.W_enc = nn.Parameter(torch.randn(d_sae, d_in) * 0.01)
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        
        # 解码器:线性
        self.W_dec = nn.Parameter(torch.zeros(d_in, d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_in))
        
        # 初始化解码器权重为编码器权重的转置的伪逆
        # 这样可以加速训练收敛
        
    def encode(self, x):
        h = F.relu(x @ self.W_enc.T + self.b_enc)
        return h
    
    def decode(self, h):
        x_recon = h @ self.W_dec.T + self.b_dec
        return x_recon
    
    def forward(self, x):
        h = self.encode(x)
        x_recon = self.decode(h)
        return x_recon, h

4.5 训练技巧

  1. 解码器权重初始化

    # 使用W_enc^T的伪逆初始化W_dec
    with torch.no_grad():
        self.W_dec.copy_(torch.linalg.pinv(self.W_enc))
  2. 学习率调度

    # 余弦退火或线性warmup + 常数
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=train_steps
    )
  3. Ghost Gradients:当特征从未激活时,其梯度也为零。可使用”dead feature restart”技术。


5. SAE变体

5.1 Gated SAE

论文:Improving Dictionary Learning with Gated Sparse Autoencoders (NeurIPS 2024)3

问题:标准SAE中的L1惩罚会引入Shrinkage问题——特征激活值被系统性地低估。

核心洞察:将功能分离为两个子问题:

  1. 选择:决定使用哪些特征
  2. 估计:计算这些特征的激活强度

架构

class GatedSAE(nn.Module):
    """
    Gated SAE:分离特征选择和幅度估计
    """
    def __init__(self, d_in, d_sae, lambda_sparse):
        super().__init__()
        self.d_in = d_in
        self.d_sae = d_sae
        
        # 门控网络(决定"是否"激活)
        self.gate_encoder = nn.Linear(d_in, d_sae, bias=False)
        
        # 幅度网络(决定"激活多少")
        self.magnitude_encoder = nn.Linear(d_in, d_sae, bias=True)
        
        # 解码器
        self.W_dec = nn.Parameter(torch.zeros(d_in, d_sae))
        
    def forward(self, x):
        # 门控信号(sigmoid归一化到[0,1])
        gate = torch.sigmoid(self.gate_encoder(x))  # (batch, d_sae)
        
        # 幅度信号
        magnitude = F.relu(self.magnitude_encoder(x))  # (batch, d_sae)
        
        # 组合:gate * magnitude
        h = gate * magnitude
        
        # 重建
        x_recon = h @ self.W_dec.T
        
        return x_recon, h, gate, magnitude

损失函数

关键改进

  • L1正则化只应用于门控信号,不应用于幅度
  • 避免了Shrinkage问题
  • 激活值更准确地反映特征强度

5.2 JumpReLU SAE

核心思想:使用阈值函数实现更激进的稀疏性。

class JumpReLUSAE(nn.Module):
    """
    JumpReLU SAE:使用可学习的跳跃阈值
    """
    def __init__(self, d_in, d_sae):
        super().__init__()
        self.W_enc = nn.Linear(d_in, d_sae, bias=False)
        self.threshold = nn.Parameter(torch.zeros(d_sae))  # 可学习阈值
        
    def encode(self, x):
        pre_acts = self.W_enc(x)
        # JumpReLU: 如果超过阈值则激活,否则为0
        h = F.relu(pre_acts - self.threshold)
        return h

5.3 TopK SAE

核心思想:始终只保留激活值最大的K个特征。

class TopKSAE(nn.Module):
    """
    TopK SAE:强制每个样本只有K个活跃特征
    """
    def __init__(self, d_in, d_sae, k=32):
        super().__init__()
        self.k = k
        self.W_enc = nn.Linear(d_in, d_sae, bias=True)
        self.W_dec = nn.Parameter(torch.zeros(d_in, d_sae))
        
    def encode(self, x):
        h_pre = self.W_enc(x)
        # TopK稀疏化
        h = self._topk_sparse(h_pre)
        return h
    
    def _topk_sparse(self, x, k=None):
        if k is None:
            k = self.k
        # 取最大的k个值,其余置为0
        values, indices = torch.topk(x, k, dim=-1)
        output = torch.zeros_like(x)
        output.scatter_(-1, indices, values)
        return output

5.4 变体对比

变体稀疏方式优点缺点
Vanilla SAEL1正则化简单、稳定Shrinkage问题
Gated SAESigmoid门控无Shrinkage,精度高额外参数
JumpReLU阈值函数极稀疏阈值敏感
TopK SAETop-K选择固定稀疏度K需手动设置

6. 特征分析实践

6.1 特征激活分析

训练完SAE后,需要分析每个特征对应什么含义。

class FeatureAnalyzer:
    """SAE特征分析工具"""
    
    def __init__(self, sae, model, tokenizer):
        self.sae = sae
        self.model = model
        self.tokenizer = tokenizer
        
    def get_feature_activations(self, text):
        """获取文本触发的特征"""
        inputs = self.tokenizer(text, return_tensors='pt').to('cuda')
        
        with torch.no_grad():
            # 获取模型激活
            _, cache = self.model.run_with_cache(
                inputs.input_ids,
                names_filter=lambda n: 'mlp' in n.lower()
            )
            
            # 通过SAE编码
            mlp_act = cache['block.mlp.hook_post']
            h = self.sae.encode(mlp_act)
            
        return h
    
    def find_top_activating_examples(self, feature_idx, dataset, n=20):
        """找到激活特定特征最强的例子"""
        activations = []
        
        for text in dataset:
            h = self.get_feature_activations(text)
            # 取该特征的平均激活值
            act = h[:, feature_idx].mean().item()
            activations.append((text, act))
        
        # 排序
        activations.sort(key=lambda x: x[1], reverse=True)
        return activations[:n]

6.2 特征验证方法

方法一:激活修补(Activation Patching)

验证特征因果作用的标准方法:

def patch_activation_evaluation(feature_idx, sae, model, clean_text, corrupted_text):
    """
    激活修补实验:验证特征对输出的因果影响
    """
    # 1. 获取clean和corrupted的激活
    clean_h = sae.encode(get_activations(model, clean_text))
    corrupted_h = sae.encode(get_activations(model, corrupted_text))
    
    # 2. 在目标位置修补特定特征
    def patch_fn(activations):
        h = sae.encode(activations)
        h[:, feature_idx] = clean_h[:, feature_idx]  # 修补
        return sae.decode(h)
    
    # 3. 对比修补前后的输出变化
    original_output = model(corrupted_text)
    patched_output = model(corrupted_text, hook_patch=patch_fn)
    
    return original_output, patched_output

方法二:特征引导生成

def feature_directed_generation(sae, model, tokenizer, feature_idx, 
                                 base_prompt, boost_strength=5.0):
    """
    通过增强特定特征的激活来引导生成
    """
    def hook_fn(activations, hook):
        h = sae.encode(activations)
        h[:, feature_idx] *= boost_strength  # 增强特征
        return sae.decode(h)
    
    output_ids = model.generate(
        tokenizer(base_prompt, return_tensors='pt').input_ids,
        hooks=[("block.mlp.hook_post", hook_fn)]
    )
    
    return tokenizer.decode(output_ids[0])

6.3 自动可解释性评估

使用辅助模型评估特征的可解释性:

class AutoInterpreter:
    """使用GPT-4等模型自动生成特征解释"""
    
    def __init__(self, openai_api_key):
        self.client = OpenAI(api_key=openai_api_key)
        
    def explain_feature(self, feature_idx, sae, dataset, top_k=10):
        # 收集最强激活的例子
        examples = self.collect_top_examples(feature_idx, dataset, top_k)
        
        # 构建prompt
        prompt = f"""
        以下是对某个SAE特征的激活示例:
        
        高激活例子:
        {examples}
        
        请分析这个特征可能代表什么概念或模式。
        请给出:
        1. 可能的特征名称
        2. 简洁的解释
        3. 置信度(高/中/低)
        """
        
        response = self.client.chat.completions.create(
            model="gpt-4",
            messages=[{"role": "user", "content": prompt}]
        )
        
        return response.choices[0].message.content

6.4 常见特征类型

根据Anthropic和EleutherAI的研究,SAE在LLM中发现的特征包括:

类别示例特征验证方式
语法句号、逗号、换行符位置激活位置统计
语义”蓝色”、“国王”、“编程”文本替换测试
概念情感、主题、实体类型相关性分析
行为引用、解释、代码块上下文模式
安全恶意软件签名、有害内容红队测试

7. 与其他方法对比

7.1 方法概览

方法类别可解释性保真度计算成本
Activation Patching因果验证
Probing监督探测
LIME局部近似
SHAP特征重要性
SAE无监督分解中-高

7.2 Activation Patching

Activation Patching(也称因果追踪)通过修补特定位置的激活来验证其因果作用。4

优点

  • 明确的因果解释
  • 验证特征必要性

缺点

  • 需要人工指定要测试的位置
  • 无法自动发现特征

7.3 LIME

Local Interpretable Model-agnostic Explanations 通过局部扰动数据拟合简单模型。

from lime.lime_text import LimeTextExplainer
 
explainer = LimeTextExplainer()
exp = explainer.explain_instance(
    text, 
    model.predict_proba,
    num_features=10
)
exp.show_in_notebook()

局限性

  • 只适用于输入级别的解释
  • 无法解释模型内部表示

7.4 SHAP

SHapley Additive exPlanations 基于博弈论的特征重要性方法。

import shap
 
explainer = shap.Explainer(model, tokenizer)
shap_values = explainer(texts)
shap.plots.text(shap_values)

局限性

  • 计算成本高(指数级)
  • 通常只应用于输入层

7.5 SAE的优势

  1. 无监督:不需要标注数据
  2. 层次化:可应用于任意中间层
  3. 可组合:特征可叠加重建激活
  4. 稀疏:每个样本只有少数活跃特征
  5. 可扩展:已应用于数十亿参数模型

8. 局限性

8.1 重建保真度问题

SAE的重建质量可能不完美:

典型问题

  • 细节丢失
  • 非线性交互难以重建
  • 某些激活模式无法用线性解码器捕获

缓解方法

  • 增加 维度
  • 使用更复杂的解码器架构
  • 训练更长时间

8.2 特征完整性问题

SAE可能无法发现所有特征:

  1. 死亡特征:训练后某些特征从未激活
  2. 特征重叠:不同特征方向过于接近
  3. 特征碎片化:单一语义概念被分解为多个特征

8.3 可靠性验证困难

Ground Truth缺失:在真实模型中没有特征的真实标注。

评估方法

  1. 前向重构一致性:同一文本多次推理是否产生一致的特征激活
  2. 反事实稳定性:改变无关部分是否保持特征激活不变
  3. 功能验证:通过激活修补验证特征的因果作用

8.4 扩展性问题

挑战描述
计算成本大模型需要更多SAE维度
特征数量百万级特征难以全部人工分析
层次差异不同层的特征类型不同
动态特性特征可能随模型微调变化

9. 实践工具与资源

9.1 开源实现

来源特点
SAELensTransformerlens社区完整训练+分析工具
GatedSAEDeepMindGated SAE实现
SAEEleutherAI大规模训练框架

9.2 预训练SAE模型

模型规模来源
Gemma-2 SAEs2B, 9BGoogle
Pythia SAEs70M-12BEleutherAI
Claude SAEsClaude 3 SonnetAnthropic(未公开)

9.3 可视化工具

# 使用SAELens进行特征可视化
from sae_lens import SAE, ActivationsStore
 
# 加载预训练SAE
sae, cfg = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.7.hook_resid_post"
)
 
# 分析特征
vis = sae.feature_visualization(feature_idx=42)

10. 参考文献


相关主题

Footnotes

  1. Huberman et al., “Sparse Autoencoders Find Highly Interpretable Features in Language Models”, ICLR 2024 2 3

  2. Elhage et al., “Toy Models of Superposition”, arXiv:2209.10652, Anthropic 2022 2

  3. Bricken et al., “Towards Monosemanticity: Decomposing Language Models With Dictionary Learning”, Anthropic 2023

  4. Gao et al., “Interpretability in the Wild: a Circuit for Indirect Object Identification in GPT-2 small”, ICLR 2024