Gemma Scope分析

概述

Gemma Scope1是由Google DeepMind于2024年发布的一套开放的Sparse Autoencoders (SAEs),应用于Gemma 2系列模型。它是目前规模最大、覆盖最全面的开源SAE套件,为可解释性研究提供了前所未有的资源。

核心贡献

  • 首次在2B、9B、27B三种规模的模型上全面部署SAE
  • 使用创新的JumpReLU激活函数
  • 提供完整的特征目录(Feature Catalog)
  • 开源所有模型权重和工具代码

1. 背景与动机

1.1 为什么需要Gemma Scope?

大型语言模型(LLM)虽然在各种任务上表现出色,但其内部工作机制一直是黑箱。SAE提供了一种将模型激活分解为可解释特征的方法,但之前的SAE研究面临以下挑战:

挑战描述
规模限制大多数SAE只在小模型上训练
架构差异不同研究使用不同的SAE架构,难以对比
缺乏标准化缺乏统一的评估协议和工具
可复现性差代码和模型权重不完整

Gemma Scope旨在解决这些问题:

Gemma Scope目标:
┌────────────────────────────────────────────────┐
│  1. 规模:覆盖2B, 9B, 27B三种规模               │
│  2. 统一:一致的架构 (JumpReLU)                 │
│  3. 开放:完整开源模型权重和代码                │
│  4. 工具:提供完整的分析工具                    │
└────────────────────────────────────────────────┘

1.2 技术创新:JumpReLU

Gemma Scope使用了一种新的激活函数——JumpReLU,它解决了标准SAE的一些关键问题:

标准ReLU的问题

  • 稀疏度由输入分布隐式决定
  • 死神经元问题严重
  • 重建质量与稀疏性难以平衡

JumpReLU的创新

  • 引入可学习的跳跃阈值
  • 每个特征有明确的激活条件
  • 自动平衡重建与稀疏性

数学定义:


2. 模型架构

2.1 模型规格

Gemma Scope覆盖Gemma 2的三种规模:

模型规模隐藏维度 SAE特征数 放大倍数目标稀疏度
Gemma 2 2B2,56016,384~2%
Gemma 2 9B3,07232,768~3%
Gemma 2 27B4,60865,536~4%

2.2 架构详情

JumpReLU SAE 架构:

输入 x (d_model维)
    │
    ▼
┌─────────────────────────────────────┐
│  线性编码器                           │
│  W_enc: d_model → n_feat            │
│  b_enc: 可学习偏置                    │
└─────────────────────────────────────┘
    │
    ▼
┌─────────────────────────────────────┐
│  JumpReLU 激活                       │
│  f(z; b) = max(0, z - b)           │
│  b: 可学习跳跃阈值                    │
└─────────────────────────────────────┘
    │
    ▼
  稀疏特征 f(x) (平均~2-4%非零)
    │
    ▼
┌─────────────────────────────────────┐
│  线性解码器                           │
│  W_dec: n_feat → d_model            │
│  b_dec: 可学习偏置                    │
└─────────────────────────────────────┘
    │
    ▼
重建 x̂

2.3 训练配置

参数
优化器AdamW
学习率1e-4
权重衰减0.01
批量大小4,096 (tokens)
热身步数1,000
训练步数100,000
死神经元惩罚1.0
Ghost Gradient启用

3. 特征分析结果

3.1 特征分布

Gemma Scope揭示了Gemma 2内部表示的丰富结构:

3.1.1 激活分布

特征激活分布(对数尺度):

频率
  ^
  │    ╭──╮
  │   ╱    ╲     ╭────╮
  │  ╱      ╲   ╱      ╲    ╭───
  │ ╱        ╲ ╱        ╲  ╱
  │╱          ╳           ╳
  └──────────────────────────────→ 激活值
         0       2       4

观察:

  • 大部分特征处于”沉默”状态(激活=0)
  • 活跃特征呈长尾分布
  • 高激活值特征通常对应语义显著的概念

