Concept Bottleneck SAEs

概述

Concept Bottleneck Sparse Autoencoders (CB-SAE)1是一种创新的SAE架构,它将概念瓶颈(Concept Bottleneck)的思想与SAE相结合,产生既可解释可操控的特征表示。

核心创新

  • 特征与人类可理解的概念直接对应
  • 支持直观的模型控制
  • 在LVLMs和图像生成任务上显著提升可解释性和可操控性

1. 背景与动机

1.1 标准SAE的局限性

标准SAE虽然能够将叠加的表示分解为稀疏特征,但存在两个关键问题:

问题描述影响
特征解释困难特征与概念的对应不明确难以理解模型行为
操控不直观修改特征值的效果难以预测限制了安全应用

1.2 CB-SAE的解决方案

CB-SAE通过引入概念瓶颈层来解决这些问题:

标准SAE:
  激活 x → [编码器] → 稀疏特征 f → [解码器] → 重建 x̂
           ↓
      无语义约束

CB-SAE:
  激活 x → [编码器] → 概念特征 c → [解码器] → 重建 x̂
           ↓
      人类可理解的概念

2. 技术细节

2.1 架构设计

CB-SAE的核心思想是在编码器和解码器之间引入概念瓶颈层

CB-SAE 架构:

输入激活 x (d维)
    │
    ▼
┌─────────────────────────────────────┐
│  线性编码器 E                        │
│  E: d → m                            │
└─────────────────────────────────────┘
    │
    ▼
┌─────────────────────────────────────┐
│  概念瓶颈层 C                        │
│  C: m → k (k << m)                  │
│  概念特征 = [concept_1, ..., c_k]   │
└─────────────────────────────────────┘
    │
    ▼
┌─────────────────────────────────────┐
│  概念增强解码器 D                     │
│  D: k → d                            │
│  使用概念信息指导重建                  │
└─────────────────────────────────────┘
    │
    ▼
重建 x̂

2.2 数学形式

编码阶段

其中 是标准SAE编码, 是概念特征。

解码阶段

第二项是一致性损失,鼓励解码器同时利用概念信息和残差信息。

2.3 概念监督

CB-SAE支持不同级别的概念监督:

2.3.1 弱监督

只提供部分概念标签:

weak_supervision = {
    "concepts": ["is_python_code", "has_entity", "is_sentiment"],
    "labels": [1, 0, 1],  # 部分标签
    "mask": [True, False, True],  # 掩码
}

2.3.2 强监督

提供完整的概念标签:

strong_supervision = {
    "concepts": ["is_python_code", "has_entity", "is_sentiment"],
    "labels": [1, 1, 0],  # 完整标签
    "mask": [True, True, True],
}

2.4 PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
 
