AlphaFold蛋白质结构预测

1. 引言

蛋白质结构预测是生物学中最重要的问题之一。1972年Christian Anfinsen提出安芬森假说:蛋白质的三维结构完全由其氨基酸序列决定。这一假说经过数十年验证,成为现代结构生物信息学的基础。然而,从序列预测结构(即”蛋白质折叠问题”)一直是科学界的圣杯。

2020年,DeepMind的AlphaFold2在CASP14竞赛中取得突破性进展,预测精度达到实验水平。2021年开源的AlphaFold2和2024年的AlphaFold3进一步扩展了能力边界。1


2. 问题定义

2.1 输入与输出

输入输出
氨基酸序列 (长度 )原子坐标
多序列比对 (MSA)置信度指标 (pLDDT)
模板结构 (可选)预测局部质量

2.2 评估指标

TM-score(模板建模分数):

其中 是归一化因子。

GDT(全局距离测试):衡量在给定阈值下正确预测的残基比例。


3. AlphaFold2架构

3.1 整体架构

氨基酸序列
    ↓
Input Embedding
    ↓
┌───────────────────────────────────────┐
│           Evoformer Stack (48层)       │
├───────────────────────────────────────┤
│  • MSA模块 (行/列注意力)               │
│  • 成对注意力 (Pairformer)            │
│  • 三角乘法/加法更新                   │
│  • 结构模块输入更新                   │
└───────────────────────────────────────┘
    ↓
Structure Module (8层)
    ↓
输出:原子坐标 + pLDDT

3.2 Input Embedding

class AlphaFoldInputEmbedding(nn.Module):
    """
    输入嵌入层
    """
    def __init__(self, config):
        super().__init__()
        
        # 序列嵌入
        self.seq_embedding = nn.Embedding(
            num_embeddings=21,  # 20种氨基酸 + 未知
            embedding_dim=config.embed_dim
        )
        
        # MSA轮廓嵌入
        self.msa_profile = nn.Linear(23, config.embed_dim)  # 23 = 20aa + gap + unknown
        
        # 模板嵌入(可选)
        self.template_embedding = TemplateEmbedding(config)
    
    def forward(self, sequence, msa, template=None):
        # 序列特征
        seq_emb = self.seq_embedding(sequence)
        
        # MSA特征 (shape: [N_seq, N_res, d])
        msa_emb = self.msa_profile(msa)
        
        # 成对特征 (shape: [N_res, N_res, d_pair])
        pair_emb = self.pairwise_embedding(sequence)
        
        return {
            "single": seq_emb,
            "msa": msa_emb,
            "pair": pair_emb
        }

3.3 Evoformer核心机制

Evoformer是AlphaFold的核心创新,将MSA信息和成对残基关系联合建模:

3.3.1 MSA行注意力

class MSARowAttention(nn.Module):
    """
    MSA行注意力:同一列(同一位置)在不同序列间的注意力
    """
    def __init__(self, num_heads, d_model):
        super().__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=num_heads,
            batch_first=True
        )
    
    def forward(self, msa_repr, msa_mask):
        """
        Args:
            msa_repr: [N_seq, N_res, d] MSA表示
            msa_mask: [N_seq, N_res] 掩码
        """
        # 行注意力:每个位置关注其他序列
        attn_output, _ = self.attention(
            query=msa_repr,
            key=msa_repr,
            value=msa_repr,
            key_padding_mask=msa_mask.bool()
        )
        
        return attn_output

3.3.2 MSA列注意力

class MSAColumnAttention(nn.Module):
    """
    MSA列注意力:同一序列在不同位置间的注意力
    """
    def forward(self, msa_repr, pair_repr):
        # 转置后应用注意力
        msa_T = msa_repr.transpose(0, 1)  # [N_res, N_seq, d]
        
        # 列注意力
        attn_output, _ = self.attention(msa_T, msa_T, msa_T)
        
        return attn_output.transpose(0, 1)  # 恢复形状

3.3.3 成对注意力 (Pairformer)

class Pairformer(nn.Module):
    """
    成对表示更新
    """
    def forward(self, msa_repr, pair_repr):
        # 1. MSA到成对的通信
        pair_update = self.msa_to_pair(msa_repr)  # 聚合MSA信息到成对
        
        # 2. 三角更新(防止简单线性组合导致信息丢失)
        pair_repr = self.triangle_update(pair_repr)
        
        # 3. 更新MSA(使用成对信息)
        msa_repr = self.pair_to_msa(msa_repr, pair_repr)
        
        return msa_repr, pair_repr

3.3.4 三角注意力

三角注意力解决图神经网络中的”过度平滑”问题:

class TriangleMultiplication(nn.Module):
    """
    三角乘法更新
    核心思想:利用三角约束更新成对表示
    """
    def __init__(self, d_pair, num_heads):
        super().__init__()
        # 传入乘法
        self.left_norm = nn.LayerNorm(d_pair)
        self.left_proj = nn.Linear(d_pair, num_heads * d_pair)
        
        # 传出乘法
        self.right_norm = nn.LayerNorm(d_pair)
        self.right_proj = nn.Linear(d_pair, num_heads * d_pair)
        
        # 门控
        self.gate_proj = nn.Linear(d_pair, num_heads)
    
    def forward(self, pair_repr):
        """
        pair_repr[i,j] 表示残基i和j之间的关系
        """
        # 传入边更新
        left = self.left_proj(self.left_norm(pair_repr))  # [B, N, N, H*d]
        right = self.right_proj(self.right_norm(pair_repr))  # [B, N, N, H*d]
        
        # 乘法 + 归一化
        left = left.view(*left.shape[:3], -1, d_pair)
        right = right.view(*right.shape[:3], -1, d_pair)
        
        # [i,j] = sum_k [i,k] * [k,j]
        out = torch.einsum('ijkh,klh->ijlh', left, right)
        
        # 门控
        gate = torch.sigmoid(self.gate_proj(self.left_norm(pair_repr)))
        out = out * gate.unsqueeze(-1)
        
        return out

4. 结构模块

4.1 坐标生成

结构模块从Evoformer的输出生成三维坐标:

class StructureModule(nn.Module):
    """
    结构模块:将抽象表示转换为3D坐标
    """
    def __init__(self, config):
        super().__init__()
        
        # IPA (Invariant Point Attention)
        self.ipa = InvariantPointAttention(config)
        
        # 骨架更新
        self.transition = nn.Sequential(
            nn.LayerNorm(config.d_single),
            nn.Linear(config.d_single, config.d_hidden),
            nn.ReLU(),
            nn.Linear(config.d_hidden, config.d_hidden),
            nn.ReLU(),
            nn.Linear(config.d_hidden, 6 * 37)  # 37 = 3旋转 + 3平移
        )
        
        # 原子位置预测头
        self.final_layer_norm = nn.LayerNorm(config.d_single)
        self.plddt_head = nn.Linear(config.d_single, 37)
    
    def forward(self, single_repr, pair_repr):
        """
        生成3D结构
        """
        # 初始化骨架
        backbone = BackboneUpdate()
        
        for _ in range(8):
            # 1. IPA更新(旋转不变注意力)
            single_repr = single_repr + self.ipa(
                query_repr=single_repr,
                pair_repr=pair_repr,
                backbone=backbone
            )
            
            # 2. 过渡层更新
            update = self.transition(single_repr)
            backbone = backbone.rigid_rotate(update[..., :6])
            
            # 3. 侧链旋转(可选)
            chi_angles = self.predict_sidechains(single_repr)
        
        # 输出
        return {
            "atoms": backbone.atom_positions,
            "plddt": self.plddt_head(self.final_layer_norm(single_repr))
        }

4.2 Invariant Point Attention (IPA)

IPA是关键创新:即使结构变化,注意力权重保持不变:

class InvariantPointAttention(nn.Module):
    """
    不变点注意力
    核心:在3D空间中计算注意力,但不破坏等变性
    """
    def __init__(self, config):
        super().__init__()
        self.num_heads = config.num_heads
        self.num_points = config.num_points  # 通常为4
        
        # 注意力参数
        self.attention_proj = nn.Linear(config.d_single, self.num_heads * 2 + 3 * self.num_points)
        self.head_proj = nn.Linear(self.num_heads * config.d_single, config.d_single)
    
    def forward(self, query_repr, pair_repr, backbone):
        """
        Args:
            query_repr: [B, N, d] 查询表示
            pair_repr: [B, N, N, d] 成对表示
            backbone: 3D骨架(旋转+平移)
        """
        # 计算注意力权重(标量部分)
        attn_weights = self.attention_proj(query_repr)  # [B, N, H*2]
        attn_logits = attn_weights[..., :self.num_heads]  # [B, N, H]
        attn_weights = F.softmax(attn_logits, dim=1)
        
        # 计算3D偏移(等变部分)
        # 将点从局部坐标变换到全局坐标
        global_points = backbone.rotate(query_repr[..., -3*self.num_points:])
        offsets = backbone.translate(global_points)
        
        # 聚合
        output = torch.einsum('bnh,bnhd->bnd', attn_weights, offsets)
        
        return self.head_proj(output)

5. 训练与损失函数

