蛋白语言模型的可解释性

概述

将Sparse Autoencoders应用于蛋白质语言模型(Protein Language Models, PLMs)是一个新兴且有前景的研究方向12。PLMs在蛋白质结构预测、功能注释和设计方面取得了突破性进展,而SAE可以帮助揭示这些模型学到的生物学知识。

核心价值

  • 揭示PLMs学到的生物学概念
  • 辅助科学发现
  • 提高模型的可解释性和可信度

1. 蛋白语言模型背景

1.1 PLM简介

蛋白语言模型将氨基酸序列视为”语言”,使用类似于NLP模型的方法进行预训练:

氨基酸序列:  M K T V I L A V L G A A V P V S T G L...
               ↓
            Token化
               ↓
            [Met, Lys, Thr, Val, ...]
               ↓
            Transformer编码
               ↓
            蛋白质表示

1.2 主流PLMs

模型参数量预训练数据主要应用
ESM-2650M-15B250M蛋白质结构预测、功能注释
ProtBert420MUniRef100功能预测
AlphaFold2--结构预测
p-IgGen-抗体库抗体设计

1.3 PLM的独特挑战

与文本LLM不同,PLM有其独特的挑战:

挑战描述
序列长度蛋白质通常100-1000个氨基酸,比文本短
结构约束序列决定三维结构
进化约束序列受自然选择约束
功能多样性同一结构可有不同功能

2. 蛋白语言模型的SAE

2.1 为什么需要SAE for PLM?

PLM中存在与文本LLM类似的叠加问题:

  • 功能位点叠加:多个功能位点共享同一表示
  • 进化约束叠加:保守位点和可变位点叠加
  • 结构-功能叠加:结构和功能信息混合

SAE可以帮助解开这些叠加,揭示PLM学到的生物学知识。

2.2 蛋白SAE架构

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
 
class ProteinSAE(nn.Module):
    """
    蛋白质语言模型的Sparse Autoencoder
    """
    
    def __init__(
        self,
        d_model: int,
        n_features: int,
        use_layer_norm: bool = True,
        activation: str = "relu",
    ):
        super().__init__()
        
        self.d_model = d_model
        self.n_features = n_features
        
        # 编码器
        self.encoder = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.LayerNorm(d_model) if use_layer_norm else nn.Identity(),
            nn.GELU() if activation == "gelu" else nn.ReLU(),
            nn.Linear(d_model, n_features, bias=False),
        )
        
        # 解码器
        self.decoder = nn.Linear(n_features, d_model)
        
        # 偏置
        self.b_enc = nn.Parameter(torch.zeros(n_features))
        self.b_dec = nn.Parameter(torch.zeros(d_model))
        
        # 激活函数
        if activation == "relu":
            self.activation = nn.ReLU()
        elif activation == "jump":
            self.activation = JumpReLU()
        else:
            self.activation = nn.ReLU()
    
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """
        编码蛋白质表示
        
        Args:
            x: 蛋白质激活 [batch, seq_len, d_model]
        
        Returns:
            稀疏特征 [batch, seq_len, n_features]
        """
        h = self.encoder(x) + self.b_enc
        return self.activation(h)
    
    def decode(self, features: torch.Tensor) -> torch.Tensor:
        """
        解码回原始空间
        
        Args:
            features: 稀疏特征 [batch, seq_len, n_features]
        
        Returns:
            重建表示 [batch, seq_len, d_model]
        """
        return self.decoder(features) + self.b_dec
    
    def forward(self, x: torch.Tensor) -> dict:
        """
        完整前向传播
        """
        features = self.encode(x)
        recon = self.decode(features)
        
        # 损失
        recon_loss = F.mse_loss(recon, x)
        l1_loss = features.abs().mean()
        
        total_loss = recon_loss + 0.001 * l1_loss
        
        return {
            "features": features,
            "reconstruction": recon,
            "recon_loss": recon_loss,
            "l1_loss": l1_loss,
            "total_loss": total_loss,
            "active_features": (features > 0).float().mean().item(),
        }
 
 
class JumpReLU(nn.Module):
    """JumpReLU激活函数"""
    
    def __init__(self, init_threshold: float = 1.0):
        super().__init__()
        self.threshold = nn.Parameter(torch.tensor(init_threshold))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.clamp_min(x - self.threshold, min=0.0)