class ConceptBottleneckSAE(nn.Module):
    """Concept Bottleneck Sparse Autoencoder"""
    
    def __init__(
        self,
        d_model: int,
        n_features: int,
        n_concepts: int,
        concept_names: list[str],
        lambda_recon: float = 1.0,
        lambda_concept: float = 0.1,
        lambda_consistency: float = 0.01,
    ):
        super().__init__()
        
        self.d_model = d_model
        self.n_features = n_features
        self.n_concepts = n_concepts
        self.concept_names = concept_names
        
        self.lambda_recon = lambda_recon
        self.lambda_concept = lambda_concept
        self.lambda_consistency = lambda_consistency
        
        # 标准SAE编码器
        self.W_enc = nn.Linear(d_model, n_features, bias=False)
        self.b_enc = nn.Parameter(torch.zeros(n_features))
        
        # 概念瓶颈层
        self.W_concept = nn.Linear(n_features, n_concepts, bias=False)
        self.b_concept = nn.Parameter(torch.zeros(n_concepts))
        
        # 概念解码器
        self.W_dec = nn.Linear(n_concepts, d_model, bias=False)
        self.b_dec = nn.Parameter(torch.zeros(d_model))
        
        # 残差投影(用于一致性损失)
        self.residual_proj = nn.Linear(n_features, n_concepts, bias=False)
        
        # 激活函数
        self.activation = nn.ReLU()
        
    def encode_to_features(self, x: torch.Tensor) -> torch.Tensor:
        """编码到稀疏特征"""
        pre_acts = self.W_enc(x) + self.b_enc
        return self.activation(pre_acts)
    
    def features_to_concepts(self, features: torch.Tensor) -> torch.Tensor:
        """稀疏特征到概念特征"""
        return torch.sigmoid(self.W_concept(features) + self.b_concept)
    
    def concepts_to_output(self, concepts: torch.Tensor) -> torch.Tensor:
        """概念到输出"""
        return self.W_dec(concepts) + self.b_dec
    
    def forward(
        self,
        x: torch.Tensor,
        concept_labels: Optional[torch.Tensor] = None,
        concept_mask: Optional[torch.Tensor] = None,
    ) -> dict:
        """
        前向传播
        
        Args:
            x: 输入激活 [batch_size, d_model]
            concept_labels: 概念标签 [batch_size, n_concepts]
            concept_mask: 概念掩码 [batch_size, n_concepts]
        
        Returns:
            包含各种损失的字典
        """
        # 编码阶段
        features = self.encode_to_features(x)
        
        # 概念瓶颈
        concepts = self.features_to_concepts(features)
        
        # 解码阶段
        recon = self.concepts_to_output(concepts)
        
        # 重构损失
        recon_loss = F.mse_loss(recon, x)
        
        # 概念监督损失(如果有标签)
        if concept_labels is not None:
            if concept_mask is not None:
                concept_loss = F.binary_cross_entropy(
                    concepts[concept_mask],
                    concept_labels[concept_mask]
                )
            else:
                concept_loss = F.binary_cross_entropy(concepts, concept_labels)
        else:
            concept_loss = torch.tensor(0.0, device=x.device)
        
        # 一致性损失:概念应该与特征一致
        projected_features = self.residual_proj(features)
        consistency_loss = F.mse_loss(concepts, projected_features)
        
        # L1稀疏性损失
        l1_loss = features.abs().mean()
        
        # 总损失
        total_loss = (
            self.lambda_recon * recon_loss +
            self.lambda_concept * concept_loss +
            self.lambda_consistency * consistency_loss +
            0.001 * l1_loss
        )
        
        return {
            "reconstruction": recon,
            "features": features,
            "concepts": concepts,
            "recon_loss": recon_loss,
            "concept_loss": concept_loss,
            "consistency_loss": consistency_loss,
            "l1_loss": l1_loss,
            "total_loss": total_loss,
        }
    
    def steer_concepts(
        self,
        x: torch.Tensor,
        concept_indices: list[int],
        concept_values: list[float],
    ) -> torch.Tensor:
        """
        通过修改概念值来操控模型输出
        
        Args:
            x: 输入激活
            concept_indices: 要修改的概念索引
            concept_values: 目标概念值
        
        Returns:
            操控后的重建输出
        """
        # 编码
        features = self.encode_to_features(x)
        concepts = self.features_to_concepts(features)
        
        # 修改指定概念
        for idx, val in zip(concept_indices, concept_values):
            concepts[:, idx] = val
        
        # 解码
        recon = self.concepts_to_output(concepts)
        
        return recon
    
    def get_concept_explanations(
        self,
        concepts: torch.Tensor,
        threshold: float = 0.5,
    ) -> list[dict]:
        """
        获取激活概念的文本解释
        
        Args:
            concepts: 概念激活 [batch_size, n_concepts]
            threshold: 激活阈值
        
        Returns:
            每个样本的激活概念列表
        """
        batch_size = concepts.shape[0]
        explanations = []
        
        for i in range(batch_size):
            active_concepts = []
            for j, (name, val) in enumerate(
                zip(self.concept_names, concepts[i].cpu().tolist())
            ):
                if val > threshold:
                    active_concepts.append({
                        "concept": name,
                        "value": val,
                    })
            explanations.append(active_concepts)
        
        return explanations
 
 
class CB_SAE_With_Steering(nn.Module):
    """
    增强版CB-SAE,支持更灵活的操控
    """
    
    def __init__(
        self,
        d_model: int,
        n_features: int,
        n_concepts: int,
        concept_names: list[str],
        concept_descriptions: dict[str, str],
    ):
        super().__init__()
        
        self.base_sae = ConceptBottleneckSAE(
            d_model, n_features, n_concepts, concept_names
        )
        self.concept_descriptions = concept_descriptions
        
        # 概念方向:用于操控
        self.concept_directions = nn.Parameter(
            torch.randn(n_concepts, d_model) * 0.02
        )
    
    def concept_directions_steering(
        self,
        x: torch.Tensor,
        concept_indices: list[int],
        strengths: list[float],
    ) -> torch.Tensor:
        """
        使用概念方向进行精细化操控
        
        Args:
            x: 输入激活
            concept_indices: 概念索引
            strengths: 操控强度(正值增强,负值抑制)
        """
        # 基础重建
        outputs = self.base_sae(x)
        recon = outputs["reconstruction"].clone()
        concepts = outputs["concepts"]
        
        # 应用概念方向偏移
        for idx, strength in zip(concept_indices, strengths):
            direction = self.concept_directions[idx]
            recon = recon + strength * direction * concepts[:, idx:idx+1]
        
        return recon
    
    def generate_steering_description(
        self,
        concept_idx: int,
        direction: str,  # "increase" or "decrease"
    ) -> str:
        """生成操控描述"""
        concept_name = self.base_sae.concept_names[concept_idx]
        description = self.concept_descriptions.get(concept_name, "Unknown concept")
        
        if direction == "increase":
            return f"增强 {concept_name} ({description}) 的激活"
        else:
            return f"抑制 {concept_name} ({description}) 的激活"

