RFdiffusion:扩散模型蛋白质设计

1. 引言

蛋白质设计是逆向蛋白质折叠问题:不是从序列预测结构,而是从目标结构/功能生成序列。传统方法如RosettaDesign依赖能量函数优化,计算量大且成功率有限。

RFdiffusion是华盛顿大学David Baker实验室的工作,将扩散模型引入蛋白质设计,实现了从随机噪声生成功能性蛋白质结构的突破。1

核心贡献

  1. 无条件生成:从随机噪声生成新颖蛋白质结构
  2. Motif脚手架:在给定功能位点约束下设计蛋白质
  3. 对称性处理:生成具有对称性的蛋白质寡聚体
  4. 靶向结合:从头设计高亲和力结合蛋白

2. 扩散模型基础

2.1 前向过程(Forward Process)

在蛋白质设计中,我们对原子坐标 添加噪声:

其中 是噪声调度。

经过 步后, 接近各向同性高斯分布:

2.2 反向过程(Reverse Process)

学习反向过程

2.3 训练目标

简化版本:直接预测去噪后的坐标或噪声:

其中


3. RoseTTAFold架构

3.1 核心组件

RFdiffusion基于RoseTTAFold的架构,包括三条信息轨道:

┌─────────────────────────────────────────────────────┐
│                 1D轨道:序列特征                     │
│    氨基酸类型、进化信息、溶剂可及性                  │
├─────────────────────────────────────────────────────┤
│                 2D轨道:成对关系                     │
│    距离图、残基间接触、相互作用倾向                  │
├─────────────────────────────────────────────────────┤
│                 3D轨道:三维结构                     │
│    主链坐标、原子类型、方向信息                      │
└─────────────────────────────────────────────────────┘

3.2 Sequence Tokenization

class SequenceEncoder(nn.Module):
    """
    序列编码器
    """
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Parameter(torch.randn(1, 512, embed_dim))
    
    def forward(self, sequence):
        """
        Args:
            sequence: [B, L] 氨基酸序列
        Returns:
            seq_tokens: [B, L, d] 序列token
        """
        tokens = self.embedding(sequence)
        tokens = tokens + self.pos_embedding[:, :tokens.size(1)]
        return tokens

3.3 Pairformer层

class PairformerBlock(nn.Module):
    """
    RoseTTAFold的Pairformer块
    """
    def __init__(self, d_model, num_heads, d_pair):
        super().__init__()
        
        # 序列自注意力
        self.seq_attn = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        
        # 成对注意力
        self.pair_attn = nn.MultiheadAttention(d_pair, num_heads, batch_first=True)
        
        # 交叉注意力
        self.cross_attn_seq2pair = nn.MultiheadAttention(d_model, num_heads, batch_first=True)
        self.cross_attn_pair2seq = nn.MultiheadAttention(d_pair, num_heads, batch_first=True)
        
        # FFN
        self.seq_ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, d_model)
        )
        self.pair_ffn = nn.Sequential(
            nn.Linear(d_pair, d_pair * 4),
            nn.ReLU(),
            nn.Linear(d_pair * 4, d_pair)
        )
    
    def forward(self, seq_repr, pair_repr):
        # 序列自注意力 + 残差
        seq_out, _ = self.seq_attn(seq_repr, seq_repr, seq_repr)
        seq_repr = seq_repr + seq_out
        
        # 成对注意力
        pair_out, _ = self.pair_attn(pair_repr, pair_repr, pair_repr)
        pair_repr = pair_repr + pair_out
        
        # 交叉注意力
        seq2pair, _ = self.cross_attn_seq2pair(pair_repr, seq_repr, seq_repr)
        pair_repr = pair_repr + seq2pair
        
        # FFN
        seq_repr = seq_repr + self.seq_ffn(seq_repr)
        pair_repr = pair_repr + self.pair_ffn(pair_repr)
        
        return seq_repr, pair_repr

3.4 3D结构模块