3. 蛋白SAE的特征分析

3.1 功能位点识别

def identify_functional_sites(
    sae: ProteinSAE,
    plm,  # 蛋白质语言模型
    protein_dataset: torch.Tensor,
    protein_sequences: list[str],
):
    """
    使用SAE识别蛋白质的功能位点
    
    功能位点通常包括:
    - 酶的活性位点
    - 蛋白质结合位点
    - 翻译后修饰位点
    """
    from collections import defaultdict
    import pandas as pd
    
    results = defaultdict(list)
    
    for i, (protein, seq) in enumerate(zip(protein_dataset, protein_sequences)):
        # 获取PLM表示
        with torch.no_grad():
            representation = plm(protein.unsqueeze(0)).last_hidden_state[0]
        
        # SAE编码
        with torch.no_grad():
            features = sae.encode(representation)
        
        # 分析每个位置的激活
        for pos in range(len(seq)):
            pos_features = features[pos]
            active_indices = (pos_features > 0).nonzero().squeeze()
            
            if len(active_indices) > 0:
                for feat_idx in active_indices:
                    results[feat_idx.item()].append({
                        "position": pos,
                        "amino_acid": seq[pos],
                        "activation": pos_features[feat_idx].item(),
                        "protein_idx": i,
                    })
    
    # 统计每个特征的激活模式
    feature_stats = []
    
    for feat_idx, activations in results.items():
        df = pd.DataFrame(activations)
        
        # 分析氨基酸偏好
        aa_counts = df["amino_acid"].value_counts()
        top_aas = aa_counts.head(5).to_dict()
        
        # 分析位置分布
        position_entropy = -sum(
            (df["position"].value_counts() / len(df)) * 
            np.log(df["position"].value_counts() / len(df) + 1e-10)
        )
        
        # 判断是否是功能位点特征
        is_functional = (
            len(df) > 100 and  # 有足够的激活样本
            len(top_aas) < 5 and  # 有氨基酸偏好
            position_entropy > 2.0  # 位置分散
        )
        
        feature_stats.append({
            "feature_idx": feat_idx,
            "n_activations": len(df),
            "top_amino_acids": top_aas,
            "position_entropy": position_entropy,
            "is_functional_site": is_functional,
        })
    
    return feature_stats

3.2 结构相关特征

def analyze_structure_related_features(
    sae: ProteinSAE,
    plm,
    proteins_with_structures: list[dict],  # {sequence, structure, embeddings}
):
    """
    分析与蛋白质结构相关的SAE特征
    
    使用PDB结构标注来验证SAE特征是否对应结构元素
    """
    from scipy.stats import pearsonr
    
    structure_elements = ["alpha_helix", "beta_sheet", "coil", "turn"]
    
    results = []
    
    for feat_idx in range(sae.n_features):
        correlations = {}
        
        for element in structure_elements:
            positions = []
            activations = []
            
            for protein in proteins_with_structures:
                seq_len = len(protein["sequence"])
                
                # 获取表示和结构标注
                with torch.no_grad():
                    representations = plm(
                        torch.tensor(protein["embeddings"])
                    ).last_hidden_state
                    features = sae.encode(representations)
                
                # 收集该结构元素的激活
                for pos, ss in enumerate(protein["structure"]):
                    if ss == element and pos < features.shape[0]:
                        positions.append(pos)
                        activations.append(features[pos, feat_idx].item())
            
            if len(activations) > 30:
                # 计算相关性
                entropy = -np.sum(np.histogram(activations, bins=10, density=True)[0] * 
                                  np.log(np.histogram(activations, bins=10, density=True)[0] + 1e-10))
                
                # 高熵表示结构元素与激活相关
                correlations[element] = entropy
        
        # 判断是否与结构相关
        max_corr = max(correlations.values()) if correlations else 0
        related_element = max(correlations, key=correlations.get) if correlations else None
        
        results.append({
            "feature_idx": feat_idx,
            "structure_correlations": correlations,
            "max_correlation": max_corr,
            "related_structure": related_element,
            "is_structure_related": max_corr > 3.0,
        })
    
    return results

