Sparse Autoencoders高级架构

概述

Sparse Autoencoders (SAEs) 已成为理解和操控大型语言模型(LLM)内部表示的核心工具1。标准SAE通过重构学习重建激活向量,同时鼓励稀疏激活来解决神经网络的叠加问题(Superposition Problem)2。然而,标准SAE架构存在一些局限性,近年来研究者提出了多种改进架构来解决这些问题。

本文档系统性地介绍SAE的主要高级架构变体,分析它们的创新点、优势和适用场景。

问题背景:为什么需要高级架构?

标准SAE的基本架构包含编码器 和解码器 ,通过最小化重构损失:

其中 是原始激活向量, 是稀疏性惩罚系数。

然而,标准SAE存在以下问题:

问题描述影响
稀疏性不均匀不同层的稀疏度分布差异大特征发现不完整
重建-稀疏权衡提高稀疏性会降低重建质量信息丢失
特征层次性无法捕获特征的层级结构缺少语义组织
死神经元问题大量神经元始终保持沉默资源浪费

高级架构通过不同的设计选择来解决这些问题。


1. JumpReLU SAEs

1.1 核心思想

JumpReLU SAEs由Google DeepMind在Gemma Scope中首次提出3,是一种创新的稀疏激活函数,它使用移位的Heaviside阶梯函数替代标准的ReLU激活。

1.2 数学形式

标准ReLU的激活函数为:

JumpReLU的激活函数为:

其中 是一个可学习的**跳跃阈值(jump threshold)**参数。

1.3 直观理解

JumpReLU可以理解为在ReLU的基础上增加了一个”跳跃”:

标准ReLU:     |      /
              |    /
              |  /
              |/
    ----------+-------→ z
              |
              |

JumpReLU:     |      ___
              |    /   
              |  /      
              |/
    ----------+-------→ z
              b

关键特性:

  • 时,激活值恒为0
  • 时,激活值按 计算
  • 参数 决定”跳跃”发生的位置

1.4 优势分析

特性JumpReLU优势
稀疏性控制阈值 自动学习,稀疏度更加可控
信息保留超过阈值的激活更”干净”
死神经元减少每个神经元有明确激活条件
解释性增强阈值可解释为特征的”存在感”

1.5 PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
 
