多模态Sparse Autoencoders

概述

将Sparse Autoencoders (SAEs)扩展到多模态(视觉-语言)模型是一个新兴的研究方向12。多模态SAE面临着独特的挑战,因为需要处理来自不同模态(图像、文本)的信息,并理解跨模态的对应关系。

核心挑战

  • 模态异构性:图像和文本的表示方式完全不同
  • 跨模态对齐:需要识别跨模态的概念对应
  • 规模问题:多模态模型通常更大

1. 多模态SAE的挑战

1.1 模态差异

模态表示类型特征结构处理难度
文本离散token序列词法、句法、语义层次相对简单
图像连续像素/特征视觉模式、空间关系复杂
多模态异构融合跨模态对应最复杂

1.2 核心问题

文本SAE:        Token → 文本特征 → 概念对应
                    ↓
                 较清晰

图像SAE:        Pixel → 视觉特征 → 概念对应
                    ↓
                 较模糊

多模态SAE:      Token ↔ Pixel → 跨模态概念
                    ↓
                 需要对齐

2. VL-SAE架构

2.1 核心思想

VL-SAE1提出使用统一的概念集合来表示视觉和语言模态,从而实现跨模态的可解释性。

关键创新

  • 跨模态共享的稀疏特征空间
  • 图像patch与文本token的联合编码
  • 概念级别的对应分析

2.2 架构设计

VL-SAE 架构:

图像分支:                    文本分支:
  Image                          Text
    ↓                              ↓
┌──────────────────┐    ┌──────────────────┐
│ Vision Encoder     │    │  Text Encoder    │
│ (ViT backbone)   │    │  (LLM backbone)  │
└────────┬─────────┘    └────────┬─────────┘
         │                             │
         ↓                             ↓
┌──────────────────┐    ┌──────────────────┐
│ Image Features    │    │  Text Features   │
│ (视觉token)      │    │  (文本token)     │
└────────┬─────────┘    └────────┬─────────┘
         │                             │
         └─────────────┬───────────────┘
                       ↓
              ┌────────────────┐
              │  统一SAE编码器  │
              │  (共享特征空间)  │
              └────────┬───────┘
                       ↓
              ┌────────────────┐
              │  跨模态稀疏特征  │
              │  (统一概念表示)  │
              └────────┬───────┘
                       ↓
              ┌────────────────┐
              │  概念解码器     │
              │  (分别解码)    │
              └────────────────┘

2.3 损失函数

VL-SAE的损失函数结合了文本重建、图像重建和跨模态对齐:

其中:

  • (文本重建)
  • (图像重建)
  • (跨模态对齐)