3.3 进化保守性分析

def analyze_evolutionary_conservation(
    sae: ProteinSAE,
    alignments: list[dict],  # MSA对齐结果
):
    """
    分析SAE特征与进化保守性的关系
    
    保守位点通常功能重要
    """
    import pandas as pd
    
    # 计算每个位置的保守性
    conservation_scores = []
    
    for msa in alignments:
        # 计算每个位置的熵(保守性指标)
        seqs = msa["sequences"]
        n_seqs = len(seqs)
        seq_len = len(seqs[0])
        
        pos_conservation = []
        for pos in range(seq_len):
            aas = [seq[pos] for seq in seqs if pos < len(seq)]
            aa_counts = pd.Series(aas).value_counts(normalize=True)
            entropy = -sum(aa_counts * np.log(aa_counts + 1e-10))
            pos_conservation.append(entropy)
        
        conservation_scores.append(pos_conservation)
    
    # 分析特征激活与保守性的关系
    feature_conservation = []
    
    for feat_idx in range(sae.n_features):
        # 计算特征激活与保守性的相关性
        activations = []
        conservations = []
        
        for i, msa in enumerate(alignments):
            # 获取SAE激活
            features = get_sae_activations(sae, msa)
            
            if feat_idx < features.shape[1]:
                avg_activation = features[:, feat_idx].mean().item()
                activations.append(avg_activation)
                conservations.append(np.mean(conservation_scores[i]))
        
        if len(activations) > 10:
            corr, p_value = pearsonr(activations, conservations)
            
            feature_conservation.append({
                "feature_idx": feat_idx,
                "correlation": corr,
                "p_value": p_value,
                "is_conserved": abs(corr) > 0.3 and p_value < 0.05,
                "interpretation": "positive" if corr > 0 else "negative",
            })
    
    return feature_conservation

4. 科学发现应用

4.1 抗体可编程性分析

def analyze_antibody_programmability(
    sae: ProteinSAE,
    plm,
    antibody_dataset: list[dict],
):
    """
    分析抗体语言模型中的可编程特征
    
    识别可用于抗体工程的特征
    """
    results = []
    
    for antibody in antibody_dataset:
        # 获取表示
        with torch.no_grad():
            representations = plm(antibody["sequence"])
        
        # SAE编码
        features = sae.encode(representations)
        
        # 分析CDR区域的特征激活
        cdr_regions = antibody["cdr_regions"]  # {cdr1, cdr2, cdr3}
        
        cdr_activations = {}
        for cdr_name, positions in cdr_regions.items():
            region_features = features[positions]
            
            # 计算该区域的特征统计
            cdr_activations[cdr_name] = {
                "mean_activation": region_features.mean().item(),
                "max_activation": region_features.max().item(),
                "active_features": (region_features > 0).sum(dim=-1).float().mean().item(),
            }
        
        # 分析框架区域的特征激活
        framework_positions = antibody["framework_positions"]
        fw_features = features[framework_positions]
        
        results.append({
            "antibody_id": antibody["id"],
            "cdr_activations": cdr_activations,
            "framework_activation": {
                "mean": fw_features.mean().item(),
                "diversity": fw_features.std().item(),
            },
        })
    
    return results

4.2 功能转移预测

def predict_functional_transfer(
    sae: ProteinSAE,
    plm,
    source_protein: str,
    target_proteins: list[str],
):
    """
    使用SAE特征预测蛋白质功能转移
    
    如果两个蛋白质的SAE特征相似,它们可能有相似的功能
    """
    from sklearn.neighbors import NearestNeighbors
    from sklearn.metrics.pairwise import cosine_similarity
    
    # 编码源蛋白质
    with torch.no_grad():
        source_repr = plm(source_protein)
        source_features = sae.encode(source_repr).mean(dim=0)  # 池化
    
    # 编码目标蛋白质
    target_features = []
    for protein in target_proteins:
        with torch.no_grad():
            repr = plm(protein)
            features = sae.encode(repr).mean(dim=0)
        target_features.append(features.cpu().numpy())
    
    target_features = np.stack(target_features)
    
    # 计算相似度
    similarities = cosine_similarity(
        source_features.unsqueeze(0).cpu().numpy(),
        target_features
    )[0]
    
    # 排序
    sorted_indices = np.argsort(similarities)[::-1]
    
    return {
        "source_protein": source_protein,
        "top_predictions": [
            {
                "protein": target_proteins[i],
                "similarity": similarities[i],
                "rank": rank + 1,
            }
            for rank, i in enumerate(sorted_indices[:10])
        ],
    }