5.1 多尺度损失

AlphaFold使用多个损失项的组合:

def alpha_fold_loss(outputs, targets):
    """
    AlphaFold训练损失
    """
    losses = {}
    
    # 1. FAPE损失(基于帧的点距离误差)
    losses["fape"] = compute_fape(
        predicted_positions=outputs["atoms"],
        true_positions=targets["atoms"],
        break_loop_at=10.0
    )
    
    # 2. MSA修复损失
    losses["msa"] = F.cross_entropy(
        outputs["predicted_lddt"].view(-1, 37),
        targets["all_atom_mask"].long()
    )
    
    # 3. 距离预测损失
    distogram_loss = compute_distogram_loss(
        pred_dist=outputs["distogram"],
        true_dist=targets["distogram"]
    )
    losses["distogram"] = distogram_loss
    
    # 4. 辅助损失(violation)
    losses["violation"] = compute_violation_loss(
        predicted_atoms=outputs["atoms"],
        mask=targets["all_atom_mask"]
    )
    
    # 加权求和
    total_loss = sum(w * losses[k] for k, w in WEIGHTS.items())
    
    return total_loss, losses

5.2 FAPE损失

FAPE(Frame Aligned Point Error)计算预测结构与真实结构的对齐误差:

def compute_fape(pred_pos, true_pos, backbone_frames, max_loss=10.0):
    """
    FAPE损失
    """
    # 1. 对齐到公共骨架帧
    pred_local = backbone_frames.invert()(pred_pos)
    true_local = backbone_frames.invert()(true_pos)
    
    # 2. 计算距离
    dist = torch.norm(pred_local - true_local, dim=-1)
    
    # 3. 截断和平均
    loss = torch.clamp(dist, max=max_loss).mean()
    
    return loss

6. AlphaFold3进展

2024年发布的AlphaFold3将预测范围扩展到:

能力AlphaFold2AlphaFold3
单体蛋白质
蛋白质复合物
核酸
小分子配体
共价修饰
DNA/RNA

6.1 Diffusion-based架构

AlphaFold3使用扩散模型生成结构:

class AlphaFold3Diffusion(nn.Module):
    """
    AlphaFold3的扩散架构
    """
    def __init__(self, config):
        super().__init__()
        
        # 噪声计划
        self.noise_schedule = NoiseSchedule()
        
        # 骨干网络(类似Evoformer)
        self.backbone = DiffusiveEvoformer(config)
        
        # 条件注入
        self.condition_encoder = ConditionEncoder(config)
    
    def forward(self, noisy_coords, timestep, conditions):
        """
        去噪过程
        """
        # 编码条件
        cond = self.condition_encoder(conditions)
        
        # 去噪特征
        features = self.backbone(noisy_coords, cond)
        
        # 预测噪声
        noise_pred = self.noise_head(features)
        
        return noise_pred
    
    def training_loss(self, clean_coords, conditions):
        """
        训练:预测添加的噪声
        """
        # 添加噪声
        t = torch.randint(0, self.T, (clean_coords.shape[0],))
        noise = torch.randn_like(clean_coords)
        noisy_coords = self.noise_schedule.add_noise(clean_coords, noise, t)
        
        # 预测噪声
        noise_pred = self.forward(noisy_coords, t, conditions)
        
        return F.mse_loss(noise_pred, noise)

7. 数据库与工具

7.1 AlphaFold数据库

DeepMind提供预计算的蛋白质结构数据库:

数据库规模用途
AlphaFold DB (UniProt)~200M 蛋白质全局搜索
AlphaFold EBI~1M 蛋白质人类蛋白质组
RCSB PDB20K实验验证

7.2 本地部署

# 使用ColabFold(轻量版)
pip install colabfold
 
# 运行预测
colabfold_search input.fasta output_dir
colabfold_batch input.fasta output_dir
 
# 使用完整AlphaFold2
docker pull ghcr.io/deepmind/alphafold
docker run -v $(pwd)/data:/data ghcr.io/deepmind/alphafold \
    --fasta_paths=/data/input.fasta \
    --output_dir=/data/output \
    --max_template_date=2024-01-01

8. 应用场景

  1. 药物发现:靶点结构解析、先导化合物优化
  2. 蛋白质设计:从头设计功能性蛋白质
  3. 功能注释:基于结构的蛋白质功能预测
  4. 进化分析:共进化分析、结构比对

参考文献


相关主题等变图神经网络, 分子动力学模拟, RFdiffusion蛋白质设计

Footnotes

  1. Jumper et al. “Highly accurate protein structure prediction with AlphaFold.” Nature 596, 583-589 (2021).