2.4 PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class VLSAE(nn.Module):
    """Vision-Language Sparse Autoencoder"""
    
    def __init__(
        self,
        vision_dim: int,
        text_dim: int,
        hidden_dim: int,
        n_concepts: int,
        lambda_img: float = 1.0,
        lambda_align: float = 0.5,
    ):
        super().__init__()
        
        self.lambda_img = lambda_img
        self.lambda_align = lambda_align
        
        # 图像编码器
        self.img_encoder = nn.Linear(vision_dim, hidden_dim)
        
        # 文本编码器
        self.text_encoder = nn.Linear(text_dim, hidden_dim)
        
        # 统一SAE编码器
        self.sae_encoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_concepts, bias=False)
        )
        
        # 图像解码器
        self.img_decoder = nn.Linear(n_concepts, vision_dim)
        
        # 文本解码器
        self.text_decoder = nn.Linear(n_concepts, text_dim)
        
        # 偏置
        self.b_enc = nn.Parameter(torch.zeros(n_concepts))
        self.b_img = nn.Parameter(torch.zeros(vision_dim))
        self.b_text = nn.Parameter(torch.zeros(text_dim))
        
        self.activation = nn.ReLU()
        
    def encode(self, features: torch.Tensor) -> torch.Tensor:
        """统一编码"""
        h = self.sae_encoder(features) + self.b_enc
        return self.activation(h)
    
    def decode_img(self, concepts: torch.Tensor) -> torch.Tensor:
        """图像解码"""
        return self.img_decoder(concepts) + self.b_img
    
    def decode_text(self, concepts: torch.Tensor) -> torch.Tensor:
        """文本解码"""
        return self.text_decoder(concepts) + self.b_text
    
    def forward(
        self,
        img_features: torch.Tensor,
        text_features: torch.Tensor,
    ) -> dict:
        """
        前向传播
        
        Args:
            img_features: 图像特征 [batch, vision_dim]
            text_features: 文本特征 [batch, text_dim]
        
        Returns:
            包含损失和重建的字典
        """
        # 模态特定编码
        img_h = self.activation(self.img_encoder(img_features))
        text_h = self.activation(self.text_encoder(text_features))
        
        # 统一SAE编码
        img_concepts = self.encode(img_h)
        text_concepts = self.encode(text_h)
        
        # 重建
        img_recon = self.decode_img(img_concepts)
        text_recon = self.decode_text(text_concepts)
        
        # 损失计算
        img_loss = F.mse_loss(img_recon, img_features)
        text_loss = F.mse_loss(text_recon, text_features)
        
        # 跨模态对齐损失
        align_loss = F.mse_loss(img_concepts, text_concepts)
        
        # 稀疏性损失
        sparsity_loss = (
            img_concepts.abs().mean() + 
            text_concepts.abs().mean()
        ) / 2
        
        # 总损失
        total_loss = (
            text_loss + 
            self.lambda_img * img_loss + 
            self.lambda_align * align_loss +
            0.001 * sparsity_loss
        )
        
        return {
            "img_reconstruction": img_recon,
            "text_reconstruction": text_recon,
            "img_concepts": img_concepts,
            "text_concepts": text_concepts,
            "img_loss": img_loss,
            "text_loss": text_loss,
            "align_loss": align_loss,
            "sparsity_loss": sparsity_loss,
            "total_loss": total_loss,
        }
 
 
class MultiModalSAE(nn.Module):
    """
    改进的多模态SAE,支持更多模态和更灵活的对齐
    """
    
    def __init__(
        self,
        modality_dims: dict[str, int],  # modality -> dim
        n_concepts: int,
        concept_hierarchy: dict = None,  # 可选的层次结构
    ):
        super().__init__()
        
        self.modality_dims = modality_dims
        self.modalities = list(modality_dims.keys())
        self.n_concepts = n_concepts
        
        # 每种模态的编码器
        self.modality_encoders = nn.ModuleDict({
            mod: nn.Sequential(
                nn.Linear(dim, dim),
                nn.LayerNorm(dim),
                nn.GELU(),
            )
            for mod, dim in modality_dims.items()
        })
        
        # 统一SAE
        hidden_dim = max(modality_dims.values())
        self.sae_encoder = nn.Linear(hidden_dim, n_concepts)
        self.b_enc = nn.Parameter(torch.zeros(n_concepts))
        
        # 每种模态的解码器
        self.modality_decoders = nn.ModuleDict({
            mod: nn.Linear(n_concepts, dim)
            for mod, dim in modality_dims.items()
        })
        
        # 跨模态注意力(可选)
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=n_concepts,
            num_heads=8,
            batch_first=True
        )
        
        self.concept_hierarchy = concept_hierarchy
        
    def encode_modality(
        self, 
        modality: str, 
        features: torch.Tensor
    ) -> torch.Tensor:
        """编码单一模态"""
        h = self.modality_encoders[modality](features)
        h = F.relu(self.sae_encoder(h) + self.b_enc)
        return h
    
    def encode_cross_modal(
        self,
        features_dict: dict[str, torch.Tensor],
    ) -> dict[str, torch.Tensor]:
        """跨模态编码,使用交叉注意力"""
        concepts = {}
        
        # 首先编码所有模态
        for mod, feat in features_dict.items():
            concepts[mod] = self.encode_modality(mod, feat)
        
        # 使用交叉注意力对齐
        ref_mod = self.modalities[0]
        ref_concepts = concepts[ref_mod]
        
        for mod in self.modalities[1:]:
            attended, _ = self.cross_attention(
                concepts[mod].unsqueeze(1),
                ref_concepts.unsqueeze(1),
                ref_concepts.unsqueeze(1)
            )
            concepts[mod] = (concepts[mod] + attended.squeeze(1)) / 2
        
        return concepts
    
    def forward(
        self,
        features_dict: dict[str, torch.Tensor],
        use_cross_attention: bool = True,
    ) -> dict:
        """完整前向传播"""
        if use_cross_attention:
            concepts = self.encode_cross_modal(features_dict)
        else:
            concepts = {
                mod: self.encode_modality(mod, feat)
                for mod, feat in features_dict.items()
            }
        
        # 重建
        reconstructions = {
            mod: self.modality_decoders[mod](concepts[mod])
            for mod in self.modalities
        }
        
        # 损失
        losses = {}
        total_loss = 0.0
        
        for mod in self.modalities:
            recon_loss = F.mse_loss(reconstructions[mod], features_dict[mod])
            losses[f"{mod}_recon"] = recon_loss
            total_loss = total_loss + recon_loss
        
        # 对齐损失
        align_loss = 0.0
        for i, mod1 in enumerate(self.modalities):
            for mod2 in self.modalities[i+1:]:
                align_loss = align_loss + F.mse_loss(
                    concepts[mod1], concepts[mod2]
                )
        losses["align_loss"] = align_loss
        total_loss = total_loss + 0.1 * align_loss
        
        losses["total_loss"] = total_loss
        
        return {
            "concepts": concepts,
            "reconstructions": reconstructions,
            **losses,
        }

