Concept Bottleneck SAEs
概述
Concept Bottleneck Sparse Autoencoders (CB-SAE)1是一种创新的SAE架构,它将概念瓶颈(Concept Bottleneck)的思想与SAE相结合,产生既可解释又可操控的特征表示。
核心创新:
- 特征与人类可理解的概念直接对应
- 支持直观的模型控制
- 在LVLMs和图像生成任务上显著提升可解释性和可操控性
1. 背景与动机
1.1 标准SAE的局限性
标准SAE虽然能够将叠加的表示分解为稀疏特征,但存在两个关键问题:
| 问题 | 描述 | 影响 |
|---|---|---|
| 特征解释困难 | 特征与概念的对应不明确 | 难以理解模型行为 |
| 操控不直观 | 修改特征值的效果难以预测 | 限制了安全应用 |
1.2 CB-SAE的解决方案
CB-SAE通过引入概念瓶颈层来解决这些问题:
标准SAE:
激活 x → [编码器] → 稀疏特征 f → [解码器] → 重建 x̂
↓
无语义约束
CB-SAE:
激活 x → [编码器] → 概念特征 c → [解码器] → 重建 x̂
↓
人类可理解的概念
2. 技术细节
2.1 架构设计
CB-SAE的核心思想是在编码器和解码器之间引入概念瓶颈层:
CB-SAE 架构:
输入激活 x (d维)
│
▼
┌─────────────────────────────────────┐
│ 线性编码器 E │
│ E: d → m │
└─────────────────────────────────────┘
│
▼
┌─────────────────────────────────────┐
│ 概念瓶颈层 C │
│ C: m → k (k << m) │
│ 概念特征 = [concept_1, ..., c_k] │
└─────────────────────────────────────┘
│
▼
┌─────────────────────────────────────┐
│ 概念增强解码器 D │
│ D: k → d │
│ 使用概念信息指导重建 │
└─────────────────────────────────────┘
│
▼
重建 x̂
2.2 数学形式
编码阶段:
其中 是标准SAE编码, 是概念特征。
解码阶段:
第二项是一致性损失,鼓励解码器同时利用概念信息和残差信息。
2.3 概念监督
CB-SAE支持不同级别的概念监督:
2.3.1 弱监督
只提供部分概念标签:
weak_supervision = {
"concepts": ["is_python_code", "has_entity", "is_sentiment"],
"labels": [1, 0, 1], # 部分标签
"mask": [True, False, True], # 掩码
}2.3.2 强监督
提供完整的概念标签:
strong_supervision = {
"concepts": ["is_python_code", "has_entity", "is_sentiment"],
"labels": [1, 1, 0], # 完整标签
"mask": [True, True, True],
}2.4 PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class ConceptBottleneckSAE(nn.Module):
"""Concept Bottleneck Sparse Autoencoder"""
def __init__(
self,
d_model: int,
n_features: int,
n_concepts: int,
concept_names: list[str],
lambda_recon: float = 1.0,
lambda_concept: float = 0.1,
lambda_consistency: float = 0.01,
):
super().__init__()
self.d_model = d_model
self.n_features = n_features
self.n_concepts = n_concepts
self.concept_names = concept_names
self.lambda_recon = lambda_recon
self.lambda_concept = lambda_concept
self.lambda_consistency = lambda_consistency
# 标准SAE编码器
self.W_enc = nn.Linear(d_model, n_features, bias=False)
self.b_enc = nn.Parameter(torch.zeros(n_features))
# 概念瓶颈层
self.W_concept = nn.Linear(n_features, n_concepts, bias=False)
self.b_concept = nn.Parameter(torch.zeros(n_concepts))
# 概念解码器
self.W_dec = nn.Linear(n_concepts, d_model, bias=False)
self.b_dec = nn.Parameter(torch.zeros(d_model))
# 残差投影(用于一致性损失)
self.residual_proj = nn.Linear(n_features, n_concepts, bias=False)
# 激活函数
self.activation = nn.ReLU()
def encode_to_features(self, x: torch.Tensor) -> torch.Tensor:
"""编码到稀疏特征"""
pre_acts = self.W_enc(x) + self.b_enc
return self.activation(pre_acts)
def features_to_concepts(self, features: torch.Tensor) -> torch.Tensor:
"""稀疏特征到概念特征"""
return torch.sigmoid(self.W_concept(features) + self.b_concept)
def concepts_to_output(self, concepts: torch.Tensor) -> torch.Tensor:
"""概念到输出"""
return self.W_dec(concepts) + self.b_dec
def forward(
self,
x: torch.Tensor,
concept_labels: Optional[torch.Tensor] = None,
concept_mask: Optional[torch.Tensor] = None,
) -> dict:
"""
前向传播
Args:
x: 输入激活 [batch_size, d_model]
concept_labels: 概念标签 [batch_size, n_concepts]
concept_mask: 概念掩码 [batch_size, n_concepts]
Returns:
包含各种损失的字典
"""
# 编码阶段
features = self.encode_to_features(x)
# 概念瓶颈
concepts = self.features_to_concepts(features)
# 解码阶段
recon = self.concepts_to_output(concepts)
# 重构损失
recon_loss = F.mse_loss(recon, x)
# 概念监督损失(如果有标签)
if concept_labels is not None:
if concept_mask is not None:
concept_loss = F.binary_cross_entropy(
concepts[concept_mask],
concept_labels[concept_mask]
)
else:
concept_loss = F.binary_cross_entropy(concepts, concept_labels)
else:
concept_loss = torch.tensor(0.0, device=x.device)
# 一致性损失:概念应该与特征一致
projected_features = self.residual_proj(features)
consistency_loss = F.mse_loss(concepts, projected_features)
# L1稀疏性损失
l1_loss = features.abs().mean()
# 总损失
total_loss = (
self.lambda_recon * recon_loss +
self.lambda_concept * concept_loss +
self.lambda_consistency * consistency_loss +
0.001 * l1_loss
)
return {
"reconstruction": recon,
"features": features,
"concepts": concepts,
"recon_loss": recon_loss,
"concept_loss": concept_loss,
"consistency_loss": consistency_loss,
"l1_loss": l1_loss,
"total_loss": total_loss,
}
def steer_concepts(
self,
x: torch.Tensor,
concept_indices: list[int],
concept_values: list[float],
) -> torch.Tensor:
"""
通过修改概念值来操控模型输出
Args:
x: 输入激活
concept_indices: 要修改的概念索引
concept_values: 目标概念值
Returns:
操控后的重建输出
"""
# 编码
features = self.encode_to_features(x)
concepts = self.features_to_concepts(features)
# 修改指定概念
for idx, val in zip(concept_indices, concept_values):
concepts[:, idx] = val
# 解码
recon = self.concepts_to_output(concepts)
return recon
def get_concept_explanations(
self,
concepts: torch.Tensor,
threshold: float = 0.5,
) -> list[dict]:
"""
获取激活概念的文本解释
Args:
concepts: 概念激活 [batch_size, n_concepts]
threshold: 激活阈值
Returns:
每个样本的激活概念列表
"""
batch_size = concepts.shape[0]
explanations = []
for i in range(batch_size):
active_concepts = []
for j, (name, val) in enumerate(
zip(self.concept_names, concepts[i].cpu().tolist())
):
if val > threshold:
active_concepts.append({
"concept": name,
"value": val,
})
explanations.append(active_concepts)
return explanations
class CB_SAE_With_Steering(nn.Module):
"""
增强版CB-SAE,支持更灵活的操控
"""
def __init__(
self,
d_model: int,
n_features: int,
n_concepts: int,
concept_names: list[str],
concept_descriptions: dict[str, str],
):
super().__init__()
self.base_sae = ConceptBottleneckSAE(
d_model, n_features, n_concepts, concept_names
)
self.concept_descriptions = concept_descriptions
# 概念方向:用于操控
self.concept_directions = nn.Parameter(
torch.randn(n_concepts, d_model) * 0.02
)
def concept_directions_steering(
self,
x: torch.Tensor,
concept_indices: list[int],
strengths: list[float],
) -> torch.Tensor:
"""
使用概念方向进行精细化操控
Args:
x: 输入激活
concept_indices: 概念索引
strengths: 操控强度(正值增强,负值抑制)
"""
# 基础重建
outputs = self.base_sae(x)
recon = outputs["reconstruction"].clone()
concepts = outputs["concepts"]
# 应用概念方向偏移
for idx, strength in zip(concept_indices, strengths):
direction = self.concept_directions[idx]
recon = recon + strength * direction * concepts[:, idx:idx+1]
return recon
def generate_steering_description(
self,
concept_idx: int,
direction: str, # "increase" or "decrease"
) -> str:
"""生成操控描述"""
concept_name = self.base_sae.concept_names[concept_idx]
description = self.concept_descriptions.get(concept_name, "Unknown concept")
if direction == "increase":
return f"增强 {concept_name} ({description}) 的激活"
else:
return f"抑制 {concept_name} ({description}) 的激活"3. 与标准SAE的对比
3.1 特性对比
| 特性 | 标准SAE | CB-SAE |
|---|---|---|
| 特征可解释性 | 低 | 高(概念直接对应) |
| 操控直观性 | 低 | 高(概念级操控) |
| 需要概念监督 | 否 | 是(可选) |
| 重建质量 | 高 | 略低(瓶颈约束) |
| 规模化难度 | 中 | 中(需要概念定义) |
3.2 定量对比
基于论文实验结果:
| 指标 | 标准SAE | CB-SAE | 提升 |
|---|---|---|---|
| 概念对应准确率 | 45% | 78% | +73% |
| 操控成功率高 | 62% | 89% | +44% |
| Loss Recovered | 0.85 | 0.82 | -3.5% |
| 用户理解得分 | 3.2/5 | 4.1/5 | +28% |
3.3 适用场景对比
选择决策:
┌─────────────────────┐
│ 需要概念级操控吗? │
└──────────┬──────────┘
│
┌───────────────┴───────────────┐
▼ ▼
Yes No
│ │
┌─────────┴─────────┐ ┌───────┴───────┐
│ 有概念标签/描述吗? │ │ 标准SAE足够 │
└─────────┬─────────┘ └───────────────┘
│
┌─────────┴─────────┐
▼ ▼
Yes No
│ │
CB-SAE CB-SAE with
(强监督) 弱监督/无监督
4. 应用案例
4.1 LVLM幻觉缓解
def mitigate_hallucination_with_cb_sae(
model,
cb_sae,
image_features: torch.Tensor,
prompt: str,
) -> str:
"""
使用CB-SAE减少LVLM的视觉幻觉
Args:
model: 视觉-语言模型
cb_sae: CB-SAE
image_features: 图像特征
prompt: 文本提示
Returns:
减少幻觉的响应
"""
# 编码图像特征
with torch.no_grad():
features = cb_sae.encode_to_features(image_features)
concepts = cb_sae.features_to_concepts(features)
# 检查可能的幻觉概念
hallucination_concepts = [
cb_sae.concept_names.index("object_not_in_image"),
cb_saude.concept_names.index("wrong_attribute"),
]
# 如果检测到幻觉概念,降低其激活
for h_concept in hallucination_concepts:
if concepts[0, h_concept] > 0.5:
concepts[0, h_concept] *= 0.3 # 降低激活
# 重建修改后的特征
modified_features = cb_sae.concepts_to_output(concepts)
# 使用修改后的特征生成响应
response = model.generate(
image_features=modified_features,
text=prompt,
)
return response4.2 图像生成控制
def control_image_generation_with_cb_sae(
diffusion_model,
cb_sae,
concept_interventions: dict[str, float],
) -> Image.Image:
"""
使用CB-SAE概念干预控制图像生成
Args:
diffusion_model: 扩散生成模型
cb_sae: CB-SAE用于特征控制
concept_interventions: 概念干预字典
e.g., {"is_outdoor": 1.0, "is_daytime": 0.0}
Returns:
生成的图像
"""
# 获取概念索引
concept_indices = []
concept_values = []
for concept_name, value in concept_interventions.items():
if concept_name in cb_sae.concept_names:
idx = cb_sae.concept_names.index(concept_name)
concept_indices.append(idx)
concept_values.append(value)
# 在扩散过程中注入概念干预
def concept_intervention_hook(latents, timestep):
# 通过CB-SAE编码
features = cb_sae.encode_to_features(latents)
concepts = cb_sae.features_to_concepts(features)
# 应用干预
for idx, val in zip(concept_indices, concept_values):
concepts[:, idx] = val
# 重建
modified_latents = cb_sae.concepts_to_output(concepts)
return modified_latents
# 生成图像
image = diffusion_model.generate(
intervention_hook=concept_intervention_hook
)
return image4.3 安全内容过滤
class CB_SAE_SafetyFilter:
"""基于CB-SAE的安全内容过滤器"""
def __init__(self, cb_sae, safety_threshold: float = 0.7):
self.cb_sae = cb_sae
self.safety_threshold = safety_threshold
# 安全相关概念
self.safety_concepts = [
"contains_harmful_content",
"promotes_violence",
"involves_illegal_activity",
]
self.safety_indices = [
cb_sae.concept_names.index(c)
for c in self.safety_concepts
if c in cb_sae.concept_names
]
def check_safety(self, text_features: torch.Tensor) -> dict:
"""
检查文本内容安全性
Returns:
安全评估结果
"""
with torch.no_grad():
concepts = self.cb_sae.features_to_concepts(
self.cb_sae.encode_to_features(text_features)
)
safety_scores = {
concept: concepts[0, idx].item()
for concept, idx in zip(self.safety_concepts, self.safety_indices)
}
is_safe = all(score < self.safety_threshold for score in safety_scores.values())
return {
"is_safe": is_safe,
"safety_scores": safety_scores,
"flagged_concepts": [
concept for concept, score in safety_scores.items()
if score >= self.safety_threshold
],
}
def filter_content(
self,
text_features: torch.Tensor,
model_output: str,
) -> str:
"""过滤不安全内容"""
safety_check = self.check_safety(text_features)
if not safety_check["is_safe"]:
# 返回安全响应
return "I cannot assist with this request as it may involve harmful content."
return model_output5. 实验结果
5.1 可解释性评估
| 方法 | 概念对应准确率 | 用户理解得分 |
|---|---|---|
| 标准SAE | 45% | 3.2/5 |
| CB-SAE (无监督) | 61% | 3.6/5 |
| CB-SAE (弱监督) | 78% | 4.1/5 |
5.2 可操控性评估
| 任务 | 标准SAE成功率 | CB-SAE成功率 |
|---|---|---|
| 情感增强 | 58% | 87% |
| 风格转换 | 52% | 82% |
| 事实修正 | 64% | 91% |
| 安全干预 | 71% | 95% |
5.3 重建质量
| 模型 | Loss Recovered | 概念一致性 |
|---|---|---|
| 标准SAE | 0.85 | N/A |
| CB-SAE | 0.82 | 0.91 |
| CB-SAE+ | 0.84 | 0.94 |
6. 实践指南
6.1 概念定义
定义有意义的概念是CB-SAE成功的关键:
# 概念定义示例
concept_definitions = {
# 语义概念
"is_python_code": "文本是Python编程语言代码",
"is_english": "文本是英文",
"contains_date": "文本包含日期信息",
# 情感概念
"has_positive_sentiment": "文本表达正面情感",
"has_negative_sentiment": "文本表达负面情感",
# 安全概念
"is_safe_content": "内容是安全的",
"requires_carefulness": "内容需要谨慎处理",
# 任务概念
"is_question": "文本是问句",
"is_factual": "文本是事实陈述",
"is_opinion": "文本是个人观点",
}6.2 训练配置
# 推荐训练配置
training_config = {
# 数据配置
"batch_size": 64,
"n_epochs": 100,
"learning_rate": 1e-4,
# 概念瓶颈配置
"n_concepts": 256, # 概念数量
"concept_ratio": 0.25, # 概念/特征比例
# 损失权重
"lambda_recon": 1.0,
"lambda_concept": 0.1, # 概念监督权重
"lambda_consistency": 0.01,
"lambda_l1": 0.001,
# 学习率调度
"warmup_steps": 1000,
"cosine_decay": True,
}6.3 概念自动发现
def discover_concepts_with_llm(
sae,
dataset,
n_concepts: int = 256,
n_samples_per_concept: int = 50,
) -> dict[str, str]:
"""
使用LLM自动发现概念
"""
from anthropic import Anthropic
client = Anthropic()
# 收集高频激活的文本片段
feature_texts = collect_top_activations(sae, dataset, n_concepts, n_samples_per_concept)
# 使用LLM为每个特征生成概念描述
concepts = {}
for feat_idx, texts in feature_texts.items():
prompt = f"""Analyze these text examples that activate the same feature:
{chr(10).join(texts)}
What single semantic concept do these examples share?
Provide a brief, descriptive name for this concept.
Respond in JSON format:
{{
"concept_name": "descriptive_name",
"description": "brief description"
}}"""
response = client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=100,
messages=[{"role": "user", "content": prompt}]
)
result = json.loads(response.content[0].text)
concepts[f"feature_{feat_idx}"] = result
return concepts7. 局限性与未来方向
7.1 当前局限性
| 局限性 | 描述 | 影响 |
|---|---|---|
| 概念定义主观 | 概念定义需要人工设计 | 可扩展性受限 |
| 监督成本 | 需要概念标签数据 | 获取成本高 |
| 概念重叠 | 概念之间可能有重叠 | 操控可能不精确 |
7.2 未来方向
| 方向 | 描述 |
|---|---|
| 自动概念发现 | 使用LLM自动从特征中提取概念 |
| 层次概念 | 构建概念的层次结构 |
| 多模态概念 | 跨模态的概念对应 |
| 动态概念 | 根据上下文自适应概念 |
8. 参考
相关资源
Footnotes
-
“Interpretable and Steerable Concept Bottleneck Sparse Autoencoders.” arXiv:2512.10805, 2025. ↩