3. 与标准SAE的对比

3.1 特性对比

特性标准SAECB-SAE
特征可解释性高(概念直接对应)
操控直观性高(概念级操控)
需要概念监督是(可选)
重建质量略低(瓶颈约束)
规模化难度中(需要概念定义)

3.2 定量对比

基于论文实验结果:

指标标准SAECB-SAE提升
概念对应准确率45%78%+73%
操控成功率高62%89%+44%
Loss Recovered0.850.82-3.5%
用户理解得分3.2/54.1/5+28%

3.3 适用场景对比

选择决策:

                    ┌─────────────────────┐
                    │ 需要概念级操控吗?    │
                    └──────────┬──────────┘
                              │
              ┌───────────────┴───────────────┐
              ▼                               ▼
           Yes                              No
              │                               │
    ┌─────────┴─────────┐           ┌───────┴───────┐
    │ 有概念标签/描述吗? │           │ 标准SAE足够   │
    └─────────┬─────────┘           └───────────────┘
              │
    ┌─────────┴─────────┐
    ▼                   ▼
 Yes                   No
  │                   │
CB-SAE         CB-SAE with
(强监督)        弱监督/无监督

4. 应用案例

4.1 LVLM幻觉缓解

def mitigate_hallucination_with_cb_sae(
    model,
    cb_sae,
    image_features: torch.Tensor,
    prompt: str,
) -> str:
    """
    使用CB-SAE减少LVLM的视觉幻觉
    
    Args:
        model: 视觉-语言模型
        cb_sae: CB-SAE
        image_features: 图像特征
        prompt: 文本提示
    
    Returns:
        减少幻觉的响应
    """
    # 编码图像特征
    with torch.no_grad():
        features = cb_sae.encode_to_features(image_features)
        concepts = cb_sae.features_to_concepts(features)
    
    # 检查可能的幻觉概念
    hallucination_concepts = [
        cb_sae.concept_names.index("object_not_in_image"),
        cb_saude.concept_names.index("wrong_attribute"),
    ]
    
    # 如果检测到幻觉概念,降低其激活
    for h_concept in hallucination_concepts:
        if concepts[0, h_concept] > 0.5:
            concepts[0, h_concept] *= 0.3  # 降低激活
    
    # 重建修改后的特征
    modified_features = cb_sae.concepts_to_output(concepts)
    
    # 使用修改后的特征生成响应
    response = model.generate(
        image_features=modified_features,
        text=prompt,
    )
    
    return response

4.2 图像生成控制

def control_image_generation_with_cb_sae(
    diffusion_model,
    cb_sae,
    concept_interventions: dict[str, float],
) -> Image.Image:
    """
    使用CB-SAE概念干预控制图像生成
    
    Args:
        diffusion_model: 扩散生成模型
        cb_sae: CB-SAE用于特征控制
        concept_interventions: 概念干预字典
            e.g., {"is_outdoor": 1.0, "is_daytime": 0.0}
    
    Returns:
        生成的图像
    """
    # 获取概念索引
    concept_indices = []
    concept_values = []
    
    for concept_name, value in concept_interventions.items():
        if concept_name in cb_sae.concept_names:
            idx = cb_sae.concept_names.index(concept_name)
            concept_indices.append(idx)
            concept_values.append(value)
    
    # 在扩散过程中注入概念干预
    def concept_intervention_hook(latents, timestep):
        # 通过CB-SAE编码
        features = cb_sae.encode_to_features(latents)
        concepts = cb_sae.features_to_concepts(features)
        
        # 应用干预
        for idx, val in zip(concept_indices, concept_values):
            concepts[:, idx] = val
        
        # 重建
        modified_latents = cb_sae.concepts_to_output(concepts)
        
        return modified_latents
    
    # 生成图像
    image = diffusion_model.generate(
        intervention_hook=concept_intervention_hook
    )
    
    return image

4.3 安全内容过滤

