Sparse Autoencoders高级架构
概述
Sparse Autoencoders (SAEs) 已成为理解和操控大型语言模型(LLM)内部表示的核心工具1。标准SAE通过重构学习重建激活向量,同时鼓励稀疏激活来解决神经网络的叠加问题(Superposition Problem)2。然而,标准SAE架构存在一些局限性,近年来研究者提出了多种改进架构来解决这些问题。
本文档系统性地介绍SAE的主要高级架构变体,分析它们的创新点、优势和适用场景。
问题背景:为什么需要高级架构?
标准SAE的基本架构包含编码器 和解码器 ,通过最小化重构损失:
其中 是原始激活向量, 是稀疏性惩罚系数。
然而,标准SAE存在以下问题:
| 问题 | 描述 | 影响 |
|---|---|---|
| 稀疏性不均匀 | 不同层的稀疏度分布差异大 | 特征发现不完整 |
| 重建-稀疏权衡 | 提高稀疏性会降低重建质量 | 信息丢失 |
| 特征层次性 | 无法捕获特征的层级结构 | 缺少语义组织 |
| 死神经元问题 | 大量神经元始终保持沉默 | 资源浪费 |
高级架构通过不同的设计选择来解决这些问题。
1. JumpReLU SAEs
1.1 核心思想
JumpReLU SAEs由Google DeepMind在Gemma Scope中首次提出3,是一种创新的稀疏激活函数,它使用移位的Heaviside阶梯函数替代标准的ReLU激活。
1.2 数学形式
标准ReLU的激活函数为:
JumpReLU的激活函数为:
其中 是一个可学习的**跳跃阈值(jump threshold)**参数。
1.3 直观理解
JumpReLU可以理解为在ReLU的基础上增加了一个”跳跃”:
标准ReLU: | /
| /
| /
|/
----------+-------→ z
|
|
JumpReLU: | ___
| /
| /
|/
----------+-------→ z
b
关键特性:
- 当 时,激活值恒为0
- 当 时,激活值按 计算
- 参数 决定”跳跃”发生的位置
1.4 优势分析
| 特性 | JumpReLU优势 |
|---|---|
| 稀疏性控制 | 阈值 自动学习,稀疏度更加可控 |
| 信息保留 | 超过阈值的激活更”干净” |
| 死神经元减少 | 每个神经元有明确激活条件 |
| 解释性增强 | 阈值可解释为特征的”存在感” |
1.5 PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class JumpReLU(nn.Module):
"""JumpReLU activation function with learnable threshold."""
def __init__(self, threshold_init: float = 1.0):
super().__init__()
# Initialize threshold as a learnable parameter
self.threshold = nn.Parameter(
torch.full((), threshold_init)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.clamp_min(x - self.threshold, min=0.0)
class JumpReLUSAE(nn.Module):
"""Sparse Autoencoder with JumpReLU activation."""
def __init__(
self,
d_model: int,
n_features: int,
threshold_init: float = 1.0,
dead_feature_penalty: float = 1.0,
):
super().__init__()
self.W_enc = nn.Linear(d_model, n_features, bias=False)
self.b_enc = nn.Parameter(torch.zeros(n_features))
self.W_dec = nn.Linear(n_features, d_model, bias=False)
self.b_dec = nn.Parameter(torch.zeros(d_model))
self.jump_relu = JumpReLU(threshold_init)
self.dead_feature_penalty = dead_feature_penalty
# Ghost gradient for dead feature handling
self.register_buffer(
"feature_usage_counts",
torch.zeros(n_features)
)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode input to sparse features."""
pre_acts = self.W_enc(x) + self.b_enc
acts = self.jump_relu(pre_acts)
return acts
def decode(self, acts: torch.Tensor) -> torch.Tensor:
"""Decode sparse features back to original space."""
return self.W_dec(acts) + self.b_dec
def forward(
self,
x: torch.Tensor,
update_usage: bool = True
) -> dict:
"""
Forward pass with loss computation.
Returns dict with reconstruction, loss, and activations.
"""
# Encode
acts = self.encode(x)
# Decode
recon = self.decode(acts)
# Compute reconstruction loss
recon_loss = F.mse_loss(recon, x)
# Compute sparsity loss (L1 on activations)
l1_loss = acts.abs().mean()
# Compute dead feature penalty (ghost gradient trick)
if update_usage:
with torch.no_grad():
self.feature_usage_counts += (acts > 0).float().sum(0)
# Ghost gradient for dead features
dead_mask = acts.abs() < 1e-6
if dead_mask.any():
ghost_grad = torch.autograd.grad(
acts[dead_mask].sum(),
self.jump_relu.threshold,
retain_graph=True
)[0]
if ghost_grad is not None:
dead_penalty = ghost_grad.pow(2).mean()
else:
dead_penalty = torch.tensor(0.0, device=x.device)
else:
dead_penalty = torch.tensor(0.0, device=x.device)
else:
dead_penalty = torch.tensor(0.0, device=x.device)
# Total loss
total_loss = (
recon_loss +
l1_loss +
self.dead_feature_penalty * dead_penalty
)
return {
"reconstruction": recon,
"features": acts,
"recon_loss": recon_loss,
"l1_loss": l1_loss,
"total_loss": total_loss,
"active_features": (acts > 0).float().sum(-1).mean().item(),
}
def get_active_features(self, x: torch.Tensor) -> torch.Tensor:
"""Get indices of active (non-zero) features."""
acts = self.encode(x)
return (acts > 0).nonzero(as_tuple=True)[1]1.6 Gemma Scope中的应用
Gemma Scope在Gemma 2模型上训练了不同规模的JumpReLU SAEs:
| 模型 | 隐藏维度 | 特征数 | 稀疏度 |
|---|---|---|---|
| Gemma 2 2B | 2560 | 16,384 (4×) | ~2% |
| Gemma 2 9B | 3072 | 32,768 (8×) | ~3% |
| Gemma 2 27B | 4608 | 65,536 (8×) | ~4% |
关键发现:
- 层级组织:不同层捕获不同抽象级别的特征
- 概念对应:许多特征与可解释的概念对应
- 跨模型一致性:相同特征在不同尺寸模型中出现
2. Matryoshka SAEs
2.1 核心思想
Matryoshka SAEs由Bussmann等人在ICML 2025提出4,灵感来自**俄罗斯套娃(Matryoshka)**的概念。与其训练多个不同大小的独立SAE,不如训练一个SAE能够同时学习不同抽象级别的特征。
2.2 形式化定义
设 为嵌套的特征维度序列。Matryoshka SAE的编码器输出一个长度为 的向量,但可以自然地截断到任意 长度:
其中 表示第 个嵌套层级的特征。
2.3 训练目标
Matryoshka SAE的损失函数结合了多层级重建损失:
其中 和 是层级权重,通常设置为:
这样确保粗粒度层级的贡献更大。
2.4 架构设计
输入 x (d维)
│
▼
┌─────────────────────────────────┐
│ 编码器 (共享) │
│ W_enc: d → m_K │
└─────────────────────────────────┘
│
▼
┌─────────────────────────────────┐
│ JumpReLU 激活 │
└─────────────────────────────────┘
│
├──→ [m_1] 粗粒度特征 (基础层)
├──→ [m_2] 中等粒度特征
├──→ ...
└──→ [m_K] 细粒度特征 (完整)
2.5 PyTorch实现
class MatryoshkaSAE(nn.Module):
"""Matryoshka SAE with nested feature hierarchies."""
def __init__(
self,
d_model: int,
feature_levels: list[int], # e.g., [1024, 4096, 16384]
threshold_init: float = 1.0,
):
super().__init__()
self.feature_levels = sorted(feature_levels)
self.max_features = max(feature_levels)
self.d_model = d_model
# Shared encoder
self.W_enc = nn.Linear(d_model, self.max_features, bias=False)
self.b_enc = nn.Parameter(torch.zeros(self.max_features))
# Decoders for each level
self.decoders = nn.ModuleDict({
f"level_{k}": nn.Linear(mk, d_model, bias=False)
for k, mk in enumerate(self.feature_levels)
})
self.biases = nn.ParameterDict({
f"level_{k}": nn.Parameter(torch.zeros(d_model))
for k, mk in enumerate(self.feature_levels)
})
self.activation = JumpReLU(threshold_init)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode to full feature space."""
pre_acts = self.W_enc(x) + self.b_enc
return self.activation(pre_acts)
def decode(
self,
features: torch.Tensor,
level: int
) -> torch.Tensor:
"""Decode from a specific level's features."""
level_key = f"level_{level}"
decoder = self.decoders[level_key]
bias = self.biases[level_key]
return decoder(features) + bias
def forward(self, x: torch.Tensor) -> dict:
"""Multi-level forward pass."""
full_features = self.encode(x)
results = {"features": {}}
# Compute reconstruction at each level
total_loss = 0.0
for k, mk in enumerate(self.feature_levels):
# Truncate to level k's feature count
truncated_features = full_features[:, :mk]
# Decode
recon = self.decode(truncated_features, k)
# Compute losses with level-dependent weights
recon_loss = F.mse_loss(recon, x)
l1_loss = truncated_features.abs().mean()
# Level weight (coarser levels have higher weight)
level_weight = 1.0 / (k + 1)
level_loss = level_weight * (recon_loss + 0.01 * l1_loss)
total_loss = total_loss + level_loss
results["features"][f"level_{k}"] = truncated_features
results[f"recon_level_{k}"] = recon
results[f"loss_level_{k}"] = level_loss
results["total_loss"] = total_loss
results["full_features"] = full_features
return results
def get_coarse_features(self, x: torch.Tensor, k: int) -> torch.Tensor:
"""Get features at level k (0-indexed)."""
full_features = self.encode(x)
return full_features[:, :self.feature_levels[k]]2.6 优势与应用
| 优势 | 说明 |
|---|---|
| 计算效率 | 单次前向传播,多种粒度输出 |
| 灵活部署 | 根据精度/速度需求选择层级 |
| 层次组织 | 自然捕获特征的语义层次 |
| 表示学习 | 粗粒度特征指导细粒度学习 |
应用场景:
- 实时系统(使用粗粒度)
- 离线分析(使用细粒度)
- 可解释性研究(多粒度对比)
3. Tree SAEs
3.1 核心思想
Tree SAEs5提出将SAE的特征组织为显式的树形层次结构,捕获特征的组合性质。这对于建模复杂语义概念的自然层次性特别有效。
3.2 树结构设计
在Tree SAE中,每个特征属于树中的一个节点:
[ROOT]
/ \
[Concept A] [Concept B]
/ \ / \
[f1] [f2] [f3] [f4]
特征 是 的子特征, 是 的子特征。
3.3 数学形式
设树结构由父关系 定义。Tree SAE的损失函数包含:
-
重建损失:
-
稀疏性损失:
-
树结构损失(新增):
这鼓励父特征的激活接近其子特征激活之和,实现特征的组合表示。
3.4 PyTorch实现
class TreeSAE(nn.Module):
"""Tree-structured Sparse Autoencoder."""
def __init__(
self,
d_model: int,
tree_structure: dict, # parent -> [children]
feature_dim: int,
):
super().__init__()
self.d_model = d_model
self.tree = tree_structure
self.n_features = feature_dim
self.W_enc = nn.Linear(d_model, feature_dim, bias=False)
self.b_enc = nn.Parameter(torch.zeros(feature_dim))
self.W_dec = nn.Linear(feature_dim, d_model, bias=False)
self.b_dec = nn.Parameter(torch.zeros(d_model))
self.activation = nn.ReLU()
# Identify leaf and internal nodes
all_children = set()
for children in tree_structure.values():
all_children.update(children)
self.leaf_nodes = all_children - set(tree_structure.keys())
self.internal_nodes = set(tree_structure.keys())
def tree_structure_loss(self, features: torch.Tensor) -> torch.Tensor:
"""Enforce hierarchical structure in features."""
loss = 0.0
for parent, children in self.tree.items():
# Parent feature
parent_feat = features[:, parent]
# Sum of children's features
children_sum = features[:, children].sum(dim=1)
# Penalize deviation
loss = loss + (parent_feat - children_sum).pow(2).mean()
return loss
def forward(self, x: torch.Tensor, tree_weight: float = 0.1) -> dict:
"""Forward pass with tree structure loss."""
# Encode
pre_acts = self.W_enc(x) + self.b_enc
acts = self.activation(pre_acts)
# Decode
recon = self.W_dec(acts) + self.b_dec
# Losses
recon_loss = F.mse_loss(recon, x)
l1_loss = acts.abs().mean()
tree_loss = self.tree_structure_loss(acts)
total_loss = recon_loss + 0.01 * l1_loss + tree_weight * tree_loss
return {
"reconstruction": recon,
"features": acts,
"recon_loss": recon_loss,
"l1_loss": l1_loss,
"tree_loss": tree_loss,
"total_loss": total_loss,
}4. SoftSAE
4.1 核心思想
SoftSAE6使用可微分的Top-K选择机制,在训练过程中实现动态的稀疏性控制。与硬性的阈值截断不同,SoftSAE使用软化的Top-K操作。
4.2 Soft Top-K操作
标准的Top-K操作:
SoftSAE使用可微分近似:
其中 是温度参数,控制软化的程度。
4.3 训练动态
class SoftTopK(torch.autograd.Function):
"""Differentiable Soft Top-K selection."""
@staticmethod
def forward(ctx, x: torch.Tensor, k: int, temperature: float = 1.0):
# Hard top-k indices during forward
_, indices = torch.topk(x, k, dim=-1)
mask = torch.zeros_like(x).scatter_(-1, indices, 1.0)
ctx.save_for_backward(x, mask, indices)
return x * mask
@staticmethod
def backward(ctx, grad_output):
x, mask, indices = ctx.saved_tensors
# Straight-through estimator
return grad_output * mask, None, None
class SoftSAE(nn.Module):
"""SAE with differentiable top-k selection."""
def __init__(self, d_model: int, n_features: int, k: int = 32):
super().__init__()
self.k = k # Target number of active features
self.temperature = nn.Parameter(torch.tensor(1.0)) # Learnable temp
self.W_enc = nn.Linear(d_model, n_features, bias=False)
self.b_enc = nn.Parameter(torch.zeros(n_features))
self.W_dec = nn.Linear(n_features, d_model, bias=False)
self.b_dec = nn.Parameter(torch.zeros(d_model))
def encode(self, x: torch.Tensor) -> torch.Tensor:
pre_acts = self.W_enc(x) + self.b_enc
# Use soft top-k
k_actual = min(self.k, pre_acts.shape[-1])
acts = SoftTopK.apply(pre_acts, k_actual, self.temperature.item())
return acts
def forward(self, x: torch.Tensor) -> dict:
acts = self.encode(x)
recon = self.decode(acts)
return {
"features": acts,
"reconstruction": recon,
"loss": F.mse_loss(recon, x) + 0.01 * acts.abs().mean(),
}5. 架构对比
5.1 特性对比表
| 架构 | 稀疏性控制 | 层次结构 | 实现复杂度 | 适用场景 |
|---|---|---|---|---|
| JumpReLU | 自动学习阈值 | 无 | 低 | 通用,推荐首选 |
| Matryoshka | 多粒度嵌套 | 嵌套层级 | 中 | 资源受限系统 |
| Tree | L1 + 树损失 | 显式树形 | 高 | 组合概念建模 |
| Soft | 软性Top-K | 无 | 中 | 平滑稀疏需求 |
5.2 性能对比(基于Gemma Scope分析)
| 指标 | JumpReLU | Matryoshka | Tree | Soft |
|---|---|---|---|---|
| 重建质量 (L2) | ★★★★★ | ★★★★☆ | ★★★☆☆ | ★★★★☆ |
| 特征可解释性 | ★★★★☆ | ★★★★★ | ★★★★★ | ★★★☆☆ |
| 稀疏性均匀度 | ★★★★☆ | ★★★★☆ | ★★★☆☆ | ★★★★★ |
| 计算效率 | ★★★★★ | ★★★★☆ | ★★★☆☆ | ★★★★☆ |
| 死神经元率 | ★★★★★ | ★★★★☆ | ★★★☆☆ | ★★★★☆ |
5.3 选择指南
选择决策树:
┌─────────────────────┐
│ 需要层次结构吗? │
└─────────┬───────────┘
│
┌───────────────┴───────────────┐
▼ ▼
Yes No
│ │
┌─────────┴─────────┐ ┌─────────┴────────┐
│ 需要显式组合建模? │ │ 需要均匀稀疏性? │
└─────────┬─────────┘ └─────────┬────────┘
│ │
┌───────┴───────┐ ┌───────┴───────┐
▼ ▼ ▼ ▼
Tree SAE Matryoshka Soft SAE JumpReLU
SAE (软) (首选)
6. 实践建议
6.1 架构选择建议
| 场景 | 推荐架构 | 原因 |
|---|---|---|
| 通用研究 | JumpReLU | 成熟、效果好、易实现 |
| 资源受限部署 | Matryoshka | 单模型多粒度 |
| 组合概念分析 | Tree | 显式层次建模 |
| 平滑激活需求 | Soft | 无硬性截断 |
6.2 超参数设置
# 推荐配置
# JumpReLU SAE
config = {
"n_features": 4 * d_model, # 通常4-8倍
"threshold_init": 1.0,
"dead_feature_penalty": 1.0,
"learning_rate": 1e-4,
}
# Matryoshka SAE
config = {
"feature_levels": [d_model, 4*d_model, 16*d_model],
"threshold_init": 1.0,
"tree_weight": 0.1, # 层次权重
}
# Tree SAE
config = {
"tree_structure": {...}, # 自定义树
"tree_weight": 0.1,
"sparsity_weight": 0.01,
}6.3 训练技巧
- Ghost Gradient:处理死神经元
- 热身期:前1000步不激活L1惩罚
- 学习率调度:cosine decay
- 层归一化:编码器前添加LayerNorm
7. 未来方向
7.1 架构演进
| 方向 | 描述 | 研究状态 |
|---|---|---|
| MoE-SAE | 混合专家SAE | 早期 |
| 动态架构 | 根据输入自适应特征数 | 探索中 |
| 跨模态SAE | 统一多模态特征空间 | 进行中 |
| 因果SAE | 显式建模特征因果关系 | 理论阶段 |
7.2 应用拓展
- 实时可解释AI:边缘设备部署
- 安全对齐:特征级干预
- 知识编辑:精确修改模型知识
- 科学发现:自动识别科学概念
8. 参考文献
相关资源
Footnotes
-
Bricken et al. “Towards Monosemanticity: Decomposing Language Models With Dictionary Learning.” 2023. ↩
-
Elhage et al. “A Mathematical Framework for Transformer Circuits.” 2021. ↩
-
Templeton et al. “Gemma Scope: Open Sparse Autoencoders Everywhere All at Once on Gemma 2.” 2024. ↩
-
Bussmann et al. “Matryoshka Sparse Autoencoders.” ICML 2025. ↩
-
arXiv:2605.07922 “Tree Sparse Autoencoders” ↩
-
arXiv:2605.06610 ↩