3.1.2 层级组织

不同Transformer层捕获不同抽象级别的特征:

层范围特征类型示例
浅层 (1-8)词汇、语法特征标点、时态、词性
中层 (9-16)语义、实体特征人物、地点、组织
深层 (17-24)推理、规划特征逻辑步骤、计划分解
深层 (25+)任务执行特征响应格式、风格

3.1.3 概念对应

许多特征与人类可理解的概念直接对应:

特征ID概念描述激活示例
layer_12/feat_234112编程语言Python”def function():“
layer_15/feat_890215地理位置-亚洲”Tokyo, Japan”
layer_18/feat_56718数学推导步骤”therefore, x =“
layer_22/feat_1203422礼貌语言”thank you”

3.2 Feature Catalog

Gemma Scope提供了详尽的Feature Catalog,记录了每个特征的手动解释。

3.2.1 特征解释示例

{
  "layer": 15,
  "feature_index": 1234,
  "description": "Programming language Python function definition",
  "activation_pattern": {
    "threshold": 2.5,
    "max_activation": 15.3,
    "typical_contexts": [
      "def process_data():",
      "async def fetch_url(url):",
      "lambda x: x * 2"
    ]
  },
  "related_features": [
    {"layer": 14, "index": 567, "relationship": "sub-concept"},
    {"layer": 16, "index": 890, "relationship": "higher-abstraction"}
  ],
  "manual_interpretation": {
    "confidence": "high",
    "notes": "Consistently activates for Python def/lambda keywords",
    "known_limitations": "May also activate for similar syntax in Julia"
  }
}

3.2.2 特征聚类分析

Feature Catalog揭示了特征的聚类结构:

特征聚类(简化示意):

编程相关 ──┬── 语法特征 ──┬── Python
          │              ├── JavaScript
          │              └── C++
          │
          └── 语义特征 ──┬── 算法描述
                         └── 代码注释

语言相关 ──┬── 语法 ─────── 时态、人称
          │
          └── 语义 ─────── 情感、主题

4. 技术实现

4.1 快速使用

# 安装 saelens
!pip install saelens
 
from saelens import SAEEnsemble
 
# 加载 Gemma Scope
ensemble = SAEEnsemble.from_pretrained(
    "google/gemma-scope-2b-pt-res",
    model_name="gemma-2-2b"
)
 
# 获取特定层的SAE
sae = ensemble[15]  # 第15层的SAE
 
# 分析单个激活
import torch
 
# 假设 model 是 Gemma 2 模型
def analyze_activation(model, tokenizer, text):
    """分析文本的激活特征"""
    inputs = tokenizer(text, return_tensors="pt")
    
    # 获取第15层残差流激活
    with torch.no_grad():
        # Hook到第15层
        activation_cache = {}
        
        def hook_fn(module, input, output):
            activation_cache["residual"] = output[0]
        
        handle = model.model.layers[15].register_forward_hook(hook_fn)
        outputs = model(**inputs)
        handle.remove()
        
        # 获取最后一层token的激活
        residual = activation_cache["residual"][0, -1]  # [d_model]
        
        # 通过SAE编码
        features = sae.encode(residual.unsqueeze(0))  # [1, n_features]
        recon = sae.decode(features)
        
        # 获取活跃特征
        active_indices = (features[0] > 0).nonzero().squeeze()
        active_values = features[0, active_indices]
        
        return {
            "active_features": active_indices.tolist(),
            "activation_values": active_values.tolist(),
            "reconstruction_error": F.mse_loss(recon[0], residual).item()
        }
 
# 示例分析
result = analyze_activation(
    model, tokenizer, 
    "def calculate_fibonacci(n):"
)
print(f"活跃特征数: {len(result['active_features'])}")
print(f"重建误差: {result['reconstruction_error']:.4f}")

4.2 批量分析

from tqdm import tqdm
 
