AlphaFold系列深度解析
1. 引言
AlphaFold系列是DeepMind开发的蛋白质结构预测系统,从2018年的AlphaFold1到2024年的AlphaFold3,代表了深度学习在科学领域应用的里程碑式突破。1[^2]2
版本演进概览
| 版本 | 年份 | CASP | 核心创新 | 预测精度 |
|---|---|---|---|---|
| AlphaFold1 | 2018 | CASP13 | 距离预测+端到端学习 | 中等 |
| AlphaFold2 | 2020 | CASP14 | Evoformer+MSA融合 | 实验级 |
| AlphaFold3 | 2024 | — | PairFormer+扩散架构 | 全原子 |
2. AlphaFold1 (2018)
2.1 参赛背景
CASP13(2018年)是蛋白质结构预测领域的转折点。传统方法(如I-TASSER、RoseTTAFold)依赖同源建模和 threading,精度提升缓慢。AlphaFold1作为DeepMind的首次尝试,引入了深度学习+距离预测的新范式。
2.2 核心方法
距离预测范式
AlphaFold1的创新在于将结构预测问题转化为距离预测问题:
class AlphaFold1DistancePrediction(nn.Module):
"""
AlphaFold1的距离预测架构
"""
def __init__(self, config):
super().__init__()
# 1. CNN特征提取
self.residual_cnn = ResidualCNN(config)
# 2. 距离图预测头
self.distance_head = nn.Conv2d(
in_channels=config.d_model,
out_channels=64, # 距离区间数
kernel_size=1
)
# 3. 蛋白质骨架生成(基于预测的距离约束)
self.structure_module = BackboneGenerationModule(config)
def forward(self, msa, template=None):
# CNN特征提取
features = self.residual_cnn(msa)
# 预测距离分布
distance_logits = self.distance_head(features)
distance_probs = F.softmax(distance_logits, dim=1)
# 从距离分布采样并生成结构
backbone = self.structure_module(distance_probs)
return backbone2.3 创新点
- 端到端学习:直接从MSA预测3D结构,无需手工特征
- 距离分布预测:预测蛋白质残基对之间的距离分布,而非单一距离值
- 梯度可微的结构生成:通过TorsionNet实现可微的骨架生成
2.4 局限性
- 预测精度受限于当时的网络架构
- 结构模块相对简单,未能达到实验精度
- 对长序列蛋白质效果有限
3. AlphaFold2 (2020-2021)
3.1 整体架构
氨基酸序列
↓
序列嵌入 + MSA嵌入 + 成对嵌入
↓
┌─────────────────────────────────────┐
│ Evoformer Stack (48层) │
├─────────────────────────────────────┤
│ • MSA行注意力 │
│ • MSA列注意力 │
│ • 成对注意力(三角更新) │
│ • 外部循环连接 │
└─────────────────────────────────────┘
↓
Structure Module (8层)
↓
Invariant Point Attention (IPA)
↓
原子坐标 + pLDDT置信度
3.2 Evoformer架构详解
Evoformer是AlphaFold2的核心创新,实现了MSA信息与成对关系的高效融合。
3.2.1 MSA表示与成对表示
class EvoformerBlock(nn.Module):
"""
Evoformer基本单元
"""
def __init__(self, config):
super().__init__()
# MSA处理模块
self.msa_row_attention = MSARowAttention(config)
self.msa_column_attention = MSAColumnAttention(config)
self.msa_transition = Transition(config) # LayerNorm + MLP
# 成对处理模块
self.pair_attention = PairAttention(config)
self.triangle_multiplication_outgoing = TriangleMultiplication(
config, direction='outgoing'
)
self.triangle_multiplication_incoming = TriangleMultiplication(
config, direction='incoming'
)
self.triangle_attention_starting = TriangleAttention(
config, axis='starting'
)
self.triangle_attention_ending = TriangleAttention(
config, axis='ending'
)
# 通信模块
self.msa_to_pair = MSAToPair(config)
self.pair_to_msa = PairToMSA(config)
def forward(self, msa_repr, pair_repr, msa_mask=None):
# MSA处理
msa_repr = msa_repr + self.msa_row_attention(msa_repr, msa_mask)
msa_repr = msa_repr + self.msa_column_attention(msa_repr, pair_repr)
msa_repr = msa_repr + self.msa_transition(msa_repr)
# 成对处理
pair_repr = pair_repr + self.msa_to_pair(msa_repr)
pair_repr = pair_repr + self.pair_attention(pair_repr)
pair_repr = pair_repr + self.triangle_multiplication_outgoing(pair_repr)
pair_repr = pair_repr + self.triangle_multiplication_incoming(pair_repr)
pair_repr = pair_repr + self.triangle_attention_starting(pair_repr)
pair_repr = pair_repr + self.triangle_attention_ending(pair_repr)
# MSA更新(利用成对信息)
msa_repr = msa_repr + self.pair_to_msa(msa_repr, pair_repr)
return msa_repr, pair_repr3.2.2 MSA行注意力
class MSARowAttention(nn.Module):
"""
MSA行注意力:同一位置在不同序列间的注意力
类似于self-attention,但作用于MSA维度
"""
def __init__(self, config):
super().__init__()
self.num_heads = config.num_heads
self.d_head = config.d_model // config.num_heads
self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model)
self.o_proj = nn.Linear(config.d_model, config.d_model)
def forward(self, msa_repr, mask=None):
"""
Args:
msa_repr: [B, N_seq, N_res, d_model] MSA表示
mask: [B, N_seq, N_res] 掩码
Returns:
更新后的MSA表示 [B, N_seq, N_res, d_model]
"""
B, N_seq, N_res, D = msa_repr.shape
# QKV投影
qkv = self.qkv_proj(msa_repr)
q, k, v = qkv.chunk(3, dim=-1)
# Reshape for multi-head attention
q = q.view(B, N_seq, N_res, self.num_heads, self.d_head).transpose(2, 3)
k = k.view(B, N_seq, N_res, self.num_heads, self.d_head).transpose(2, 3)
v = v.view(B, N_seq, N_res, self.num_heads, self.d_head).transpose(2, 3)
# Attention
scale = self.d_head ** -0.5
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
if mask is not None:
mask = mask.unsqueeze(2).unsqueeze(3) # [B, N_seq, 1, 1, N_res]
attn = attn.masked_fill(mask == 0, -1e9)
attn = F.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
# Merge heads
out = out.transpose(2, 3).contiguous().view(B, N_seq, N_res, D)
out = self.o_proj(out)
return out3.2.3 MSA列注意力
class MSAColumnAttention(nn.Module):
"""
MSA列注意力:同一序列在不同位置间的注意力
捕获同源序列中的共进化信息
"""
def forward(self, msa_repr, pair_repr=None):
"""
Args:
msa_repr: [B, N_seq, N_res, d_model]
pair_repr: [B, N_res, N_res, d_pair] 成对表示(可选)
Returns:
更新后的MSA表示 [B, N_seq, N_res, d_model]
"""
# 转置:N_seq变为序列维度
msa_T = msa_repr.transpose(1, 2) # [B, N_res, N_seq, d]
# 标准self-attention
attn_output = self.self_attention(msa_T)
# 恢复形状
return attn_output.transpose(1, 2)3.2.4 三角更新机制
三角更新是防止成对表示陷入简单线性组合的关键:
class TriangleMultiplication(nn.Module):
"""
三角乘法更新
核心思想:利用三角约束更新成对表示
[i,j] = sum_k [i,k] * [k,j]
"""
def __init__(self, config, direction='outgoing'):
super().__init__()
self.direction = direction
d = config.d_pair
self.num_heads = config.num_heads
self.d_head = d // config.num_heads
# 传入边更新
self.left_norm = nn.LayerNorm(d)
self.left_proj = nn.Linear(d, 2 * d) # 用于gating
# 传出边更新
self.right_norm = nn.LayerNorm(d)
self.right_proj = nn.Linear(d, 2 * d)
self.output_norm = nn.LayerNorm(d)
self.output_proj = nn.Linear(d, d)
def forward(self, pair_repr):
"""
Args:
pair_repr: [B, N_res, N_res, d_pair]
Returns:
更新后的成对表示 [B, N_res, N_res, d_pair]
"""
B, N, _, D = pair_repr.shape
# 传入边
left = self.left_norm(pair_repr)
left_gating, left_values = self.left_proj(left).split(D, dim=-1)
# 传出边
right = self.right_norm(pair_repr)
right_gating, right_values = self.right_proj(right).split(D, dim=-1)
# 矩阵乘法:[i,j] = sum_k [i,k] * [k,j]
if self.direction == 'outgoing':
# 左乘
left_reshape = left_values.view(B, N, N, self.num_heads, self.d_head)
right_reshape = right_values.view(B, N, N, self.num_heads, self.d_head)
# [i,j] = sum_k left[i,k] * right[k,j]
out = torch.einsum('ijhnd,khd->ijhnd', left_reshape, right_reshape.mean(dim=1))
else:
# incoming: [i,j] = sum_k left[k,i] * right[j,k]
out = torch.einsum('kjhnd,jhd->ijhnd', left_reshape, right_reshape.mean(dim=0))
# Gate
gate = torch.sigmoid(left_gating + right_gating)
out = out * gate.unsqueeze(-1)
return self.output_proj(self.output_norm(out))3.3 MSA处理与序列比对
3.3.1 MSA构建
MSA(Multiple Sequence Alignment,多序列比对)是AlphaFold2的关键输入:
class MSAProcessor:
"""
MSA处理流程
"""
def __init__(self, config):
self.config = config
def build_msa(self, sequence, homologs=None):
"""
构建MSA
"""
if homologs is None:
# 使用HHBlits/JACKHMMER搜索数据库
homologs = self.sequence_search(sequence)
# 比对构建
msa = self.multiple_alignment([sequence] + homologs)
# 剪裁:限制MSA深度
if len(msa) > self.config.max_msa:
msa = self.sample_msa(msa, self.config.max_msa)
return msa
def sequence_search(self, sequence):
"""
使用HHBlits或Jackhmmer搜索同源序列
"""
# 搜索UniClust30、MGnify等数据库
hits = hhblits.search(sequence, database='uniclust30')
return hits
def sample_msa(self, msa, max_sequences):
"""
均匀采样以保持多样性
"""
n_total = len(msa)
indices = torch.linspace(0, n_total-1, max_sequences).long()
return msa[indices]3.4 距离预测到概率分布
AlphaFold2将距离预测转化为概率分布预测:
class DistogramHead(nn.Module):
"""
距离分布预测头
输出:每个残基对的距离区间概率
"""
def __init__(self, config, num_bins=64, min_dist=2.0, max_dist=22.0):
super().__init__()
self.num_bins = num_bins
self.min_dist = min_dist
self.max_dist = max_dist
self.bin_width = (max_dist - min_dist) / num_bins
self.projection = nn.Linear(config.d_pair, num_bins)
def forward(self, pair_repr):
"""
Args:
pair_repr: [B, N, N, d_pair]
Returns:
distogram: [B, N, N, num_bins] 距离概率分布
"""
logits = self.projection(pair_repr)
probs = F.softmax(logits, dim=-1)
# 期望距离
bin_centers = torch.linspace(
self.min_dist + self.bin_width/2,
self.max_dist - self.bin_width/2,
self.num_bins,
device=probs.device
)
expected_dist = torch.sum(probs * bin_centers, dim=-1)
return {
'probs': probs,
'expected': expected_dist,
'logits': logits
}
def compute_loss(self, pred_dist, true_dist):
"""
计算距离分布损失
"""
# 真实距离映射到bin
bin_idx = ((true_dist - self.min_dist) / self.bin_width).long()
bin_idx = torch.clamp(bin_idx, 0, self.num_bins - 1)
# 交叉熵损失
return F.cross_entropy(
pred_dist['logits'].view(-1, self.num_bins),
bin_idx.view(-1)
)3.5 置信度估计
3.5.1 pLDDT (per-residue Local Distance Difference Test)
class pLDDTHead(nn.Module):
"""
局部置信度预测
pLDDT: 预测LDDT分数,表示单残基级别的预测可靠性
"""
def __init__(self, config):
super().__init__()
self.dense1 = nn.Linear(config.d_single, 128)
self.dense2 = nn.Linear(128, 37) # 37个LDDT bin
def forward(self, single_repr):
"""
Returns:
plddt_logits: [B, N, 37] 每个残基的LDDT预测
"""
x = F.relu(self.dense1(single_repr))
return self.dense2(x)
def compute_plddt(self, plddt_logits):
"""
将logits转换为pLDDT分数
"""
# Softmax得到概率
probs = F.softmax(plddt_logits, dim=-1)
# LDDT bin中心值: [0.1, 0.2, ..., 1.0, ...] (更细的分桶)
bins = torch.linspace(0.1, 1.0, 37, device=probs.device)
# 期望值
plddt = torch.sum(probs * bins, dim=-1)
return pLDDT * 100 # 转换为0-1003.5.2 PAE (Predicted Alignment Error)
class PAEHead(nn.Module):
"""
预测对齐误差
PAE[i,j]: 如果将残基i对齐到残基j,预期误差是多少
"""
def __init__(self, config):
super().__init__()
self.projection = nn.Linear(config.d_pair, 64) # 误差区间数
def forward(self, pair_repr):
"""
Returns:
pae: [B, N, N, 64] PAE概率分布
"""
return self.projection(pair_repr)
def compute_expected_pae(self, pae_logits):
"""
计算期望PAE
"""
bins = torch.linspace(0, 32, 64, device=pae_logits.device) # 0-32Å
probs = F.softmax(pae_logits, dim=-1)
expected_pae = torch.sum(probs * bins, dim=-1)
return expected_pae3.6 关键技术详解
3.6.1 MSARefinement机制
class MSARefinementIteration(nn.Module):
"""
MSA精炼迭代
AlphaFold2的核心循环机制
"""
def __init__(self, config):
super().__init__()
self.evoformer = EvoformerBlock(config)
self.structure_module = StructureModule(config)
def forward(self, msa_repr, pair_repr, msa_mask):
"""
一次迭代
"""
# Evoformer更新
msa_repr, pair_repr = self.evoformer(msa_repr, pair_repr, msa_mask)
# 结构模块更新
single_repr = msa_repr[0] # 取第一行(目标序列)
structure_output = self.structure_module(single_repr, pair_repr)
# 将结构信息注入MSA(关键创新)
msa_repr = self.inject_structure_to_msa(
msa_repr, structure_output['backbone']
)
return msa_repr, pair_repr, structure_output3.6.2 Invariant Point Attention (IPA)
class InvariantPointAttention(nn.Module):
"""
不变点注意力
核心:在3D空间中计算注意力,但保持旋转-平移不变性
"""
def __init__(self, config):
super().__init__()
self.num_heads = config.num_heads
self.num_points = 4 # 4个参考点
# 标量注意力
self.scalar_attention = ScalarAttention(config, self.num_heads)
# 几何注意力(等变)
self.geometric_attention = GeometricAttention(
config, self.num_heads, self.num_points
)
def forward(self, single_repr, pair_repr, backbone, mask=None):
"""
Args:
single_repr: [B, N, d] 单残基表示
pair_repr: [B, N, N, d_pair] 成对表示
backbone: 包含旋转和平移的骨架
"""
# 标量注意力权重
scalar_weights = self.scalar_attention(single_repr, pair_repr)
# 几何特征
points_local = self.compute_local_points(single_repr)
points_global = backbone.transform(points_local) # 等变变换
# 几何注意力
geometric_output = self.geometric_attention(
points_global, scalar_weights, mask
)
return geometric_output3.7 Nature 2021论文核心要点
“Highly accurate protein structure prediction with AlphaFold”
Jumper et al., Nature 596, 583-589 (2021)
关键贡献:
- 端到端Evoformer:首次将MSA和成对关系联合建模
- 突破性精度:在CASP14中达到实验级精度
- 置信度估计:pLDDT和PAE提供可靠的预测质量评估
- 完整开源:AlphaFold2和AlphaFold-Multimer代码开源
4. AlphaFold3 (2024)
4.1 核心创新概述
AlphaFold3在AlphaFold2基础上进行了架构重构:
| 组件 | AlphaFold2 | AlphaFold3 |
|---|---|---|
| 核心架构 | Evoformer | PairFormer |
| 结构生成 | 结构模块+IPA | 扩散模型(Diffusion) |
| MSA处理 | 深度集成 | 边缘化/简化 |
| 输入范围 | 蛋白质 | 全原子(蛋白+核酸+配体+离子) |
4.2 PairFormer架构
PairFormer移除了显式的MSA表示,专注于成对交互:
class PairformerBlock(nn.Module):
"""
PairFormer:简化版的Evoformer
"""
def __init__(self, config):
super().__init__()
# Self-attention on pairs
self.pair_self_attention = PairSelfAttention(config)
# 三角更新(保留核心机制)
self.triangle_update = TriangleUpdate(config)
# Transition
self.transition = Transition(config)
# 单体表示更新(简化的IPA)
self.single_update = SingleRepresentationUpdate(config)
def forward(self, single_repr, pair_repr):
# 成对表示更新
pair_repr = pair_repr + self.pair_self_attention(pair_repr)
pair_repr = pair_repr + self.triangle_update(pair_repr)
pair_repr = pair_repr + self.transition(pair_repr)
# 单体表示更新(利用成对信息)
single_repr = single_repr + self.single_update(single_repr, pair_repr)
return single_repr, pair_repr4.3 Diffusion Head架构
class DiffusionHead(nn.Module):
"""
AlphaFold3的扩散模型结构生成
"""
def __init__(self, config):
super().__init__()
# 噪声调度
self.noise_schedule = NoiseSchedule(
num_steps=1000,
beta_start=1e-6,
beta_end=0.01
)
# 去噪骨干网络
self.denoising_backbone = DiffusivePairformer(config)
# 条件编码器
self.condition_encoder = ConditionEncoder(config)
def forward(self, noisy_coords, timestep, conditions):
"""
Args:
noisy_coords: [B, N_atoms, 3] 带噪声的原子坐标
timestep: t 扩散时间步
conditions: 序列、MSA、模板等条件
Returns:
noise_pred: 预测的噪声
"""
# 编码条件
cond = self.condition_encoder(conditions)
# 添加时间嵌入
t_emb = self.get_timestep_embedding(timestep)
# 去噪预测
noise_pred = self.denoising_backbone(
noisy_coords, cond, t_emb
)
return noise_pred
def training_loss(self, clean_coords, conditions):
"""
简化的训练损失
"""
B = clean_coords.shape[0]
# 采样时间步
t = torch.randint(0, self.noise_schedule.num_steps, (B,))
# 添加噪声
noise = torch.randn_like(clean_coords)
noisy_coords = self.noise_schedule.add_noise(clean_coords, noise, t)
# 预测噪声
noise_pred = self.forward(noisy_coords, t, conditions)
return F.mse_loss(noise_pred, noise)
def sampling(self, conditions, num_steps=1000):
"""
从纯噪声生成结构
"""
coords = torch.randn(
conditions['num_atoms'], 3,
device=conditions['device']
).unsqueeze(0)
for t in reversed(range(num_steps)):
timestep = torch.tensor([t], device=coords.device)
noise_pred = self.forward(coords, timestep, conditions)
coords = self.noise_schedule.step(coords, noise_pred, t)
return coords4.4 全原子预测能力
AlphaFold3支持预测:
| 分子类型 | 支持情况 |
|---|---|
| 蛋白质 | ✅ 单体、复合物 |
| DNA | ✅ 双链、单链 |
| RNA | ✅ 发夹、假结 |
| 小分子配体 | ✅ 药物分子、金属离子 |
| 共价修饰 | ✅ 磷酸化、糖基化 |
| 离子 | ✅ Mg²⁺、Zn²⁺、Fe²⁺ |
4.5 配体与抗体支持
class LigandProcessor:
"""
配体处理模块
"""
def __init__(self):
self.element_embedding = nn.Embedding(100, 128) # 元素周期表
def process_ligand(self, ligand_smiles, ligand_coords):
"""
处理小分子配体
"""
# 从SMILES构建分子图
mol = Chem.MolFromSmiles(ligand_smiles)
# 提取原子特征
atom_features = self.extract_atom_features(mol)
# 原子类型嵌入
atom_emb = self.element_embedding(atom_features['atomic_num'])
# 3D坐标初始嵌入
coord_emb = self.encode_coords(ligand_coords)
return {
'atoms': atom_emb,
'coords': coord_emb,
'bonds': self.extract_bond_features(mol)
}
def extract_atom_features(self, mol):
"""
提取原子特征:原子序数、电负性、杂化类型等
"""
features = {
'atomic_num': [],
'formal_charge': [],
'hybridization': [],
'num_h': [],
}
for atom in mol.GetAtoms():
features['atomic_num'].append(atom.GetAtomicNum())
features['formal_charge'].append(atom.GetFormalCharge())
features['hybridization'].append(int(atom.GetHybridization()))
features['num_h'].append(atom.GetTotalNumHs())
return {k: torch.tensor(v) for k, v in features.items()}4.6 Nature 2024论文核心要点
“Accurate structure prediction of biomolecular interactions with AlphaFold 3”
Abramson et al., Nature 630, 1150-1160 (2024)
关键贡献:
- 统一框架:首次将蛋白质、核酸、配体统一到一个模型
- 扩散生成:用扩散模型替代迭代结构更新
- 大幅简化:移除显式MSA表示,降低计算复杂度
- API开放:AlphaFold Server向学术机构开放
5. 技术细节
5.1 注意力机制在结构预测中的应用
┌────────────────────────────────────────────────────┐
│ 注意力类型对比 │
├────────────────────────────────────────────────────┤
│ MSA行注意力 │ 同位置多序列信息融合 │
│ MSA列注意力 │ 同序列多位置共进化 │
│ 成对注意力 │ 残基对相互作用建模 │
│ IPA │ 3D几何结构感知(等变) │
│ 交叉注意力 │ MSA → 成对 → 结构 信息流动 │
└────────────────────────────────────────────────────┘
5.2 距离图与3D坐标生成
距离分布预测 ──→ 约束优化 ──→ 骨架坐标 ──→ 原子坐标
│ │ │
▼ ▼ ▼
distogram Rosetta-like 侧链旋转异构体
loss refinement
5.3 置信度预测机制
| 指标 | 定义 | 用途 |
|---|---|---|
| pLDDT | 单残基LDDT预测 | 判断局部质量,识别无序区 |
| PAE | 预测对齐误差矩阵 | 评估域间/链间相对位置 |
| pTM | 成对TM-score预测 | 整体预测质量 |
| ipTM | 接口TM-score | 复合物界面质量 |
6. 开源资源
6.1 AlphaFold Server
# 通过AlphaFold Server在线预测
# 访问:https://alphafold.ebi.ac.uk/
# API调用示例(如果有)
curl -X POST "https://alphafold.ebi.ac.uk/api/predict/{uniprot_id}" \
-H "accept: application/json"6.2 AlphaFold Database
| 数据库 | URL | 内容 |
|---|---|---|
| AlphaFold DB | https://alphafold.ebi.ac.uk | UniProt蛋白质 |
| EMBL-EBI | alphafold.ebi.ac.uk | 预计算结构 |
| Google Cloud | gs://public-diamond-af | 完整数据集 |
6.3 ColabFold
# 安装ColabFold
pip install colabfold
# 快速预测
colabfold_search input.fasta output_dir --host 1
# 批量预测
colabfold_batch input.fasta output_dir \
--num-recycle 3 \
--msa-mode mmseqs2_uniref_env6.4 本地部署
# 使用Docker部署完整AlphaFold2
docker pull ghcr.io/deepmind/alphafold:2.3.1
# 运行预测
docker run -v $(pwd)/data:/data \
ghcr.io/deepmind/alphafold:2.3.1 \
--fasta_paths=/data/input.fasta \
--output_dir=/data/output \
--max_template_date=2024-01-01 \
--db_preset=full_dbs7. 参考论文
AlphaFold1
AlphaFold2
AlphaFold3
相关主题
最后更新: 2026-05-15