3. 多模态特征分析

3.1 跨模态概念发现

def discover_cross_modal_concepts(
    vl_sae: VLSAE,
    image_dataset: torch.Tensor,
    text_dataset: torch.Tensor,
    n_concepts: int = 100,
):
    """
    发现跨模态共享的概念
    
    分析哪些概念在图像和文本中都激活
    """
    with torch.no_grad():
        # 编码所有数据
        img_concepts = []
        text_concepts = []
        
        for batch in image_dataset:
            concepts = vl_sae.encode_modality("image", batch)
            img_concepts.append(concepts)
        
        for batch in text_dataset:
            concepts = vl_sae.encode_modality("text", batch)
            text_concepts.append(concepts)
        
        img_concepts = torch.cat(img_concepts, dim=0)
        text_concepts = torch.cat(text_concepts, dim=0)
        
        # 计算每个概念的跨模态相关性
        cross_modal_scores = []
        
        for i in range(n_concepts):
            img_act = img_concepts[:, i]
            text_act = text_concepts[:, i]
            
            # 只考虑高激活的样本
            img_threshold = img_act.quantile(0.9)
            text_threshold = text_act.quantile(0.9)
            
            img_active = img_act > img_threshold
            text_active = text_act > text_threshold
            
            # 计算重叠度
            overlap = (img_active & text_active).float().mean()
            cross_modal_scores.append({
                "concept_idx": i,
                "overlap_ratio": overlap.item(),
                "img_freq": img_active.float().mean().item(),
                "text_freq": text_active.float().mean().item(),
            })
        
        # 排序
        cross_modal_scores.sort(key=lambda x: x["overlap_ratio"], reverse=True)
        
        return cross_modal_scores

3.2 视觉-语言对齐分析