def batch_feature_analysis(sae, model, dataset, n_samples=1000):
    """批量分析数据集的特征激活"""
    feature_counts = torch.zeros(sae.cfg.n_features)
    
    for batch in tqdm(dataset):
        # 获取激活
        acts = sae.encode(batch)
        
        # 统计活跃特征
        feature_counts += (acts > 0).sum(0)
    
    # 计算活跃频率
    frequencies = feature_counts / n_samples
    
    return frequencies
 
# 分析不同数据集的特征偏好
results = {
    "python_code": batch_feature_analysis(sae, model, python_ds),
    "english_text": batch_feature_analysis(sae, model, english_ds),
    "math": batch_feature_analysis(sae, model, math_ds),
}
 
# 找出差异最大的特征
for name1, freqs1 in results.items():
    for name2, freqs2 in results.items():
        if name1 < name2:
            diff = (freqs1 - freqs2).abs()
            top_diff = diff.topk(10)
            print(f"\n{name1} vs {name2} 差异最大的特征:")
            for idx, val in zip(*top_diff):
                print(f"  特征 {idx}: 差异 {val:.4f}")

4.3 特征重构

def reconstruct_from_features(sae, feature_indices, feature_values):
    """从指定特征重构激活"""
    # 创建稀疏激活向量
    acts = torch.zeros(1, sae.cfg.n_features)
    acts[0, list(feature_indices)] = list(feature_values)
    
    # 解码
    recon = sae.decode(acts)
    
    return recon.squeeze(0)
 
def interpret_feature_direction(sae, feature_idx, n_samples=50):
    """解释特征的方向含义"""
    # 获取特征对应的解码器列
    decoder_col = sae.W_dec.weight[:, feature_idx]
    
    # 找到最激活这个特征的输入
    max_acts = []
    max_inputs = []
    
    for _ in range(n_samples):
        # 随机输入
        x = torch.randn(1, sae.cfg.d_sae)
        
        acts = sae.encode(x)
        if acts[0, feature_idx] > 0:
            max_acts.append(acts[0, feature_idx].item())
            max_inputs.append(x.squeeze().clone())
    
    if not max_inputs:
        return {"error": "Feature never activates"}
    
    # 分析解码器方向
    return {
        "feature_idx": feature_idx,
        "decoder_direction": decoder_col.cpu().numpy(),
        "activation_stats": {
            "mean": sum(max_acts) / len(max_acts),
            "max": max(max_acts),
            "n_activating": len(max_acts)
        }
    }

5. 与其他SAE的对比

5.1 规模对比

项目Gemma ScopeAnthropic SAEsEleutherAI SAEs
模型规模2B-27B70B1B-7B
模型数量313
总特征数~115K~34K~65K
激活函数JumpReLUReLUReLU
开源程度完全开源部分完全开源

5.2 质量对比

指标Gemma ScopeAnthropicEleutherAI
Loss Recovered~85%~80%~75%
Feature Sparsity~3%~5%~8%
Dead Feature Rate<5%~10%~15%
解释一致性

5.3 JumpReLU vs ReLU

import matplotlib.pyplot as plt
 
def compare_activations():
    """比较JumpReLU和ReLU的激活分布"""
    # 模拟激活值
    z = torch.randn(10000)
    
    # ReLU激活
    relu_acts = torch.clamp_min(z, 0)
    
    # JumpReLU激活 (假设阈值=1.0)
    threshold = 1.0
    jump_relu_acts = torch.clamp_min(z - threshold, 0)
    
    # 绘图
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    axes[0].hist(z.numpy(), bins=50, alpha=0.7)
    axes[0].set_title("原始激活分布")
    axes[0].set_xlabel("值")
    axes[0].set_ylabel("频率")
    
    axes[1].hist(relu_acts.numpy(), bins=50, alpha=0.7)
    axes[1].set_title(f"ReLU激活 (稀疏度={(relu_acts>0).float().mean():.1%})")
    axes[1].set_xlabel("激活值")
    
    axes[2].hist(jump_relu_acts.numpy(), bins=50, alpha=0.7)
    axes[2].set_title(f"JumpReLU激活 (阈值={threshold}, 稀疏度={(jump_relu_acts>0).float().mean():.1%})")
    axes[2].set_xlabel("激活值")
    
    plt.tight_layout()
    plt.savefig("activation_comparison.png")
    
    return {
        "relu_sparsity": (relu_acts > 0).float().mean().item(),
        "jumprelu_sparsity": (jump_relu_acts > 0).float().mean().item(),
    }
 
