Gemma Scope分析
概述
Gemma Scope1是由Google DeepMind于2024年发布的一套开放的Sparse Autoencoders (SAEs),应用于Gemma 2系列模型。它是目前规模最大、覆盖最全面的开源SAE套件,为可解释性研究提供了前所未有的资源。
核心贡献:
- 首次在2B、9B、27B三种规模的模型上全面部署SAE
- 使用创新的JumpReLU激活函数
- 提供完整的特征目录(Feature Catalog)
- 开源所有模型权重和工具代码
1. 背景与动机
1.1 为什么需要Gemma Scope?
大型语言模型(LLM)虽然在各种任务上表现出色,但其内部工作机制一直是黑箱。SAE提供了一种将模型激活分解为可解释特征的方法,但之前的SAE研究面临以下挑战:
| 挑战 | 描述 |
|---|---|
| 规模限制 | 大多数SAE只在小模型上训练 |
| 架构差异 | 不同研究使用不同的SAE架构,难以对比 |
| 缺乏标准化 | 缺乏统一的评估协议和工具 |
| 可复现性差 | 代码和模型权重不完整 |
Gemma Scope旨在解决这些问题:
Gemma Scope目标:
┌────────────────────────────────────────────────┐
│ 1. 规模:覆盖2B, 9B, 27B三种规模 │
│ 2. 统一:一致的架构 (JumpReLU) │
│ 3. 开放:完整开源模型权重和代码 │
│ 4. 工具:提供完整的分析工具 │
└────────────────────────────────────────────────┘
1.2 技术创新:JumpReLU
Gemma Scope使用了一种新的激活函数——JumpReLU,它解决了标准SAE的一些关键问题:
标准ReLU的问题:
- 稀疏度由输入分布隐式决定
- 死神经元问题严重
- 重建质量与稀疏性难以平衡
JumpReLU的创新:
- 引入可学习的跳跃阈值
- 每个特征有明确的激活条件
- 自动平衡重建与稀疏性
数学定义:
2. 模型架构
2.1 模型规格
Gemma Scope覆盖Gemma 2的三种规模:
| 模型规模 | 隐藏维度 | SAE特征数 | 放大倍数 | 目标稀疏度 |
|---|---|---|---|---|
| Gemma 2 2B | 2,560 | 16,384 | 4× | ~2% |
| Gemma 2 9B | 3,072 | 32,768 | 8× | ~3% |
| Gemma 2 27B | 4,608 | 65,536 | 8× | ~4% |
2.2 架构详情
JumpReLU SAE 架构:
输入 x (d_model维)
│
▼
┌─────────────────────────────────────┐
│ 线性编码器 │
│ W_enc: d_model → n_feat │
│ b_enc: 可学习偏置 │
└─────────────────────────────────────┘
│
▼
┌─────────────────────────────────────┐
│ JumpReLU 激活 │
│ f(z; b) = max(0, z - b) │
│ b: 可学习跳跃阈值 │
└─────────────────────────────────────┘
│
▼
稀疏特征 f(x) (平均~2-4%非零)
│
▼
┌─────────────────────────────────────┐
│ 线性解码器 │
│ W_dec: n_feat → d_model │
│ b_dec: 可学习偏置 │
└─────────────────────────────────────┘
│
▼
重建 x̂
2.3 训练配置
| 参数 | 值 |
|---|---|
| 优化器 | AdamW |
| 学习率 | 1e-4 |
| 权重衰减 | 0.01 |
| 批量大小 | 4,096 (tokens) |
| 热身步数 | 1,000 |
| 训练步数 | 100,000 |
| 死神经元惩罚 | 1.0 |
| Ghost Gradient | 启用 |
3. 特征分析结果
3.1 特征分布
Gemma Scope揭示了Gemma 2内部表示的丰富结构:
3.1.1 激活分布
特征激活分布(对数尺度):
频率
^
│ ╭──╮
│ ╱ ╲ ╭────╮
│ ╱ ╲ ╱ ╲ ╭───
│ ╱ ╲ ╱ ╲ ╱
│╱ ╳ ╳
└──────────────────────────────→ 激活值
0 2 4
观察:
- 大部分特征处于”沉默”状态(激活=0)
- 活跃特征呈长尾分布
- 高激活值特征通常对应语义显著的概念
3.1.2 层级组织
不同Transformer层捕获不同抽象级别的特征:
| 层范围 | 特征类型 | 示例 |
|---|---|---|
| 浅层 (1-8) | 词汇、语法特征 | 标点、时态、词性 |
| 中层 (9-16) | 语义、实体特征 | 人物、地点、组织 |
| 深层 (17-24) | 推理、规划特征 | 逻辑步骤、计划分解 |
| 深层 (25+) | 任务执行特征 | 响应格式、风格 |
3.1.3 概念对应
许多特征与人类可理解的概念直接对应:
| 特征ID | 层 | 概念描述 | 激活示例 |
|---|---|---|---|
layer_12/feat_2341 | 12 | 编程语言Python | ”def function():“ |
layer_15/feat_8902 | 15 | 地理位置-亚洲 | ”Tokyo, Japan” |
layer_18/feat_567 | 18 | 数学推导步骤 | ”therefore, x =“ |
layer_22/feat_12034 | 22 | 礼貌语言 | ”thank you” |
3.2 Feature Catalog
Gemma Scope提供了详尽的Feature Catalog,记录了每个特征的手动解释。
3.2.1 特征解释示例
{
"layer": 15,
"feature_index": 1234,
"description": "Programming language Python function definition",
"activation_pattern": {
"threshold": 2.5,
"max_activation": 15.3,
"typical_contexts": [
"def process_data():",
"async def fetch_url(url):",
"lambda x: x * 2"
]
},
"related_features": [
{"layer": 14, "index": 567, "relationship": "sub-concept"},
{"layer": 16, "index": 890, "relationship": "higher-abstraction"}
],
"manual_interpretation": {
"confidence": "high",
"notes": "Consistently activates for Python def/lambda keywords",
"known_limitations": "May also activate for similar syntax in Julia"
}
}3.2.2 特征聚类分析
Feature Catalog揭示了特征的聚类结构:
特征聚类(简化示意):
编程相关 ──┬── 语法特征 ──┬── Python
│ ├── JavaScript
│ └── C++
│
└── 语义特征 ──┬── 算法描述
└── 代码注释
语言相关 ──┬── 语法 ─────── 时态、人称
│
└── 语义 ─────── 情感、主题
4. 技术实现
4.1 快速使用
# 安装 saelens
!pip install saelens
from saelens import SAEEnsemble
# 加载 Gemma Scope
ensemble = SAEEnsemble.from_pretrained(
"google/gemma-scope-2b-pt-res",
model_name="gemma-2-2b"
)
# 获取特定层的SAE
sae = ensemble[15] # 第15层的SAE
# 分析单个激活
import torch
# 假设 model 是 Gemma 2 模型
def analyze_activation(model, tokenizer, text):
"""分析文本的激活特征"""
inputs = tokenizer(text, return_tensors="pt")
# 获取第15层残差流激活
with torch.no_grad():
# Hook到第15层
activation_cache = {}
def hook_fn(module, input, output):
activation_cache["residual"] = output[0]
handle = model.model.layers[15].register_forward_hook(hook_fn)
outputs = model(**inputs)
handle.remove()
# 获取最后一层token的激活
residual = activation_cache["residual"][0, -1] # [d_model]
# 通过SAE编码
features = sae.encode(residual.unsqueeze(0)) # [1, n_features]
recon = sae.decode(features)
# 获取活跃特征
active_indices = (features[0] > 0).nonzero().squeeze()
active_values = features[0, active_indices]
return {
"active_features": active_indices.tolist(),
"activation_values": active_values.tolist(),
"reconstruction_error": F.mse_loss(recon[0], residual).item()
}
# 示例分析
result = analyze_activation(
model, tokenizer,
"def calculate_fibonacci(n):"
)
print(f"活跃特征数: {len(result['active_features'])}")
print(f"重建误差: {result['reconstruction_error']:.4f}")4.2 批量分析
from tqdm import tqdm
def batch_feature_analysis(sae, model, dataset, n_samples=1000):
"""批量分析数据集的特征激活"""
feature_counts = torch.zeros(sae.cfg.n_features)
for batch in tqdm(dataset):
# 获取激活
acts = sae.encode(batch)
# 统计活跃特征
feature_counts += (acts > 0).sum(0)
# 计算活跃频率
frequencies = feature_counts / n_samples
return frequencies
# 分析不同数据集的特征偏好
results = {
"python_code": batch_feature_analysis(sae, model, python_ds),
"english_text": batch_feature_analysis(sae, model, english_ds),
"math": batch_feature_analysis(sae, model, math_ds),
}
# 找出差异最大的特征
for name1, freqs1 in results.items():
for name2, freqs2 in results.items():
if name1 < name2:
diff = (freqs1 - freqs2).abs()
top_diff = diff.topk(10)
print(f"\n{name1} vs {name2} 差异最大的特征:")
for idx, val in zip(*top_diff):
print(f" 特征 {idx}: 差异 {val:.4f}")4.3 特征重构
def reconstruct_from_features(sae, feature_indices, feature_values):
"""从指定特征重构激活"""
# 创建稀疏激活向量
acts = torch.zeros(1, sae.cfg.n_features)
acts[0, list(feature_indices)] = list(feature_values)
# 解码
recon = sae.decode(acts)
return recon.squeeze(0)
def interpret_feature_direction(sae, feature_idx, n_samples=50):
"""解释特征的方向含义"""
# 获取特征对应的解码器列
decoder_col = sae.W_dec.weight[:, feature_idx]
# 找到最激活这个特征的输入
max_acts = []
max_inputs = []
for _ in range(n_samples):
# 随机输入
x = torch.randn(1, sae.cfg.d_sae)
acts = sae.encode(x)
if acts[0, feature_idx] > 0:
max_acts.append(acts[0, feature_idx].item())
max_inputs.append(x.squeeze().clone())
if not max_inputs:
return {"error": "Feature never activates"}
# 分析解码器方向
return {
"feature_idx": feature_idx,
"decoder_direction": decoder_col.cpu().numpy(),
"activation_stats": {
"mean": sum(max_acts) / len(max_acts),
"max": max(max_acts),
"n_activating": len(max_acts)
}
}5. 与其他SAE的对比
5.1 规模对比
| 项目 | Gemma Scope | Anthropic SAEs | EleutherAI SAEs |
|---|---|---|---|
| 模型规模 | 2B-27B | 70B | 1B-7B |
| 模型数量 | 3 | 1 | 3 |
| 总特征数 | ~115K | ~34K | ~65K |
| 激活函数 | JumpReLU | ReLU | ReLU |
| 开源程度 | 完全开源 | 部分 | 完全开源 |
5.2 质量对比
| 指标 | Gemma Scope | Anthropic | EleutherAI |
|---|---|---|---|
| Loss Recovered | ~85% | ~80% | ~75% |
| Feature Sparsity | ~3% | ~5% | ~8% |
| Dead Feature Rate | <5% | ~10% | ~15% |
| 解释一致性 | 高 | 高 | 中 |
5.3 JumpReLU vs ReLU
import matplotlib.pyplot as plt
def compare_activations():
"""比较JumpReLU和ReLU的激活分布"""
# 模拟激活值
z = torch.randn(10000)
# ReLU激活
relu_acts = torch.clamp_min(z, 0)
# JumpReLU激活 (假设阈值=1.0)
threshold = 1.0
jump_relu_acts = torch.clamp_min(z - threshold, 0)
# 绘图
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
axes[0].hist(z.numpy(), bins=50, alpha=0.7)
axes[0].set_title("原始激活分布")
axes[0].set_xlabel("值")
axes[0].set_ylabel("频率")
axes[1].hist(relu_acts.numpy(), bins=50, alpha=0.7)
axes[1].set_title(f"ReLU激活 (稀疏度={(relu_acts>0).float().mean():.1%})")
axes[1].set_xlabel("激活值")
axes[2].hist(jump_relu_acts.numpy(), bins=50, alpha=0.7)
axes[2].set_title(f"JumpReLU激活 (阈值={threshold}, 稀疏度={(jump_relu_acts>0).float().mean():.1%})")
axes[2].set_xlabel("激活值")
plt.tight_layout()
plt.savefig("activation_comparison.png")
return {
"relu_sparsity": (relu_acts > 0).float().mean().item(),
"jumprelu_sparsity": (jump_relu_acts > 0).float().mean().item(),
}
stats = compare_activations()
print(f"ReLU稀疏度: {stats['relu_sparsity']:.1%}")
print(f"JumpReLU稀疏度: {stats['jumprelu_sparsity']:.1%}")6. 应用案例
6.1 特征可视化
def visualize_feature_activation(sae, feature_idx, texts):
"""可视化特定特征的激活模式"""
from transformers import AutoTokenizer
import torch.nn.functional as F
tokenizer = AutoTokenizer.from_pretrained("gemma-2-2b")
activations = []
positions = []
for text in texts:
inputs = tokenizer(text, return_tensors="pt")
# 获取tokenized后的位置
tokens = inputs["input_ids"][0]
# Hook获取激活
cache = {}
def hook_fn(module, input, output):
cache["out"] = output[0]
handle = sae.register_forward_hook(hook_fn)
# ... (需要完整的模型前向传播)
handle.remove()
# 分析该特征在哪些位置激活
acts = cache["out"][0] # [seq_len, d_model]
feat_acts = sae.encode(acts) # [seq_len, n_features]
feat_activation = feat_acts[:, feature_idx].cpu()
activations.append(feat_activation)
positions.append(len(tokens))
# 可视化
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(12, 4))
offset = 0
for i, (acts, pos) in enumerate(zip(activations, positions)):
ax.bar(range(offset, offset + pos), acts.numpy(), alpha=0.7)
offset += pos + 2
ax.set_xlabel("Token位置")
ax.set_ylabel("特征激活")
ax.set_title(f"特征 {feature_idx} 在不同文本中的激活")
ax.axhline(y=sae.cfg.threshold, color='r', linestyle='--', label='阈值')
ax.legend()
plt.tight_layout()
plt.savefig(f"feature_{feature_idx}_visualization.png")6.2 特征操控
def steer_model_with_feature(sae, model, feature_idx, direction, strength=2.0):
"""
使用特定特征方向操控模型输出
Args:
sae: 训练好的SAE
model: 原始语言模型
feature_idx: 要操控的特征索引
direction: 操控方向 (+1 增加, -1 抑制)
strength: 操控强度
"""
def modified_hook(module, input, output):
# 原始激活
acts = output[0]
# 通过SAE编码
sparse_acts = sae.encode(acts)
# 修改目标特征
sparse_acts[:, feature_idx] += direction * strength
# 通过SAE解码
modified_acts = sae.decode(sparse_acts)
return (modified_acts,) + output[1:]
return modified_hook
# 使用示例:增加Python代码生成倾向
python_feature_idx = 1234 # 需要通过分析确定
hook = steer_model_with_feature(
sae, model,
feature_idx=python_feature_idx,
direction=+1,
strength=3.0
)
# 注册hook
handle = model.layers[15].register_forward_hook(hook)
# 生成文本
output = model.generate(**inputs, max_new_tokens=100)
# 移除hook
handle.remove()
print(tokenizer.decode(output[0]))6.3 电路发现
def find_circuit_for_feature(sae, model, feature_idx, layer_range=(0, 26)):
"""
找到与特定特征相关的注意力头和残差连接
"""
from collections import defaultdict
contributions = defaultdict(list)
# Hook每个注意力头和MLP
hooks = {}
for layer in range(*layer_range):
# 注意力头输出
def make_attn_hook(l):
def hook_fn(module, input, output):
# output: (output, None, None, None) for attention
attn_output = output[0]
# 计算每个头对目标特征的贡献
acts = sae.encode(attn_output)
contrib = acts[:, feature_idx].mean()
contributions[f"layer_{l}_attn"].append(contrib.item())
return hook_fn
hooks[f"attn_{layer}"] = make_attn_hook(layer)
# MLP输出
def make_mlp_hook(l):
def hook_fn(module, input, output):
acts = sae.encode(output)
contrib = acts[:, feature_idx].mean()
contributions[f"layer_{l}_mlp"].append(contrib.item())
return hook_fn
hooks[f"mlp_{layer}"] = make_mlp_hook(layer)
# 注册所有hooks
hook_handles = []
for name, hook_fn in hooks.items():
if "attn" in name:
layer = int(name.split("_")[1])
handle = model.model.layers[layer].self_attn.register_forward_hook(hook_fn)
else:
layer = int(name.split("_")[1])
handle = model.model.layers[layer].mlp.register_forward_hook(hook_fn)
hook_handles.append(handle)
# 运行推理
with torch.no_grad():
model(**inputs)
# 清理
for handle in hook_handles:
handle.remove()
# 找出贡献最大的组件
circuit = {}
for name, contribs in contributions.items():
if contribs:
circuit[name] = sum(contribs) / len(contribs)
# 排序
sorted_circuit = sorted(circuit.items(), key=lambda x: abs(x[1]), reverse=True)
return sorted_circuit[:10] # 返回前10个最重要组件7. 局限性与未来方向
7.1 当前局限性
| 局限性 | 描述 | 影响 |
|---|---|---|
| 非单一特征 | 许多特征并非”原子” | 可解释性打折 |
| 上下文依赖 | 特征激活依赖于上下文 | 分析复杂 |
| 计算成本 | 编码所有层成本高 | 大规模分析受限 |
| 覆盖范围 | 仍有未解释的激活 | 信息丢失 |
7.2 正在进行的研究
| 研究方向 | 描述 |
|---|---|
| 特征层次聚类 | 自动发现特征的自然分组 |
| 跨层追踪 | 追踪信息在层间的流动 |
| 因果干预 | 使用激活 patching 验证因果关系 |
| 组合特征 | 建模特征的交互和组合 |
8. 使用建议
8.1 最佳实践
- 从浅层开始:浅层特征更直观,易于理解
- 使用活跃特征:只分析激活频率 >0.1% 的特征
- 交叉验证:用多个输入验证特征解释
- 结合注意力分析:SAE特征 + 注意力模式 = 更完整理解
8.2 常见陷阱
| 陷阱 | 说明 | 避免方法 |
|---|---|---|
| 过度解读 | 将随机模式解读为有意义的特征 | 多样本验证 |
| 忽略上下文 | 忽视激活的上下文依赖 | 分析激活位置 |
| 单一方向 | 只关注正激活 | 同时分析正负激活 |
| 孤立分析 | 不考虑特征间的关系 | 构建特征图 |
8.3 推荐工具链
# 推荐的工具组合
tools = {
"sae_training": "saelens", # SAELens库
"feature_exploration": "neuronpedia", # 在线可视化
"circuit_analysis": "transformerlens", # TransformerLens
"interpretability": "circuits", # 自定义分析
"visualization": "matplotlib", # 绘图
"clustering": "sklearn", # 特征聚类
}9. 参考文献
相关资源
| 资源 | 链接 |
|---|---|
| Hugging Face模型 | gemma-scope-2b-pt-res |
| SAELens库 | github.com/jbloomGIT/SAELens |
| Neuronpedia | neuronpedia.org |
| 官方博客 | deepmind.google/blog/gemma-scope |
| Feature Catalog | huggingface.co/google/gemma-scope-feature-catalog |
Footnotes
-
Templeton et al. “Gemma Scope: Open Sparse Autoencoders Everywhere All at Once on Gemma 2.” Google DeepMind, 2024. ↩