class CB_SAE_SafetyFilter:
    """基于CB-SAE的安全内容过滤器"""
    
    def __init__(self, cb_sae, safety_threshold: float = 0.7):
        self.cb_sae = cb_sae
        self.safety_threshold = safety_threshold
        
        # 安全相关概念
        self.safety_concepts = [
            "contains_harmful_content",
            "promotes_violence",
            "involves_illegal_activity",
        ]
        
        self.safety_indices = [
            cb_sae.concept_names.index(c)
            for c in self.safety_concepts
            if c in cb_sae.concept_names
        ]
    
    def check_safety(self, text_features: torch.Tensor) -> dict:
        """
        检查文本内容安全性
        
        Returns:
            安全评估结果
        """
        with torch.no_grad():
            concepts = self.cb_sae.features_to_concepts(
                self.cb_sae.encode_to_features(text_features)
            )
        
        safety_scores = {
            concept: concepts[0, idx].item()
            for concept, idx in zip(self.safety_concepts, self.safety_indices)
        }
        
        is_safe = all(score < self.safety_threshold for score in safety_scores.values())
        
        return {
            "is_safe": is_safe,
            "safety_scores": safety_scores,
            "flagged_concepts": [
                concept for concept, score in safety_scores.items()
                if score >= self.safety_threshold
            ],
        }
    
    def filter_content(
        self,
        text_features: torch.Tensor,
        model_output: str,
    ) -> str:
        """过滤不安全内容"""
        safety_check = self.check_safety(text_features)
        
        if not safety_check["is_safe"]:
            # 返回安全响应
            return "I cannot assist with this request as it may involve harmful content."
        
        return model_output

5. 实验结果

5.1 可解释性评估

方法概念对应准确率用户理解得分
标准SAE45%3.2/5
CB-SAE (无监督)61%3.6/5
CB-SAE (弱监督)78%4.1/5

5.2 可操控性评估

任务标准SAE成功率CB-SAE成功率
情感增强58%87%
风格转换52%82%
事实修正64%91%
安全干预71%95%

5.3 重建质量

模型Loss Recovered概念一致性
标准SAE0.85N/A
CB-SAE0.820.91
CB-SAE+0.840.94

6. 实践指南

6.1 概念定义

定义有意义的概念是CB-SAE成功的关键:

# 概念定义示例
concept_definitions = {
    # 语义概念
    "is_python_code": "文本是Python编程语言代码",
    "is_english": "文本是英文",
    "contains_date": "文本包含日期信息",
    
    # 情感概念
    "has_positive_sentiment": "文本表达正面情感",
    "has_negative_sentiment": "文本表达负面情感",
    
    # 安全概念
    "is_safe_content": "内容是安全的",
    "requires_carefulness": "内容需要谨慎处理",
    
    # 任务概念
    "is_question": "文本是问句",
    "is_factual": "文本是事实陈述",
    "is_opinion": "文本是个人观点",
}

6.2 训练配置

# 推荐训练配置
training_config = {
    # 数据配置
    "batch_size": 64,
    "n_epochs": 100,
    "learning_rate": 1e-4,
    
    # 概念瓶颈配置
    "n_concepts": 256,  # 概念数量
    "concept_ratio": 0.25,  # 概念/特征比例
    
    # 损失权重
    "lambda_recon": 1.0,
    "lambda_concept": 0.1,  # 概念监督权重
    "lambda_consistency": 0.01,
    "lambda_l1": 0.001,
    
    # 学习率调度
    "warmup_steps": 1000,
    "cosine_decay": True,
}

6.3 概念自动发现

def discover_concepts_with_llm(
    sae,
    dataset,
    n_concepts: int = 256,
    n_samples_per_concept: int = 50,
) -> dict[str, str]:
    """
    使用LLM自动发现概念
    """
    from anthropic import Anthropic
    
    client = Anthropic()
    
    # 收集高频激活的文本片段
    feature_texts = collect_top_activations(sae, dataset, n_concepts, n_samples_per_concept)
    
    # 使用LLM为每个特征生成概念描述
    concepts = {}
    
    for feat_idx, texts in feature_texts.items():
        prompt = f"""Analyze these text examples that activate the same feature:
 
{chr(10).join(texts)}
 
What single semantic concept do these examples share?
Provide a brief, descriptive name for this concept.
 
Respond in JSON format:
{{
    "concept_name": "descriptive_name",
    "description": "brief description"
}}"""
        
        response = client.messages.create(
            model="claude-sonnet-4-20250514",
            max_tokens=100,
            messages=[{"role": "user", "content": prompt}]
        )
        
        result = json.loads(response.content[0].text)
        concepts[f"feature_{feat_idx}"] = result
    
    return concepts

7. 局限性与未来方向

7.1 当前局限性

局限性描述影响
概念定义主观概念定义需要人工设计可扩展性受限
监督成本需要概念标签数据获取成本高
概念重叠概念之间可能有重叠操控可能不精确

7.2 未来方向

方向描述
自动概念发现使用LLM自动从特征中提取概念
层次概念构建概念的层次结构
多模态概念跨模态的概念对应
动态概念根据上下文自适应概念

8. 参考


相关资源

Footnotes

  1. “Interpretable and Steerable Concept Bottleneck Sparse Autoencoders.” arXiv:2512.10805, 2025.