stats = compare_activations()
print(f"ReLU稀疏度: {stats['relu_sparsity']:.1%}")
print(f"JumpReLU稀疏度: {stats['jumprelu_sparsity']:.1%}")

6. 应用案例

6.1 特征可视化

def visualize_feature_activation(sae, feature_idx, texts):
    """可视化特定特征的激活模式"""
    from transformers import AutoTokenizer
    import torch.nn.functional as F
    
    tokenizer = AutoTokenizer.from_pretrained("gemma-2-2b")
    
    activations = []
    positions = []
    
    for text in texts:
        inputs = tokenizer(text, return_tensors="pt")
        
        # 获取tokenized后的位置
        tokens = inputs["input_ids"][0]
        
        # Hook获取激活
        cache = {}
        def hook_fn(module, input, output):
            cache["out"] = output[0]
        
        handle = sae.register_forward_hook(hook_fn)
        # ... (需要完整的模型前向传播)
        handle.remove()
        
        # 分析该特征在哪些位置激活
        acts = cache["out"][0]  # [seq_len, d_model]
        feat_acts = sae.encode(acts)  # [seq_len, n_features]
        feat_activation = feat_acts[:, feature_idx].cpu()
        
        activations.append(feat_activation)
        positions.append(len(tokens))
    
    # 可视化
    import matplotlib.pyplot as plt
    
    fig, ax = plt.subplots(figsize=(12, 4))
    
    offset = 0
    for i, (acts, pos) in enumerate(zip(activations, positions)):
        ax.bar(range(offset, offset + pos), acts.numpy(), alpha=0.7)
        offset += pos + 2
    
    ax.set_xlabel("Token位置")
    ax.set_ylabel("特征激活")
    ax.set_title(f"特征 {feature_idx} 在不同文本中的激活")
    ax.axhline(y=sae.cfg.threshold, color='r', linestyle='--', label='阈值')
    ax.legend()
    
    plt.tight_layout()
    plt.savefig(f"feature_{feature_idx}_visualization.png")

6.2 特征操控

def steer_model_with_feature(sae, model, feature_idx, direction, strength=2.0):
    """
    使用特定特征方向操控模型输出
    
    Args:
        sae: 训练好的SAE
        model: 原始语言模型
        feature_idx: 要操控的特征索引
        direction: 操控方向 (+1 增加, -1 抑制)
        strength: 操控强度
    """
    def modified_hook(module, input, output):
        # 原始激活
        acts = output[0]
        
        # 通过SAE编码
        sparse_acts = sae.encode(acts)
        
        # 修改目标特征
        sparse_acts[:, feature_idx] += direction * strength
        
        # 通过SAE解码
        modified_acts = sae.decode(sparse_acts)
        
        return (modified_acts,) + output[1:]
    
    return modified_hook
 
# 使用示例:增加Python代码生成倾向
python_feature_idx = 1234  # 需要通过分析确定
 
hook = steer_model_with_feature(
    sae, model, 
    feature_idx=python_feature_idx,
    direction=+1,
    strength=3.0
)
 
# 注册hook
handle = model.layers[15].register_forward_hook(hook)
 
# 生成文本
output = model.generate(**inputs, max_new_tokens=100)
 
# 移除hook
handle.remove()
 