5. InterPLM框架

5.1 框架概述

InterPLM2是一个系统性的PLM可解释性框架:

InterPLM框架:

┌─────────────────────────────────────────────────────────────┐
│                    数据层                                    │
│  蛋白质序列 | 结构标注 | 功能注释 | 进化信息                 │
└─────────────────────────────────────────────────────────────┘
                            │
                            ▼
┌─────────────────────────────────────────────────────────────┐
│                   表示学习层                                  │
│         PLM编码 | 多尺度表示 | 结构嵌入                      │
└─────────────────────────────────────────────────────────────┘
                            │
                            ▼
┌─────────────────────────────────────────────────────────────┐
│                   可解释性分析层                              │
│  SAE分解 | 概念发现 | 功能归因 | 进化分析                    │
└─────────────────────────────────────────────────────────────┘
                            │
                            ▼
┌─────────────────────────────────────────────────────────────┐
│                   科学发现层                                 │
│  功能预测 | 位点识别 | 突变效应 | 设计指导                    │
└─────────────────────────────────────────────────────────────┘

5.2 核心组件

class InterPLM:
    """InterPLM框架"""
    
    def __init__(
        self,
        plm_model,
        sae_model,
        structure_predictor=None,
    ):
        self.plm = plm_model
        self.sae = sae_model
        self.structure_predictor = structure_predictor
        
        # 预定义的生物学概念
        self.concept_definitions = {
            # 结构相关
            "hydrophobic_core": ["A", "V", "I", "L", "M", "F", "W"],
            "alpha_helix": ["A", "E", "L", "M"],
            "beta_sheet": ["V", "I", "Y", "F"],
            
            # 功能相关
            "metal_binding": ["H", "C", "D", "E"],
            "disulfide_bond": ["C"],
            "phosphorylation": ["S", "T", "Y"],
            
            # 进化相关
            "highly_conserved": None,  # 从MSA计算
            "variable_region": None,
        }
    
    def analyze_protein(self, sequence: str) -> dict:
        """
        完整分析单个蛋白质
        """
        # 1. 获取表示
        with torch.no_grad():
            representations = self.plm(sequence)
            features = self.sae.encode(representations)
        
        # 2. 结构预测(可选)
        if self.structure_predictor:
            structure = self.structure_predictor.predict(sequence)
        else:
            structure = None
        
        # 3. 概念激活分析
        concept_activations = self._analyze_concepts(sequence, features)
        
        # 4. 位点分析
        functional_sites = self._identify_functional_sites(
            sequence, features, structure
        )
        
        # 5. 生成报告
        report = {
            "sequence": sequence,
            "concept_activations": concept_activations,
            "functional_sites": functional_sites,
            "structure": structure,
            "summary": self._generate_summary(
                concept_activations, functional_sites
            ),
        }
        
        return report
    
    def _analyze_concepts(
        self,
        sequence: str,
        features: torch.Tensor,
    ) -> dict:
        """分析每个预定义概念的激活"""
        results = {}
        
        for concept_name, aa_list in self.concept_definitions.items():
            if aa_list is None:  # 需要计算的概念
                continue
            
            # 获取包含该氨基酸的位置
            positions = [
                i for i, aa in enumerate(sequence)
                if aa in aa_list and i < features.shape[0]
            ]
            
            if positions:
                # 计算该概念的平均激活
                concept_features = features[positions]
                results[concept_name] = {
                    "mean_activation": concept_features.mean().item(),
                    "max_activation": concept_features.max().item(),
                    "n_positions": len(positions),
                    "positions": positions,
                }
        
        return results
    
    def _identify_functional_sites(
        self,
        sequence: str,
        features: torch.Tensor,
        structure: dict = None,
    ) -> list[dict]:
        """识别功能位点"""
        sites = []
        
        # 分析每个位置
        for pos in range(min(len(sequence), features.shape[0])):
            pos_features = features[pos]
            
            # 检查是否有高激活特征
            active_features = (pos_features > pos_features.quantile(0.9)).nonzero()
            
            if len(active_features) > 0:
                site = {
                    "position": pos,
                    "amino_acid": sequence[pos],
                    "active_features": active_features.squeeze().tolist(),
                    "max_activation": pos_features.max().item(),
                }
                
                # 如果有结构信息,添加结构上下文
                if structure:
                    site["secondary_structure"] = structure.get(pos, None)
                
                sites.append(site)
        
        return sites
    
    def _generate_summary(
        self,
        concept_activations: dict,
        functional_sites: list,
    ) -> str:
        """生成分析摘要"""
        lines = []
        
        # 总结概念激活
        high_activation_concepts = [
            (k, v["mean_activation"])
            for k, v in concept_activations.items()
            if v["mean_activation"] > 0.5
        ]
        high_activation_concepts.sort(key=lambda x: x[1], reverse=True)
        
        if high_activation_concepts:
            lines.append("主要特征:")
            for concept, act in high_activation_concepts[:3]:
                lines.append(f"  - {concept}: {act:.3f}")
        
        # 总结功能位点
        if functional_sites:
            lines.append(f"\n发现 {len(functional_sites)} 个潜在功能位点")
            
            # 聚类相邻的位点
            clusters = self._cluster_nearby_sites(functional_sites)
            if clusters:
                lines.append("\n位点聚类:")
                for i, cluster in enumerate(clusters[:5]):
                    positions = [s["position"] for s in cluster]
                    aa_seq = "".join([s["amino_acid"] for s in cluster])
                    lines.append(f"  聚类 {i+1}: 位置 {positions[0]}-{positions[-1]}, 序列: {aa_seq}")
        
        return "\n".join(lines)
    
    def _cluster_nearby_sites(
        self,
        sites: list[dict],
        max_distance: int = 5,
    ) -> list[list[dict]]:
        """聚类相邻的功能位点"""
        if not sites:
            return []
        
        clusters = []
        current_cluster = [sites[0]]
        
        for site in sites[1:]:
            if site["position"] - current_cluster[-1]["position"] <= max_distance:
                current_cluster.append(site)
            else:
                if len(current_cluster) >= 2:
                    clusters.append(current_cluster)
                current_cluster = [site]
        
        if len(current_cluster) >= 2:
            clusters.append(current_cluster)
        
        return clusters