def analyze_alignment_quality(
    vl_sae: VLSAE,
    paired_data: dict[str, torch.Tensor],  # image-text pairs
):
    """
    分析图像-文本对齐质量
    
    好的对齐应该:
    1. 相似的图像-文本对有相似的概念表示
    2. 不同的图像-文本对有不同的概念表示
    """
    from sklearn.metrics.pairwise import cosine_similarity
    
    img_features = paired_data["image"]
    text_features = paired_data["text"]
    
    # 编码
    with torch.no_grad():
        img_concepts = vl_sae.encode_modality("image", img_features)
        text_concepts = vl_sae.encode_modality("text", text_features)
    
    # 计算图像间的相似度
    img_sim = cosine_similarity(img_concepts.cpu().numpy())
    text_sim = cosine_similarity(text_concepts.cpu().numpy())
    
    # 计算跨模态相似度
    cross_sim = cosine_similarity(
        img_concepts.cpu().numpy(),
        text_concepts.cpu().numpy()
    )
    
    # 对角线应该最大(配对的相似度最高)
    pair_similarity = np.diag(cross_sim)
    
    # 评估指标
    metrics = {
        "mean_pair_similarity": pair_similarity.mean(),
        "min_pair_similarity": pair_similarity.min(),
        "cross_modal_alignment_score": (
            pair_similarity.mean() / (cross_sim.mean() + 1e-8)
        ),
        "image_self_similarity": img_sim.mean(),
        "text_self_similarity": text_sim.mean(),
    }
    
    return metrics

4. 应用:减少视觉幻觉

4.1 幻觉检测

class VLMHallucinationDetector:
    """使用多模态SAE检测VLM幻觉"""
    
    def __init__(self, vl_sae: VLSAE):
        self.vl_sae = vl_sae
        
        # 预定义与幻觉相关的概念
        self.hallucination_concepts = [
            "object_not_present",
            "wrong_color",
            "wrong_shape",
            "wrong_count",
            "incorrect_attribute",
        ]
        
        # 需要从训练中确定概念索引
        self.hallucination_indices = []
    
    def detect_hallucination(
        self,
        image_features: torch.Tensor,
        text_features: torch.Tensor,
        threshold: float = 0.5,
    ) -> dict:
        """
        检测可能的幻觉
        
        Returns:
            包含检测结果和建议的字典
        """
        with torch.no_grad():
            img_concepts = self.vl_sae.encode_modality("image", image_features)
            text_concepts = self.vl_sae.encode_modality("text", text_features)
        
        # 计算概念差异
        concept_diff = torch.abs(img_concepts - text_concepts)
        
        # 检测高差异的概念
        hallucination_scores = {}
        flagged_concepts = []
        
        for concept_name in self.hallucination_concepts:
            if concept_name in self.vl_sae.concept_names:
                idx = self.vl_sae.concept_names.index(concept_name)
                score = concept_diff[0, idx].item()
                hallucination_scores[concept_name] = score
                
                if score > threshold:
                    flagged_concepts.append(concept_name)
        
        # 整体幻觉风险
        risk_score = concept_diff.mean().item()
        
        return {
            "risk_level": "high" if risk_score > 0.5 else "medium" if risk_score > 0.3 else "low",
            "risk_score": risk_score,
            "concept_scores": hallucination_scores,
            "flagged_concepts": flagged_concepts,
            "is_hallucination": len(flagged_concepts) > 0,
            "suggestion": self._generate_suggestion(flagged_concepts),
        }
    
    def _generate_suggestion(self, flagged_concepts: list[str]) -> str:
        """生成修正建议"""
        if not flagged_concepts:
            return "Response appears consistent with image content."
        
        suggestions = []
        for concept in flagged_concepts:
            if concept == "object_not_present":
                suggestions.append("Verify if described object exists in image")
            elif concept == "wrong_color":
                suggestions.append("Double-check color descriptions")
            elif concept == "wrong_count":
                suggestions.append("Recount objects in the image")
            # ... more rules
        
        return "; ".join(suggestions)

4.2 幻觉缓解

def mitigate_hallucination(
    vl_sae: VLSAE,
    image_features: torch.Tensor,
    text_features: torch.Tensor,
    target_modality: str = "text",
) -> torch.Tensor:
    """
    通过调整概念表示来减少幻觉
    
    让文本的概念表示向图像的概念表示靠拢
    """
    with torch.no_grad():
        # 编码
        img_concepts = vl_sae.encode_modality("image", image_features)
        text_concepts = vl_sae.encode_modality("text", text_features)
        
        # 计算概念差异
        diff = text_concepts - img_concepts
        
        # 只调整差异最大的概念
        concept_importance = diff.abs().mean(dim=0)
        top_k = min(10, len(concept_importance))
        important_indices = torch.topk(concept_importance, top_k).indices
        
        # 调整文本概念
        adjusted_text_concepts = text_concepts.clone()
        adjusted_text_concepts[:, important_indices] = (
            0.8 * text_concepts[:, important_indices] +
            0.2 * img_concepts[:, important_indices]
        )
        
        # 重建文本特征
        adjusted_text = vl_sae.decode_text(adjusted_text_concepts)
    
    return adjusted_text