class StructureModule(nn.Module):
    """
    3D结构生成模块
    """
    def __init__(self, d_model, d_pair):
        super().__init__()
        
        # 骨架表示:旋转矩阵 + 平移向量
        self.backbone_update = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 6)  # 3旋转 + 3平移
        )
        
        # 等变点注意力 (IPA)
        self.ipa = InvariantPointAttention(d_model, d_pair, num_heads=8, num_points=4)
        
        # 层归一化
        self.ln = nn.LayerNorm(d_model)
    
    def forward(self, seq_repr, pair_repr, backbone):
        """
        Args:
            seq_repr: [B, L, d_model] 序列表示
            pair_repr: [B, L, L, d_pair] 成对表示
            backbone: 当前骨架状态
        """
        # IPA更新
        ipa_out = self.ipa(
            query_repr=self.ln(seq_repr),
            pair_repr=pair_repr,
            backbone=backbone
        )
        
        # 更新序列表示
        seq_repr = seq_repr + ipa_out
        
        # 骨架更新
        update = self.backbone_update(self.ln(seq_repr))
        backbone = backbone.compose(update)
        
        return seq_repr, backbone

4. RFdiffusion扩散机制

4.1 噪声添加策略

RFdiffusion在骨架坐标上添加噪声:

class DiffusiveRF(nn.Module):
    """
    RFdiffusion的核心
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # 轨道Transformer(3条信息轨道)
        self.track_transformer = TrackTransformer(config)
        
        # 条件编码器(用于motif约束)
        self.condition_encoder = ConditionEncoder(config)
        
        # 去噪头
        self.denoise_head = nn.Linear(config.d_model, 3)  # 预测3D偏移
    
    def forward_diffusion(self, coords, t):
        """
        前向扩散:添加噪声
        """
        noise = torch.randn_like(coords)
        
        # 计算t时刻的noisy坐标
        alpha_t = self.noise_schedule.alpha[t]
        sigma_t = self.noise_schedule.sigma[t]
        
        noisy_coords = alpha_t * coords + sigma_t * noise
        
        return noisy_coords, noise
    
    def forward(self, noisy_coords, t, conditions=None, mask=None):
        """
        反向去噪
        """
        # 1. 时间嵌入
        t_embed = self.time_embedding(t)
        
        # 2. 条件编码(如果有motif约束)
        if conditions is not None:
            cond_embed = self.condition_encoder(conditions)
        else:
            cond_embed = None
        
        # 3. 轨道Transformer处理
        seq_repr, pair_repr = self.track_transformer(
            noisy_coords,
            t_embed,
            conditions=cond_embed,
            mask=mask
        )
        
        # 4. 预测噪声或去噪坐标
        output = self.denoise_head(seq_repr)
        
        return output

4.2 条件生成(Motif脚手架)

RFdiffusion支持在给定部分结构约束下生成:

class MotifScaffolding:
    """
    Motif脚手架生成
    """
    def __init__(self, model):
        self.model = model
    
    def generate_with_motif(
        self,
        motif_coords,      # 已知功能位点的坐标
        motif_mask,        # 哪些残基是固定的
        scaffold_length,   # 脚手架长度
        num_samples=10
    ):
        """
        生成带有motif约束的蛋白质
        """
        results = []
        
        for i in range(num_samples):
            # 初始化:motif部分使用真实坐标,其余随机
            coords = torch.randn(1, scaffold_length + len(motif_coords), 3)
            
            # 固定motif区域
            coords[:, motif_mask] = motif_coords
            
            # 扩散采样
            for t in reversed(range(self.model.num_steps)):
                with torch.no_grad():
                    # 预测噪声
                    noise_pred = self.model(coords, t)
                    
                    # 计算条件梯度(保持motif固定)
                    if t > 0:
                        coords = self.ddim_step(
                            coords, noise_pred, t, 
                            motif_mask=motif_mask
                        )
                    else:
                        coords = self.final_step(coords, noise_pred)
            
            results.append(coords)
        
        return results
    
    def ddim_step(self, coords, noise_pred, t, motif_mask=None):
        """
        DDIM采样步骤
        """
        alpha_t = self.noise_schedule.alpha[t]
        sigma_t = self.noise_schedule.sigma[t]
        
        # 预测干净坐标
        pred_coords = (coords - sigma_t * noise_pred) / alpha_t
        
        # 保持motif固定
        if motif_mask is not None:
            pred_coords[:, motif_mask] = coords[:, motif_mask]
        
        # 添加少量噪声
        if t > 0:
            alpha_prev = self.noise_schedule.alpha[t-1]
            sigma_prev = self.noise_schedule.sigma[t-1]
            
            coords = alpha_prev * pred_coords + sigma_prev * torch.randn_like(coords)
            
            # 重新应用motif约束
            coords[:, motif_mask] = motif_coords[:, motif_mask]
        else:
            coords = pred_coords
        
        return coords

5. 对称性处理

5.1 对称扩散

对于具有对称性的蛋白质(如环形、十聚体),RFdiffusion支持对称扩散:

class SymmetricDiffusion:
    """
    对称蛋白质扩散
    """
    def __init__(self, symmetry_type, n_subunits):
        self.symmetry_type = symmetry_type  # 'c_n', 'd_n', ...
        self.n_subunits = n_subunits
    
    def symmetrize_coords(self, coords):
        """
        将非对称坐标转换为对称形式
        """
        if self.symmetry_type == 'c_n':
            # 循环对称
            base_coords = coords[:, :len(coords)//self.n_subunits]
            
            symmetrized = []
            for i in range(self.n_subunits):
                angle = 2 * np.pi * i / self.n_subunits
                rot_matrix = self.rotation_matrix_z(angle)
                rotated = torch.matmul(base_coords, rot_matrix.T)
                symmetrized.append(rotated)
            
            return torch.cat(symmetrized, dim=1)
        
        return coords
    
    def apply_symmetric_noise(self, coords, t):
        """
        对称地添加噪声
        """
        noise = torch.randn_like(coords)
        
        # 对称化噪声
        if self.symmetry_type == 'c_n':
            # 只对基础单元添加噪声,然后旋转复制
            noise_base = noise[:, :len(coords)//self.n_subunits]
            symmetrized_noise = self.symmetrize_coords(noise_base)
        else:
            symmetrized_noise = noise
        
        # 添加到坐标
        alpha_t = self.noise_schedule.alpha[t]
        sigma_t = self.noise_schedule.sigma[t]
        
        return alpha_t * coords + sigma_t * symmetrized_noise

5.2 对称蛋白质设计示例

# 设计C5对称纳米颗粒
python -m rfdiffusion.inference \
    input_pdb="motif.pdb" \
    contig="A:50-80" \
    symmetry="c5" \
    iterations=200 \
    out_prefix="designed_c5_nanoparticle"

6. ProteinMPNN序列设计

6.1 序列设计模型

RFdiffusion生成结构后,使用ProteinMPNN设计对应序列:

class ProteinMPNN(nn.Module):
    """
    图神经网络序列设计
    """
    def __init__(self, hidden_dim=128, num_layers=3):
        super().__init__()
        
        # 消息传递网络
        self.encoders = nn.ModuleList([
            MessagePassingLayer(hidden_dim) for _ in range(num_layers)
        ])
        
        # 序列解码器
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 20)  # 20种氨基酸
        )
        
        # 温度调度
        self.temperature = 0.1
    
    def forward(self, coords, omit_AAs=[], bias_AA=None):
        """
        Args:
            coords: [N, 3] 原子坐标
            omit_AAs: 不使用的氨基酸列表
            bias_AA: 氨基酸偏好
        """
        # 1. 图构建
        edge_index = self.knn_graph(coords, k=30)
        
        # 2. 编码
        node_features = self.coord_encoder(coords)
        
        # 3. 消息传递
        for encoder in self.encoders:
            node_features = encoder(node_features, edge_index)
        
        # 4. 解码为序列
        logits = self.decoder(node_features)  # [N, 20]
        
        # 5. 应用约束
        if omit_AAs:
            for AA in omit_AAs:
                logits[:, self.aa_to_idx[AA]] = -inf
        
        if bias_AA:
            logits = logits + bias_AA
        
        # 6. 采样
        probs = F.softmax(logits / self.temperature, dim=-1)
        sequence = torch.multinomial(probs, 1).squeeze(-1)
        
        return sequence
    
    def training_loss(self, sequence, coords):
        """
        对数似然损失
        """
        logits = self.forward(coords)
        loss = F.cross_entropy(logits, sequence)
        return loss

6.2 端到端设计流程

class ProteinDesignPipeline:
    """
    端到端蛋白质设计流程
    """
    def __init__(self):
        self.diffusion_model = DiffusiveRF(checkpoint="rfdiffusion.pt")
        self.sequence_model = ProteinMPNN(checkpoint="proteinmpnn.pt")
    
    def design_protein(
        self,
        target_function=None,
        scaffold_length=150,
        num_designs=100
    ):
        """
        完整设计流程
        """
        designs = []
        
        for i in range(num_designs):
            # 1. 结构扩散生成
            structure = self.diffusion_model.sample(
                length=scaffold_length,
                conditions=target_function
            )
            
            # 2. 序列设计
            sequence = self.sequence_model.sample(structure)
            
            # 3. 结构验证(Rosetta)
            plddt = self.validate_structure(structure, sequence)
            
            if plddt > 0.8:
                designs.append({
                    "structure": structure,
                    "sequence": sequence,
                    "plddt": plddt
                })
        
        return designs
    
    def validate_structure(self, structure, sequence):
        """
        使用AlphaFold验证设计的序列是否折叠成目标结构
        """
        # AlphaFold结构预测
        predicted = alphafold_predict(sequence)
        
        # 计算TM-score
        tm_score = compute_tm_score(structure, predicted)
        
        return tm_score

7. 应用案例

7.1 功能性蛋白质设计

应用案例成功率
酶设计催化特定化学反应的人工酶~30%
结合蛋白高亲和力靶向特定蛋白的结合剂~60%
纳米颗粒自组装的纳米笼/纳米管~70%
疫苗设计稳定的纳米颗粒疫苗平台~80%

7.2 设计流程示例

# 设计结合HER2受体的纳米抗体
from rfdiffusion import RFdiffusionRunner
 
# 初始化
runner = RFdiffusionRunner(
    inference_config="config.yaml",
    model_path="rfdiffusion_weights.pt"
)
 
# 定义motif:HER2结合表位
motif = {
    "coords": load_pdb("her2_epitope.pdb"),
    "chain": "A",
    "residue_indices": [30, 31, 32, 33, 34]
}
 
# 生成设计
designs = runner.design(
    motif=motif,
    scaffold_length=120,
    num_iterations=200,
    num_seeds=10
)
 
# 筛选最佳设计
best = max(designs, key=lambda x: x["tm_score"])
 
# 序列设计
from proteinmpnn import ProteinMPNNRunner
seq_designer = ProteinMPNNRunner()
sequence = seq_designer.sequence_design(best["structure"])
 
print(f"Designed sequence: {sequence}")

8. 局限性与未来方向

局限性当前解决方案未来方向
主链精度迭代优化更精确的扩散模型
侧链设计ProteinMPNN端到端侧链生成
动态构象多状态设计动态扩散模型
配体复合物限制性设计全原子扩散

参考文献


相关主题AlphaFold蛋白质结构, 扩散模型, 等变图神经网络

Footnotes

  1. Watson et al. “De novo design of protein structure and function with RFdiffusion.” Nature 620, 1089-1100 (2023).