class JumpReLU(nn.Module):
    """JumpReLU activation function with learnable threshold."""
    
    def __init__(self, threshold_init: float = 1.0):
        super().__init__()
        # Initialize threshold as a learnable parameter
        self.threshold = nn.Parameter(
            torch.full((), threshold_init)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.clamp_min(x - self.threshold, min=0.0)
 
class JumpReLUSAE(nn.Module):
    """Sparse Autoencoder with JumpReLU activation."""
    
    def __init__(
        self,
        d_model: int,
        n_features: int,
        threshold_init: float = 1.0,
        dead_feature_penalty: float = 1.0,
    ):
        super().__init__()
        self.W_enc = nn.Linear(d_model, n_features, bias=False)
        self.b_enc = nn.Parameter(torch.zeros(n_features))
        
        self.W_dec = nn.Linear(n_features, d_model, bias=False)
        self.b_dec = nn.Parameter(torch.zeros(d_model))
        
        self.jump_relu = JumpReLU(threshold_init)
        self.dead_feature_penalty = dead_feature_penalty
        
        # Ghost gradient for dead feature handling
        self.register_buffer(
            "feature_usage_counts", 
            torch.zeros(n_features)
        )
    
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """Encode input to sparse features."""
        pre_acts = self.W_enc(x) + self.b_enc
        acts = self.jump_relu(pre_acts)
        return acts
    
    def decode(self, acts: torch.Tensor) -> torch.Tensor:
        """Decode sparse features back to original space."""
        return self.W_dec(acts) + self.b_dec
    
    def forward(
        self, 
        x: torch.Tensor,
        update_usage: bool = True
    ) -> dict:
        """
        Forward pass with loss computation.
        
        Returns dict with reconstruction, loss, and activations.
        """
        # Encode
        acts = self.encode(x)
        
        # Decode
        recon = self.decode(acts)
        
        # Compute reconstruction loss
        recon_loss = F.mse_loss(recon, x)
        
        # Compute sparsity loss (L1 on activations)
        l1_loss = acts.abs().mean()
        
        # Compute dead feature penalty (ghost gradient trick)
        if update_usage:
            with torch.no_grad():
                self.feature_usage_counts += (acts > 0).float().sum(0)
            
            # Ghost gradient for dead features
            dead_mask = acts.abs() < 1e-6
            if dead_mask.any():
                ghost_grad = torch.autograd.grad(
                    acts[dead_mask].sum(),
                    self.jump_relu.threshold,
                    retain_graph=True
                )[0]
                if ghost_grad is not None:
                    dead_penalty = ghost_grad.pow(2).mean()
                else:
                    dead_penalty = torch.tensor(0.0, device=x.device)
            else:
                dead_penalty = torch.tensor(0.0, device=x.device)
        else:
            dead_penalty = torch.tensor(0.0, device=x.device)
        
        # Total loss
        total_loss = (
            recon_loss + 
            l1_loss + 
            self.dead_feature_penalty * dead_penalty
        )
        
        return {
            "reconstruction": recon,
            "features": acts,
            "recon_loss": recon_loss,
            "l1_loss": l1_loss,
            "total_loss": total_loss,
            "active_features": (acts > 0).float().sum(-1).mean().item(),
        }
    
    def get_active_features(self, x: torch.Tensor) -> torch.Tensor:
        """Get indices of active (non-zero) features."""
        acts = self.encode(x)
        return (acts > 0).nonzero(as_tuple=True)[1]

1.6 Gemma Scope中的应用

Gemma Scope在Gemma 2模型上训练了不同规模的JumpReLU SAEs:

模型隐藏维度特征数稀疏度
Gemma 2 2B256016,384 (4×)~2%
Gemma 2 9B307232,768 (8×)~3%
Gemma 2 27B460865,536 (8×)~4%

关键发现:

  • 层级组织:不同层捕获不同抽象级别的特征
  • 概念对应:许多特征与可解释的概念对应
  • 跨模型一致性:相同特征在不同尺寸模型中出现

2. Matryoshka SAEs

2.1 核心思想

Matryoshka SAEs由Bussmann等人在ICML 2025提出4,灵感来自**俄罗斯套娃(Matryoshka)**的概念。与其训练多个不同大小的独立SAE,不如训练一个SAE能够同时学习不同抽象级别的特征。

2.2 形式化定义

为嵌套的特征维度序列。Matryoshka SAE的编码器输出一个长度为 的向量,但可以自然地截断到任意 长度:

其中 表示第 个嵌套层级的特征。

2.3 训练目标

Matryoshka SAE的损失函数结合了多层级重建损失:

其中 是层级权重,通常设置为:

这样确保粗粒度层级的贡献更大。

2.4 架构设计

输入 x (d维)
    │
    ▼
┌─────────────────────────────────┐
│         编码器 (共享)            │
│    W_enc: d → m_K               │
└─────────────────────────────────┘
    │
    ▼
┌─────────────────────────────────┐
│    JumpReLU 激活                 │
└─────────────────────────────────┘
    │
    ├──→ [m_1] 粗粒度特征 (基础层)
    ├──→ [m_2] 中等粒度特征
    ├──→ ...
    └──→ [m_K] 细粒度特征 (完整)

2.5 PyTorch实现

class MatryoshkaSAE(nn.Module):
    """Matryoshka SAE with nested feature hierarchies."""
    
    def __init__(
        self,
        d_model: int,
        feature_levels: list[int],  # e.g., [1024, 4096, 16384]
        threshold_init: float = 1.0,
    ):
        super().__init__()
        self.feature_levels = sorted(feature_levels)
        self.max_features = max(feature_levels)
        self.d_model = d_model
        
        # Shared encoder
        self.W_enc = nn.Linear(d_model, self.max_features, bias=False)
        self.b_enc = nn.Parameter(torch.zeros(self.max_features))
        
        # Decoders for each level
        self.decoders = nn.ModuleDict({
            f"level_{k}": nn.Linear(mk, d_model, bias=False)
            for k, mk in enumerate(self.feature_levels)
        })
        self.biases = nn.ParameterDict({
            f"level_{k}": nn.Parameter(torch.zeros(d_model))
            for k, mk in enumerate(self.feature_levels)
        })
        
        self.activation = JumpReLU(threshold_init)
        
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """Encode to full feature space."""
        pre_acts = self.W_enc(x) + self.b_enc
        return self.activation(pre_acts)
    
    def decode(
        self, 
        features: torch.Tensor, 
        level: int
    ) -> torch.Tensor:
        """Decode from a specific level's features."""
        level_key = f"level_{level}"
        decoder = self.decoders[level_key]
        bias = self.biases[level_key]
        return decoder(features) + bias
    
    def forward(self, x: torch.Tensor) -> dict:
        """Multi-level forward pass."""
        full_features = self.encode(x)
        
        results = {"features": {}}
        
        # Compute reconstruction at each level
        total_loss = 0.0
        for k, mk in enumerate(self.feature_levels):
            # Truncate to level k's feature count
            truncated_features = full_features[:, :mk]
            
            # Decode
            recon = self.decode(truncated_features, k)
            
            # Compute losses with level-dependent weights
            recon_loss = F.mse_loss(recon, x)
            l1_loss = truncated_features.abs().mean()
            
            # Level weight (coarser levels have higher weight)
            level_weight = 1.0 / (k + 1)
            
            level_loss = level_weight * (recon_loss + 0.01 * l1_loss)
            total_loss = total_loss + level_loss
            
            results["features"][f"level_{k}"] = truncated_features
            results[f"recon_level_{k}"] = recon
            results[f"loss_level_{k}"] = level_loss
        
        results["total_loss"] = total_loss
        results["full_features"] = full_features
        
        return results
    
    def get_coarse_features(self, x: torch.Tensor, k: int) -> torch.Tensor:
        """Get features at level k (0-indexed)."""
        full_features = self.encode(x)
        return full_features[:, :self.feature_levels[k]]

2.6 优势与应用

优势说明
计算效率单次前向传播,多种粒度输出
灵活部署根据精度/速度需求选择层级
层次组织自然捕获特征的语义层次
表示学习粗粒度特征指导细粒度学习

应用场景:

  • 实时系统(使用粗粒度)
  • 离线分析(使用细粒度)
  • 可解释性研究(多粒度对比)

3. Tree SAEs

3.1 核心思想

Tree SAEs5提出将SAE的特征组织为显式的树形层次结构,捕获特征的组合性质。这对于建模复杂语义概念的自然层次性特别有效。

3.2 树结构设计

在Tree SAE中,每个特征属于树中的一个节点:

                    [ROOT]
                    /    \
              [Concept A] [Concept B]
              /    \        /    \
         [f1]  [f2]    [f3]  [f4]

特征 的子特征, 的子特征。

3.3 数学形式

设树结构由父关系 定义。Tree SAE的损失函数包含:

  1. 重建损失

  2. 稀疏性损失

  3. 树结构损失(新增):

这鼓励父特征的激活接近其子特征激活之和,实现特征的组合表示

3.4 PyTorch实现

class TreeSAE(nn.Module):
    """Tree-structured Sparse Autoencoder."""
    
    def __init__(
        self,
        d_model: int,
        tree_structure: dict,  # parent -> [children]
        feature_dim: int,
    ):
        super().__init__()
        self.d_model = d_model
        self.tree = tree_structure
        self.n_features = feature_dim
        
        self.W_enc = nn.Linear(d_model, feature_dim, bias=False)
        self.b_enc = nn.Parameter(torch.zeros(feature_dim))
        self.W_dec = nn.Linear(feature_dim, d_model, bias=False)
        self.b_dec = nn.Parameter(torch.zeros(d_model))
        
        self.activation = nn.ReLU()
        
        # Identify leaf and internal nodes
        all_children = set()
        for children in tree_structure.values():
            all_children.update(children)
        self.leaf_nodes = all_children - set(tree_structure.keys())
        self.internal_nodes = set(tree_structure.keys())
        
    def tree_structure_loss(self, features: torch.Tensor) -> torch.Tensor:
        """Enforce hierarchical structure in features."""
        loss = 0.0
        
        for parent, children in self.tree.items():
            # Parent feature
            parent_feat = features[:, parent]
            
            # Sum of children's features
            children_sum = features[:, children].sum(dim=1)
            
            # Penalize deviation
            loss = loss + (parent_feat - children_sum).pow(2).mean()
        
        return loss
    
    def forward(self, x: torch.Tensor, tree_weight: float = 0.1) -> dict:
        """Forward pass with tree structure loss."""
        # Encode
        pre_acts = self.W_enc(x) + self.b_enc
        acts = self.activation(pre_acts)
        
        # Decode
        recon = self.W_dec(acts) + self.b_dec
        
        # Losses
        recon_loss = F.mse_loss(recon, x)
        l1_loss = acts.abs().mean()
        tree_loss = self.tree_structure_loss(acts)
        
        total_loss = recon_loss + 0.01 * l1_loss + tree_weight * tree_loss
        
        return {
            "reconstruction": recon,
            "features": acts,
            "recon_loss": recon_loss,
            "l1_loss": l1_loss,
            "tree_loss": tree_loss,
            "total_loss": total_loss,
        }

4. SoftSAE

4.1 核心思想

SoftSAE6使用可微分的Top-K选择机制,在训练过程中实现动态的稀疏性控制。与硬性的阈值截断不同,SoftSAE使用软化的Top-K操作。

4.2 Soft Top-K操作

标准的Top-K操作:

SoftSAE使用可微分近似:

其中 是温度参数,控制软化的程度。

4.3 训练动态

class SoftTopK(torch.autograd.Function):
    """Differentiable Soft Top-K selection."""
    
    @staticmethod
    def forward(ctx, x: torch.Tensor, k: int, temperature: float = 1.0):
        # Hard top-k indices during forward
        _, indices = torch.topk(x, k, dim=-1)
        mask = torch.zeros_like(x).scatter_(-1, indices, 1.0)
        ctx.save_for_backward(x, mask, indices)
        return x * mask
    
    @staticmethod
    def backward(ctx, grad_output):
        x, mask, indices = ctx.saved_tensors
        
        # Straight-through estimator
        return grad_output * mask, None, None
 
class SoftSAE(nn.Module):
    """SAE with differentiable top-k selection."""
    
    def __init__(self, d_model: int, n_features: int, k: int = 32):
        super().__init__()
        self.k = k  # Target number of active features
        self.temperature = nn.Parameter(torch.tensor(1.0))  # Learnable temp
        
        self.W_enc = nn.Linear(d_model, n_features, bias=False)
        self.b_enc = nn.Parameter(torch.zeros(n_features))
        self.W_dec = nn.Linear(n_features, d_model, bias=False)
        self.b_dec = nn.Parameter(torch.zeros(d_model))
        
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        pre_acts = self.W_enc(x) + self.b_enc
        
        # Use soft top-k
        k_actual = min(self.k, pre_acts.shape[-1])
        acts = SoftTopK.apply(pre_acts, k_actual, self.temperature.item())
        
        return acts
    
    def forward(self, x: torch.Tensor) -> dict:
        acts = self.encode(x)
        recon = self.decode(acts)
        
        return {
            "features": acts,
            "reconstruction": recon,
            "loss": F.mse_loss(recon, x) + 0.01 * acts.abs().mean(),
        }

5. 架构对比

5.1 特性对比表

架构稀疏性控制层次结构实现复杂度适用场景
JumpReLU自动学习阈值通用,推荐首选
Matryoshka多粒度嵌套嵌套层级资源受限系统
TreeL1 + 树损失显式树形组合概念建模
Soft软性Top-K平滑稀疏需求

5.2 性能对比(基于Gemma Scope分析)

指标JumpReLUMatryoshkaTreeSoft
重建质量 (L2)★★★★★★★★★☆★★★☆☆★★★★☆
特征可解释性★★★★☆★★★★★★★★★★★★★☆☆
稀疏性均匀度★★★★☆★★★★☆★★★☆☆★★★★★
计算效率★★★★★★★★★☆★★★☆☆★★★★☆
死神经元率★★★★★★★★★☆★★★☆☆★★★★☆

5.3 选择指南

选择决策树:

                    ┌─────────────────────┐
                    │ 需要层次结构吗?      │
                    └─────────┬───────────┘
                              │
              ┌───────────────┴───────────────┐
              ▼                               ▼
         Yes                            No
              │                               │
    ┌─────────┴─────────┐         ┌─────────┴────────┐
    │ 需要显式组合建模? │         │ 需要均匀稀疏性?  │
    └─────────┬─────────┘         └─────────┬────────┘
              │                               │
      ┌───────┴───────┐               ┌───────┴───────┐
      ▼               ▼               ▼               ▼
   Tree SAE     Matryoshka      Soft SAE       JumpReLU
                 SAE                (软)          (首选)

6. 实践建议

6.1 架构选择建议

场景推荐架构原因
通用研究JumpReLU成熟、效果好、易实现
资源受限部署Matryoshka单模型多粒度
组合概念分析Tree显式层次建模
平滑激活需求Soft无硬性截断

6.2 超参数设置

# 推荐配置
 
# JumpReLU SAE
config = {
    "n_features": 4 * d_model,  # 通常4-8倍
    "threshold_init": 1.0,
    "dead_feature_penalty": 1.0,
    "learning_rate": 1e-4,
}
 
# Matryoshka SAE
config = {
    "feature_levels": [d_model, 4*d_model, 16*d_model],
    "threshold_init": 1.0,
    "tree_weight": 0.1,  # 层次权重
}
 
# Tree SAE
config = {
    "tree_structure": {...},  # 自定义树
    "tree_weight": 0.1,
    "sparsity_weight": 0.01,
}

6.3 训练技巧

  1. Ghost Gradient:处理死神经元
  2. 热身期:前1000步不激活L1惩罚
  3. 学习率调度:cosine decay
  4. 层归一化:编码器前添加LayerNorm

7. 未来方向

7.1 架构演进

方向描述研究状态
MoE-SAE混合专家SAE早期
动态架构根据输入自适应特征数探索中
跨模态SAE统一多模态特征空间进行中
因果SAE显式建模特征因果关系理论阶段

7.2 应用拓展

  • 实时可解释AI:边缘设备部署
  • 安全对齐:特征级干预
  • 知识编辑:精确修改模型知识
  • 科学发现:自动识别科学概念

8. 参考文献


相关资源

Footnotes

  1. Bricken et al. “Towards Monosemanticity: Decomposing Language Models With Dictionary Learning.” 2023.

  2. Elhage et al. “A Mathematical Framework for Transformer Circuits.” 2021.

  3. Templeton et al. “Gemma Scope: Open Sparse Autoencoders Everywhere All at Once on Gemma 2.” 2024.

  4. Bussmann et al. “Matryoshka Sparse Autoencoders.” ICML 2025.

  5. arXiv:2605.07922 “Tree Sparse Autoencoders”

  6. arXiv:2605.06610