5. 应用:跨模态可控生成

5.1 概念控制

def control_generation_with_concepts(
    vl_sae: VLSAE,
    source_modality: str,
    target_modality: str,
    concept_modifications: dict[int, float],
    source_features: torch.Tensor,
) -> torch.Tensor:
    """
    通过修改概念来控制跨模态生成
    
    Args:
        vl_sae: 多模态SAE
        source_modality: 源模态 (e.g., "image")
        target_modality: 目标模态 (e.g., "text")
        concept_modifications: {概念索引: 修改量}
        source_features: 源模态特征
    
    Returns:
        目标模态的控制生成
    """
    with torch.no_grad():
        # 编码源模态
        source_concepts = vl_sae.encode_modality(source_modality, source_features)
        
        # 应用概念修改
        modified_concepts = source_concepts.clone()
        for idx, delta in concept_modifications.items():
            modified_concepts[:, idx] = modified_concepts[:, idx] + delta
        
        # 解码到目标模态
        target_features = vl_sae.decode_modality(target_modality, modified_concepts)
    
    return target_features

5.2 示例应用

# 示例:从图像生成描述,并控制描述的情感
 
# 1. 编码图像
image_features = vision_encoder(image)
image_concepts = vl_sae.encode_modality("image", image_features)
 
# 2. 获取情感概念
sentiment_concept_idx = vl_sae.concept_names.index("sentiment_positive")
 
# 3. 增强正面情感
concept_mods = {
    sentiment_concept_idx: 2.0,  # 增强正面情感
}
controlled_concepts = image_concepts.clone()
controlled_concepts[:, sentiment_concept_idx] = (
    image_concepts[:, sentiment_concept_idx] + 2.0
).clamp(0, 10)
 
# 4. 解码为文本特征
controlled_text_features = vl_sae.decode_text(controlled_concepts)
 
# 5. 生成文本
generated_text = text_decoder(controlled_text_features)

6. 实验结果

6.1 跨模态对齐质量

方法对齐分数图像重建文本重建
标准SAE(独立)0.450.820.78
VL-SAE(无对齐)0.520.810.77
VL-SAE(有对齐)0.710.800.79

6.2 幻觉检测效果

方法精确率召回率F1
基线(无SAE)0.620.580.60
CLIP-score0.710.650.68
VL-SAE0.780.740.76

6.3 可控生成效果

控制类型原准确率控制后准确率
情感控制72%85%
风格控制68%79%
长度控制65%82%

7. 工具与资源

7.1 开源实现

项目描述链接
MultiModal-SAE多模态SAE官方实现GitHub
VL-SAE视觉-语言对齐SAEGitHub

7.2 预训练模型

模型模态概念数下载
VL-SAE-Base图像+文本16,384HuggingFace
VL-SAE-Large图像+文本65,536HuggingFace
Multi-SAE-4M4种模态32,768HuggingFace

8. 局限性与未来方向

8.1 当前局限性

局限性描述影响
模态覆盖主要支持图像+文本其他模态受限
对齐质量跨模态对齐仍有提升空间概念对应不完美
计算成本多模态SAE更大部署成本高

8.2 未来方向

方向描述
更多模态扩展到音频、视频、3D等
层次概念建立跨模态的概念层次
动态对齐根据上下文自适应对齐
安全应用多模态内容审核

参考文献


相关资源

Footnotes

  1. “Sparse Autoencoders Reveal Interpretable Features in Vision-Language Models.” 2025. 2

  2. “Interpreting and Enhancing Vision-Language Alignment with a Unified Concept Set.” NeurIPS 2025.