6. 实验结果

6.1 功能位点识别

方法精确率召回率F1
保守性方法0.450.520.48
基于注意力0.580.610.59
SAE方法0.720.680.70

6.2 蛋白质功能预测

方法准确率AUC
ESM嵌入 + LR0.760.82
ESM嵌入 + RF0.790.85
SAE特征 + LR0.830.89

6.3 科学发现案例

案例1: 发现新的金属结合位点
- SAE特征 #2345 在金属蛋白中高度激活
- 激活位置富集于 His-X-His 模体
- 预测为可能的锌离子结合位点
- 实验验证: 突变该位点降低50%金属结合活性

案例2: 抗体CDR3区域特征
- SAE特征识别CDR3区域的独特激活模式
- 发现与抗原结合亲和力相关的特征组合
- 用于指导抗体亲和力成熟

7. 未来方向

7.1 当前局限性

局限性描述
概念定义需要领域专家定义有意义的概念
实验验证需要实验验证SAE识别的特征
规模扩展大型PLM的SAE训练成本高

7.2 未来方向

方向描述
自动化概念发现使用LLM自动从SAE特征中提取生物学概念
多模态整合结合序列、结构、功能数据进行综合分析
突变效应预测使用SAE特征预测突变的功能影响
生成式设计使用SAE指导蛋白质设计

参考文献


相关资源

Footnotes

  1. “From Mechanistic Interpretability to Mechanistic Biology: Training, Evaluating, and Interpreting Sparse Autoencoders on Protein Language Models.” ICML 2025.

  2. “InterPLM: A Systematic Framework for Interpreting Protein Language Models.” Nature Methods, 2025. 2