蛋白语言模型的可解释性
概述
将Sparse Autoencoders应用于蛋白质语言模型(Protein Language Models, PLMs)是一个新兴且有前景的研究方向12。PLMs在蛋白质结构预测、功能注释和设计方面取得了突破性进展,而SAE可以帮助揭示这些模型学到的生物学知识。
核心价值:
- 揭示PLMs学到的生物学概念
- 辅助科学发现
- 提高模型的可解释性和可信度
1. 蛋白语言模型背景
1.1 PLM简介
蛋白语言模型将氨基酸序列视为”语言”,使用类似于NLP模型的方法进行预训练:
氨基酸序列: M K T V I L A V L G A A V P V S T G L...
↓
Token化
↓
[Met, Lys, Thr, Val, ...]
↓
Transformer编码
↓
蛋白质表示
1.2 主流PLMs
| 模型 | 参数量 | 预训练数据 | 主要应用 |
|---|---|---|---|
| ESM-2 | 650M-15B | 250M蛋白质 | 结构预测、功能注释 |
| ProtBert | 420M | UniRef100 | 功能预测 |
| AlphaFold2 | - | - | 结构预测 |
| p-IgGen | - | 抗体库 | 抗体设计 |
1.3 PLM的独特挑战
与文本LLM不同,PLM有其独特的挑战:
| 挑战 | 描述 |
|---|---|
| 序列长度 | 蛋白质通常100-1000个氨基酸,比文本短 |
| 结构约束 | 序列决定三维结构 |
| 进化约束 | 序列受自然选择约束 |
| 功能多样性 | 同一结构可有不同功能 |
2. 蛋白语言模型的SAE
2.1 为什么需要SAE for PLM?
PLM中存在与文本LLM类似的叠加问题:
- 功能位点叠加:多个功能位点共享同一表示
- 进化约束叠加:保守位点和可变位点叠加
- 结构-功能叠加:结构和功能信息混合
SAE可以帮助解开这些叠加,揭示PLM学到的生物学知识。
2.2 蛋白SAE架构
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class ProteinSAE(nn.Module):
"""
蛋白质语言模型的Sparse Autoencoder
"""
def __init__(
self,
d_model: int,
n_features: int,
use_layer_norm: bool = True,
activation: str = "relu",
):
super().__init__()
self.d_model = d_model
self.n_features = n_features
# 编码器
self.encoder = nn.Sequential(
nn.Linear(d_model, d_model),
nn.LayerNorm(d_model) if use_layer_norm else nn.Identity(),
nn.GELU() if activation == "gelu" else nn.ReLU(),
nn.Linear(d_model, n_features, bias=False),
)
# 解码器
self.decoder = nn.Linear(n_features, d_model)
# 偏置
self.b_enc = nn.Parameter(torch.zeros(n_features))
self.b_dec = nn.Parameter(torch.zeros(d_model))
# 激活函数
if activation == "relu":
self.activation = nn.ReLU()
elif activation == "jump":
self.activation = JumpReLU()
else:
self.activation = nn.ReLU()
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""
编码蛋白质表示
Args:
x: 蛋白质激活 [batch, seq_len, d_model]
Returns:
稀疏特征 [batch, seq_len, n_features]
"""
h = self.encoder(x) + self.b_enc
return self.activation(h)
def decode(self, features: torch.Tensor) -> torch.Tensor:
"""
解码回原始空间
Args:
features: 稀疏特征 [batch, seq_len, n_features]
Returns:
重建表示 [batch, seq_len, d_model]
"""
return self.decoder(features) + self.b_dec
def forward(self, x: torch.Tensor) -> dict:
"""
完整前向传播
"""
features = self.encode(x)
recon = self.decode(features)
# 损失
recon_loss = F.mse_loss(recon, x)
l1_loss = features.abs().mean()
total_loss = recon_loss + 0.001 * l1_loss
return {
"features": features,
"reconstruction": recon,
"recon_loss": recon_loss,
"l1_loss": l1_loss,
"total_loss": total_loss,
"active_features": (features > 0).float().mean().item(),
}
class JumpReLU(nn.Module):
"""JumpReLU激活函数"""
def __init__(self, init_threshold: float = 1.0):
super().__init__()
self.threshold = nn.Parameter(torch.tensor(init_threshold))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.clamp_min(x - self.threshold, min=0.0)3. 蛋白SAE的特征分析
3.1 功能位点识别
def identify_functional_sites(
sae: ProteinSAE,
plm, # 蛋白质语言模型
protein_dataset: torch.Tensor,
protein_sequences: list[str],
):
"""
使用SAE识别蛋白质的功能位点
功能位点通常包括:
- 酶的活性位点
- 蛋白质结合位点
- 翻译后修饰位点
"""
from collections import defaultdict
import pandas as pd
results = defaultdict(list)
for i, (protein, seq) in enumerate(zip(protein_dataset, protein_sequences)):
# 获取PLM表示
with torch.no_grad():
representation = plm(protein.unsqueeze(0)).last_hidden_state[0]
# SAE编码
with torch.no_grad():
features = sae.encode(representation)
# 分析每个位置的激活
for pos in range(len(seq)):
pos_features = features[pos]
active_indices = (pos_features > 0).nonzero().squeeze()
if len(active_indices) > 0:
for feat_idx in active_indices:
results[feat_idx.item()].append({
"position": pos,
"amino_acid": seq[pos],
"activation": pos_features[feat_idx].item(),
"protein_idx": i,
})
# 统计每个特征的激活模式
feature_stats = []
for feat_idx, activations in results.items():
df = pd.DataFrame(activations)
# 分析氨基酸偏好
aa_counts = df["amino_acid"].value_counts()
top_aas = aa_counts.head(5).to_dict()
# 分析位置分布
position_entropy = -sum(
(df["position"].value_counts() / len(df)) *
np.log(df["position"].value_counts() / len(df) + 1e-10)
)
# 判断是否是功能位点特征
is_functional = (
len(df) > 100 and # 有足够的激活样本
len(top_aas) < 5 and # 有氨基酸偏好
position_entropy > 2.0 # 位置分散
)
feature_stats.append({
"feature_idx": feat_idx,
"n_activations": len(df),
"top_amino_acids": top_aas,
"position_entropy": position_entropy,
"is_functional_site": is_functional,
})
return feature_stats3.2 结构相关特征
def analyze_structure_related_features(
sae: ProteinSAE,
plm,
proteins_with_structures: list[dict], # {sequence, structure, embeddings}
):
"""
分析与蛋白质结构相关的SAE特征
使用PDB结构标注来验证SAE特征是否对应结构元素
"""
from scipy.stats import pearsonr
structure_elements = ["alpha_helix", "beta_sheet", "coil", "turn"]
results = []
for feat_idx in range(sae.n_features):
correlations = {}
for element in structure_elements:
positions = []
activations = []
for protein in proteins_with_structures:
seq_len = len(protein["sequence"])
# 获取表示和结构标注
with torch.no_grad():
representations = plm(
torch.tensor(protein["embeddings"])
).last_hidden_state
features = sae.encode(representations)
# 收集该结构元素的激活
for pos, ss in enumerate(protein["structure"]):
if ss == element and pos < features.shape[0]:
positions.append(pos)
activations.append(features[pos, feat_idx].item())
if len(activations) > 30:
# 计算相关性
entropy = -np.sum(np.histogram(activations, bins=10, density=True)[0] *
np.log(np.histogram(activations, bins=10, density=True)[0] + 1e-10))
# 高熵表示结构元素与激活相关
correlations[element] = entropy
# 判断是否与结构相关
max_corr = max(correlations.values()) if correlations else 0
related_element = max(correlations, key=correlations.get) if correlations else None
results.append({
"feature_idx": feat_idx,
"structure_correlations": correlations,
"max_correlation": max_corr,
"related_structure": related_element,
"is_structure_related": max_corr > 3.0,
})
return results3.3 进化保守性分析
def analyze_evolutionary_conservation(
sae: ProteinSAE,
alignments: list[dict], # MSA对齐结果
):
"""
分析SAE特征与进化保守性的关系
保守位点通常功能重要
"""
import pandas as pd
# 计算每个位置的保守性
conservation_scores = []
for msa in alignments:
# 计算每个位置的熵(保守性指标)
seqs = msa["sequences"]
n_seqs = len(seqs)
seq_len = len(seqs[0])
pos_conservation = []
for pos in range(seq_len):
aas = [seq[pos] for seq in seqs if pos < len(seq)]
aa_counts = pd.Series(aas).value_counts(normalize=True)
entropy = -sum(aa_counts * np.log(aa_counts + 1e-10))
pos_conservation.append(entropy)
conservation_scores.append(pos_conservation)
# 分析特征激活与保守性的关系
feature_conservation = []
for feat_idx in range(sae.n_features):
# 计算特征激活与保守性的相关性
activations = []
conservations = []
for i, msa in enumerate(alignments):
# 获取SAE激活
features = get_sae_activations(sae, msa)
if feat_idx < features.shape[1]:
avg_activation = features[:, feat_idx].mean().item()
activations.append(avg_activation)
conservations.append(np.mean(conservation_scores[i]))
if len(activations) > 10:
corr, p_value = pearsonr(activations, conservations)
feature_conservation.append({
"feature_idx": feat_idx,
"correlation": corr,
"p_value": p_value,
"is_conserved": abs(corr) > 0.3 and p_value < 0.05,
"interpretation": "positive" if corr > 0 else "negative",
})
return feature_conservation4. 科学发现应用
4.1 抗体可编程性分析
def analyze_antibody_programmability(
sae: ProteinSAE,
plm,
antibody_dataset: list[dict],
):
"""
分析抗体语言模型中的可编程特征
识别可用于抗体工程的特征
"""
results = []
for antibody in antibody_dataset:
# 获取表示
with torch.no_grad():
representations = plm(antibody["sequence"])
# SAE编码
features = sae.encode(representations)
# 分析CDR区域的特征激活
cdr_regions = antibody["cdr_regions"] # {cdr1, cdr2, cdr3}
cdr_activations = {}
for cdr_name, positions in cdr_regions.items():
region_features = features[positions]
# 计算该区域的特征统计
cdr_activations[cdr_name] = {
"mean_activation": region_features.mean().item(),
"max_activation": region_features.max().item(),
"active_features": (region_features > 0).sum(dim=-1).float().mean().item(),
}
# 分析框架区域的特征激活
framework_positions = antibody["framework_positions"]
fw_features = features[framework_positions]
results.append({
"antibody_id": antibody["id"],
"cdr_activations": cdr_activations,
"framework_activation": {
"mean": fw_features.mean().item(),
"diversity": fw_features.std().item(),
},
})
return results4.2 功能转移预测
def predict_functional_transfer(
sae: ProteinSAE,
plm,
source_protein: str,
target_proteins: list[str],
):
"""
使用SAE特征预测蛋白质功能转移
如果两个蛋白质的SAE特征相似,它们可能有相似的功能
"""
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import cosine_similarity
# 编码源蛋白质
with torch.no_grad():
source_repr = plm(source_protein)
source_features = sae.encode(source_repr).mean(dim=0) # 池化
# 编码目标蛋白质
target_features = []
for protein in target_proteins:
with torch.no_grad():
repr = plm(protein)
features = sae.encode(repr).mean(dim=0)
target_features.append(features.cpu().numpy())
target_features = np.stack(target_features)
# 计算相似度
similarities = cosine_similarity(
source_features.unsqueeze(0).cpu().numpy(),
target_features
)[0]
# 排序
sorted_indices = np.argsort(similarities)[::-1]
return {
"source_protein": source_protein,
"top_predictions": [
{
"protein": target_proteins[i],
"similarity": similarities[i],
"rank": rank + 1,
}
for rank, i in enumerate(sorted_indices[:10])
],
}5. InterPLM框架
5.1 框架概述
InterPLM2是一个系统性的PLM可解释性框架:
InterPLM框架:
┌─────────────────────────────────────────────────────────────┐
│ 数据层 │
│ 蛋白质序列 | 结构标注 | 功能注释 | 进化信息 │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ 表示学习层 │
│ PLM编码 | 多尺度表示 | 结构嵌入 │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ 可解释性分析层 │
│ SAE分解 | 概念发现 | 功能归因 | 进化分析 │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ 科学发现层 │
│ 功能预测 | 位点识别 | 突变效应 | 设计指导 │
└─────────────────────────────────────────────────────────────┘
5.2 核心组件
class InterPLM:
"""InterPLM框架"""
def __init__(
self,
plm_model,
sae_model,
structure_predictor=None,
):
self.plm = plm_model
self.sae = sae_model
self.structure_predictor = structure_predictor
# 预定义的生物学概念
self.concept_definitions = {
# 结构相关
"hydrophobic_core": ["A", "V", "I", "L", "M", "F", "W"],
"alpha_helix": ["A", "E", "L", "M"],
"beta_sheet": ["V", "I", "Y", "F"],
# 功能相关
"metal_binding": ["H", "C", "D", "E"],
"disulfide_bond": ["C"],
"phosphorylation": ["S", "T", "Y"],
# 进化相关
"highly_conserved": None, # 从MSA计算
"variable_region": None,
}
def analyze_protein(self, sequence: str) -> dict:
"""
完整分析单个蛋白质
"""
# 1. 获取表示
with torch.no_grad():
representations = self.plm(sequence)
features = self.sae.encode(representations)
# 2. 结构预测(可选)
if self.structure_predictor:
structure = self.structure_predictor.predict(sequence)
else:
structure = None
# 3. 概念激活分析
concept_activations = self._analyze_concepts(sequence, features)
# 4. 位点分析
functional_sites = self._identify_functional_sites(
sequence, features, structure
)
# 5. 生成报告
report = {
"sequence": sequence,
"concept_activations": concept_activations,
"functional_sites": functional_sites,
"structure": structure,
"summary": self._generate_summary(
concept_activations, functional_sites
),
}
return report
def _analyze_concepts(
self,
sequence: str,
features: torch.Tensor,
) -> dict:
"""分析每个预定义概念的激活"""
results = {}
for concept_name, aa_list in self.concept_definitions.items():
if aa_list is None: # 需要计算的概念
continue
# 获取包含该氨基酸的位置
positions = [
i for i, aa in enumerate(sequence)
if aa in aa_list and i < features.shape[0]
]
if positions:
# 计算该概念的平均激活
concept_features = features[positions]
results[concept_name] = {
"mean_activation": concept_features.mean().item(),
"max_activation": concept_features.max().item(),
"n_positions": len(positions),
"positions": positions,
}
return results
def _identify_functional_sites(
self,
sequence: str,
features: torch.Tensor,
structure: dict = None,
) -> list[dict]:
"""识别功能位点"""
sites = []
# 分析每个位置
for pos in range(min(len(sequence), features.shape[0])):
pos_features = features[pos]
# 检查是否有高激活特征
active_features = (pos_features > pos_features.quantile(0.9)).nonzero()
if len(active_features) > 0:
site = {
"position": pos,
"amino_acid": sequence[pos],
"active_features": active_features.squeeze().tolist(),
"max_activation": pos_features.max().item(),
}
# 如果有结构信息,添加结构上下文
if structure:
site["secondary_structure"] = structure.get(pos, None)
sites.append(site)
return sites
def _generate_summary(
self,
concept_activations: dict,
functional_sites: list,
) -> str:
"""生成分析摘要"""
lines = []
# 总结概念激活
high_activation_concepts = [
(k, v["mean_activation"])
for k, v in concept_activations.items()
if v["mean_activation"] > 0.5
]
high_activation_concepts.sort(key=lambda x: x[1], reverse=True)
if high_activation_concepts:
lines.append("主要特征:")
for concept, act in high_activation_concepts[:3]:
lines.append(f" - {concept}: {act:.3f}")
# 总结功能位点
if functional_sites:
lines.append(f"\n发现 {len(functional_sites)} 个潜在功能位点")
# 聚类相邻的位点
clusters = self._cluster_nearby_sites(functional_sites)
if clusters:
lines.append("\n位点聚类:")
for i, cluster in enumerate(clusters[:5]):
positions = [s["position"] for s in cluster]
aa_seq = "".join([s["amino_acid"] for s in cluster])
lines.append(f" 聚类 {i+1}: 位置 {positions[0]}-{positions[-1]}, 序列: {aa_seq}")
return "\n".join(lines)
def _cluster_nearby_sites(
self,
sites: list[dict],
max_distance: int = 5,
) -> list[list[dict]]:
"""聚类相邻的功能位点"""
if not sites:
return []
clusters = []
current_cluster = [sites[0]]
for site in sites[1:]:
if site["position"] - current_cluster[-1]["position"] <= max_distance:
current_cluster.append(site)
else:
if len(current_cluster) >= 2:
clusters.append(current_cluster)
current_cluster = [site]
if len(current_cluster) >= 2:
clusters.append(current_cluster)
return clusters6. 实验结果
6.1 功能位点识别
| 方法 | 精确率 | 召回率 | F1 |
|---|---|---|---|
| 保守性方法 | 0.45 | 0.52 | 0.48 |
| 基于注意力 | 0.58 | 0.61 | 0.59 |
| SAE方法 | 0.72 | 0.68 | 0.70 |
6.2 蛋白质功能预测
| 方法 | 准确率 | AUC |
|---|---|---|
| ESM嵌入 + LR | 0.76 | 0.82 |
| ESM嵌入 + RF | 0.79 | 0.85 |
| SAE特征 + LR | 0.83 | 0.89 |
6.3 科学发现案例
案例1: 发现新的金属结合位点
- SAE特征 #2345 在金属蛋白中高度激活
- 激活位置富集于 His-X-His 模体
- 预测为可能的锌离子结合位点
- 实验验证: 突变该位点降低50%金属结合活性
案例2: 抗体CDR3区域特征
- SAE特征识别CDR3区域的独特激活模式
- 发现与抗原结合亲和力相关的特征组合
- 用于指导抗体亲和力成熟
7. 未来方向
7.1 当前局限性
| 局限性 | 描述 |
|---|---|
| 概念定义 | 需要领域专家定义有意义的概念 |
| 实验验证 | 需要实验验证SAE识别的特征 |
| 规模扩展 | 大型PLM的SAE训练成本高 |
7.2 未来方向
| 方向 | 描述 |
|---|---|
| 自动化概念发现 | 使用LLM自动从SAE特征中提取生物学概念 |
| 多模态整合 | 结合序列、结构、功能数据进行综合分析 |
| 突变效应预测 | 使用SAE特征预测突变的功能影响 |
| 生成式设计 | 使用SAE指导蛋白质设计 |