print(tokenizer.decode(output[0]))

6.3 电路发现

def find_circuit_for_feature(sae, model, feature_idx, layer_range=(0, 26)):
    """
    找到与特定特征相关的注意力头和残差连接
    """
    from collections import defaultdict
    
    contributions = defaultdict(list)
    
    # Hook每个注意力头和MLP
    hooks = {}
    
    for layer in range(*layer_range):
        # 注意力头输出
        def make_attn_hook(l):
            def hook_fn(module, input, output):
                # output: (output, None, None, None) for attention
                attn_output = output[0]
                # 计算每个头对目标特征的贡献
                acts = sae.encode(attn_output)
                contrib = acts[:, feature_idx].mean()
                contributions[f"layer_{l}_attn"].append(contrib.item())
            return hook_fn
        hooks[f"attn_{layer}"] = make_attn_hook(layer)
        
        # MLP输出
        def make_mlp_hook(l):
            def hook_fn(module, input, output):
                acts = sae.encode(output)
                contrib = acts[:, feature_idx].mean()
                contributions[f"layer_{l}_mlp"].append(contrib.item())
            return hook_fn
        hooks[f"mlp_{layer}"] = make_mlp_hook(layer)
    
    # 注册所有hooks
    hook_handles = []
    for name, hook_fn in hooks.items():
        if "attn" in name:
            layer = int(name.split("_")[1])
            handle = model.model.layers[layer].self_attn.register_forward_hook(hook_fn)
        else:
            layer = int(name.split("_")[1])
            handle = model.model.layers[layer].mlp.register_forward_hook(hook_fn)
        hook_handles.append(handle)
    
    # 运行推理
    with torch.no_grad():
        model(**inputs)
    
    # 清理
    for handle in hook_handles:
        handle.remove()
    
    # 找出贡献最大的组件
    circuit = {}
    for name, contribs in contributions.items():
        if contribs:
            circuit[name] = sum(contribs) / len(contribs)
    
    # 排序
    sorted_circuit = sorted(circuit.items(), key=lambda x: abs(x[1]), reverse=True)
    
    return sorted_circuit[:10]  # 返回前10个最重要组件

7. 局限性与未来方向

7.1 当前局限性

局限性描述影响
非单一特征许多特征并非”原子”可解释性打折
上下文依赖特征激活依赖于上下文分析复杂
计算成本编码所有层成本高大规模分析受限
覆盖范围仍有未解释的激活信息丢失

7.2 正在进行的研究

研究方向描述
特征层次聚类自动发现特征的自然分组
跨层追踪追踪信息在层间的流动
因果干预使用激活 patching 验证因果关系
组合特征建模特征的交互和组合

8. 使用建议

8.1 最佳实践

  1. 从浅层开始:浅层特征更直观,易于理解
  2. 使用活跃特征:只分析激活频率 >0.1% 的特征
  3. 交叉验证:用多个输入验证特征解释
  4. 结合注意力分析:SAE特征 + 注意力模式 = 更完整理解

8.2 常见陷阱

陷阱说明避免方法
过度解读将随机模式解读为有意义的特征多样本验证
忽略上下文忽视激活的上下文依赖分析激活位置
单一方向只关注正激活同时分析正负激活
孤立分析不考虑特征间的关系构建特征图

8.3 推荐工具链

# 推荐的工具组合
 
tools = {
    "sae_training": "saelens",       # SAELens库
    "feature_exploration": "neuronpedia",  # 在线可视化
    "circuit_analysis": "transformerlens",   # TransformerLens
    "interpretability": "circuits",          # 自定义分析
    "visualization": "matplotlib",           # 绘图
    "clustering": "sklearn",                # 特征聚类
}

9. 参考文献


相关资源

Footnotes

  1. Templeton et al. “Gemma Scope: Open Sparse Autoencoders Everywhere All at Once on Gemma 2.” Google DeepMind, 2024.