概述

分子可以用图结构自然表示:原子作为节点,化学键作为边。这种表示使得图神经网络成为分子建模的理想工具。1

分子图表示:

    H
    |
H - C = O     节点: C, H, H, O
    |         边: C-H, C-H, C=O
    H

应用场景

任务输入输出示例
性质预测分子图标量/向量毒性、溶解度
药物-靶点预测药物图 + 蛋白结构亲和力分数候选药物筛选
分子生成隐变量/条件分子图新药设计
反应预测反应物图产物图有机合成

1. 分子图的表示

1.1 节点特征

原子节点的特征向量:

特征描述维度
原子类型C, N, O, S, …原子种类数
度数直接键合的原子数0-6
手性R/S/无3
形式电荷+1, 0, -1, …整数
杂化类型sp, sp², sp³, …枚举
芳香性是否在芳香环中1
氢原子数连接的H数整数
质量原子质量实数

1.2 边特征

化学键的特征向量:

特征描述维度
键类型单键/双键/三键/芳香键4
共轭是否共轭1
成环是否在环中1
立体化学E/Z/无3

1.3 RDKit特征提取

from rdkit import Chem
from rdkit.Chem import AllChem
import torch
 
def atom_to_feature(atom):
    """将RDKit原子转换为特征向量"""
    return [
        atom.GetAtomicNum(),           # 原子序数
        atom.GetDegree(),              # 度数
        atom.GetFormalCharge(),        # 形式电荷
        atom.GetHybridization(),       # 杂化类型
        atom.GetIsAromatic(),          # 芳香性
        atom.GetTotalNumHs(),          # 氢原子数
        atom.GetChiralTag(),           # 手性
    ]
 
def bond_to_feature(bond):
    """将RDKit键转换为特征向量"""
    return [
        int(bond.GetBondType()),      # 键类型
        int(bond.GetIsConjugated()),  # 共轭
        int(bond.GetIsRing()),        # 成环
        int(bond.GetStereo()),        # 立体化学
    ]
 
def mol_to_graph(mol):
    """
    将RDKit分子转换为图结构
    """
    # 获取原子特征
    num_atoms = mol.GetNumAtoms()
    atom_features = []
    for atom in mol.GetAtoms():
        features = [
            atom.GetAtomicNum(),           # 100维独热编码
            atom.GetDegree(),
            atom.GetFormalCharge(),
            atom.GetTotalNumHs(),
            atom.GetNumRadicalElectrons(),
            atom.GetHybridization(),
            atom.GetIsAromatic(),
            atom.GetIsInRing(),
            atom.GetChiralTag(),
        ]
        atom_features.append(features)
    
    atom_features = torch.tensor(atom_features, dtype=torch.float)
    
    # 获取边索引和边特征
    edge_indices = []
    edge_features = []
    
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        
        edge_indices.append([i, j])
        edge_indices.append([j, i])
        
        bond_features = [
            int(bond.GetBondType()),
            int(bond.GetIsConjugated()),
            int(bond.GetIsRing()),
        ]
        edge_features.append(bond_features)
        edge_features.append(bond_features)
    
    edge_index = torch.tensor(edge_indices, dtype=torch.long).t()
    edge_attr = torch.tensor(edge_features, dtype=torch.float)
    
    return atom_features, edge_index, edge_attr

2. 分子性质预测

2.1 任务定义

给定分子图 ,预测其性质

常见的分子性质包括:

  • 量子力学性质:HOMO-LUMO能隙、偶极矩
  • 物理化学性质:溶解度、沸点、熔点
  • 生物学性质:毒性、血脑屏障穿透性

2.2 常用数据集

数据集分子数性质数描述
QM9134K12量子力学性质
ZINC250K~20药物相似性
Tox2112K12毒性
BBBP2K1血脑屏障穿透

2.3 分子GNN架构

消息传递神经网络(MPNN)

MPNN框架在分子图上特别有效:

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class MolecularMPNN(nn.Module):
    """分子性质预测的MPNN"""
    def __init__(self, node_dim, edge_dim, hidden_dim, num_layers=4, num_tasks=1):
        super().__init__()
        self.num_layers = num_layers
        self.num_tasks = num_tasks
        
        # 节点和边的嵌入
        self.node_embed = nn.Linear(node_dim, hidden_dim)
        self.edge_embed = nn.Linear(edge_dim, hidden_dim)
        
        # 消息传递层
        self.message_layers = nn.ModuleList([
            nn.Linear(hidden_dim * 2 + hidden_dim, hidden_dim)
            for _ in range(num_layers)
        ])
        
        # 更新层
        self.update_layers = nn.ModuleList([
            nn.GRUCell(hidden_dim, hidden_dim)
            for _ in range(num_layers)
        ])
        
        # 读出函数
        self.readout = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU()
        )
        
        # 任务头
        self.predictor = nn.Linear(hidden_dim // 2, num_tasks)
    
    def message(self, h_i, h_j, e_ij):
        """消息函数"""
        return F.relu(self.message_layers[0](torch.cat([h_i, h_j, e_ij], dim=-1)))
    
    def forward(self, x, edge_index, edge_attr, batch_idx=None):
        """
        x: (N, node_dim) 节点特征
        edge_index: (2, E) 边索引
        edge_attr: (E, edge_dim) 边特征
        """
        # 嵌入
        h = self.node_embed(x)
        e = self.edge_embed(edge_attr)
        
        # 消息传递
        for _ in range(self.num_layers):
            h_old = h.clone()
            
            # 收集邻居消息
            messages = torch.zeros_like(h)
            
            for i in range(edge_index.shape[1]):
                src, dst = edge_index[0, i], edge_index[1, i]
                msg = self.message(h_old[src], h_old[dst], e[i])
                messages[dst] += msg
            
            # 更新
            for v in range(h.shape[0]):
                neighbors = edge_index[1][edge_index[0] == v]
                if len(neighbors) > 0:
                    h[v] = self.update_layers[0](messages[v], h_old[v])
            
            h = F.dropout(h, p=0.2, training=self.training)
        
        # 图级读出
        if batch_idx is None:
            # 单图:全局池化
            graph_h = torch.cat([h.mean(dim=0), h.max(dim=0)[0]], dim=-1)
        else:
            # 多图:按batch池化
            graph_h = []
            for b in range(batch_idx.max() + 1):
                mask = batch_idx == b
                graph_h.append(torch.cat([
                    h[mask].mean(dim=0),
                    h[mask].max(dim=0)[0]
                ], dim=-1))
            graph_h = torch.stack(graph_h)
        
        # 预测
        h = self.readout(graph_h)
        out = self.predictor(h)
        
        return out

2.4 带边特征的GNN

分子性质预测中,边特征(如键类型)至关重要:

class EdgeConditionedGNN(nn.Module):
    """边条件GNN"""
    def __init__(self, node_dim, edge_dim, hidden_dim, num_layers=3):
        super().__init__()
        self.num_layers = num_layers
        
        # 节点和边嵌入
        self.node_embed = nn.Linear(node_dim, hidden_dim)
        self.edge_embed = nn.Linear(edge_dim, hidden_dim)
        
        # 消息网络(输入:源节点+边)
        self.message_nets = nn.ModuleList([
            nn.Linear(hidden_dim * 2, hidden_dim)
            for _ in range(num_layers)
        ])
        
        # 更新网络
        self.update_nets = nn.ModuleList([
            nn.GRUCell(hidden_dim, hidden_dim)
            for _ in range(num_layers)
        ])
    
    def forward(self, x, edge_index, edge_attr):
        h = self.node_embed(x)
        e = self.edge_embed(edge_attr)
        
        for l in range(self.num_layers):
            # 消息传递
            messages = torch.zeros_like(h)
            counts = torch.zeros(h.shape[0], device=h.device)
            
            for i in range(edge_index.shape[1]):
                src, dst = edge_index[0, i], edge_index[1, i]
                msg_input = torch.cat([h[src], e[i]], dim=-1)
                msg = F.relu(self.message_nets[l](msg_input))
                messages[dst] += msg
                counts[dst] += 1
            
            # 归一化
            counts[counts == 0] = 1
            messages = messages / counts.unsqueeze(-1)
            
            # 更新
            h_new = torch.zeros_like(h)
            for v in range(h.shape[0]):
                h_new[v] = self.update_nets[l](messages[v], h[v])
            
            h = h_new
        
        return h

3. 药物-靶点交互预测

3.1 任务定义

预测药物分子与蛋白质靶点之间的相互作用强度(亲和力)。2

药物图 + 蛋白图/序列 → 亲和力预测

3.2 蛋白表示

蛋白质可以用以下方式表示:

  1. 序列表示:氨基酸序列 → 蛋白质语言模型
  2. 图表示:氨基酸作为节点,空间接触作为边
class ProteinGraphBuilder:
    """构建蛋白质接触图"""
    def __init__(self, contact_threshold=8.0):
        self.contact_threshold = contact_threshold
    
    def build_from_structure(self, pdb_file):
        """从PDB结构文件构建图"""
        # 使用RCSB PDB或AlphaFold获取结构
        coords = self.extract_ca_coordinates(pdb_file)
        
        # 计算距离矩阵
        dist_matrix = torch.cdist(coords, coords)
        
        # 构建接触图
        edge_index = (dist_matrix < self.contact_threshold).nonzero().t()
        
        return edge_index, coords
    
    def build_from_sequence(self, sequence):
        """从序列构建k-mer图"""
        k = 3  # k-mer大小
        nodes = [sequence[i:i+k] for i in range(len(sequence) - k + 1)]
        edges = []
        
        for i in range(len(nodes) - 1):
            edges.append([i, i+1])  # 相邻k-mer相连
            edges.append([i+1, i])
        
        return torch.tensor(nodes), torch.tensor(edges)

3.3 药物-靶点交互模型

class DrugTargetInteractionModel(nn.Module):
    """药物-靶点交互预测模型"""
    def __init__(self, drug_node_dim, drug_edge_dim, target_dim, hidden_dim):
        super().__init__()
        
        # 药物编码器(分子GNN)
        self.drug_encoder = MolecularMPNN(
            node_dim=drug_node_dim,
            edge_dim=drug_edge_dim,
            hidden_dim=hidden_dim,
            num_layers=4
        )
        
        # 蛋白质编码器(Transformer/1D CNN)
        self.target_encoder = nn.Sequential(
            nn.Linear(20, hidden_dim),  # 氨基酸 one-hot
            nn.TransformerEncoder(
                nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=4),
                num_layers=3
            )
        )
        
        # 交互预测器
        self.predictor = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, drug_x, drug_edge_index, drug_edge_attr, 
                target_seq, drug_batch=None, target_batch=None):
        # 编码药物
        drug_emb = self.drug_encoder(drug_x, drug_edge_index, drug_edge_attr, drug_batch)
        
        # 编码靶点
        target_emb = self.target_encoder(target_seq)
        
        # 交互预测
        # 简单拼接
        interaction = torch.cat([
            drug_emb,
            target_emb.mean(dim=1),  # 蛋白质池化
            drug_emb * target_emb.mean(dim=1)  # 逐元素交互
        ], dim=-1)
        
        affinity = self.predictor(interaction)
        
        return affinity

4. 分子生成

4.1 任务定义

生成具有特定性质的分子图:

4.2 VAE方法:JT-VAE

连接树变分自编码器(JT-VAE)将分子分解为支架树分子图两部分。3

分子 → 分解 → 连接树 → 编码 → 隐向量 → 解码 → 连接树 → 组装 → 分子
class JTVAE(nn.Module):
    """JT-VAE: 连接树变分自编码器"""
    def __init__(self, hidden_dim, vocab_size):
        super().__init__()
        
        # 树编码器(Tree LSTM)
        self.tree_encoder = TreeLSTM(vocab_size, hidden_dim)
        
        # 图编码器(MPNN)
        self.graph_encoder = MolecularMPNN(..., hidden_dim)
        
        # 隐变量
        self.mean_net = nn.Linear(hidden_dim, hidden_dim)
        self.logvar_net = nn.Linear(hidden_dim, hidden_dim)
        
        # 解码器
        self.tree_decoder = GraphAFDecoder(hidden_dim)
        self.junction_tree_prior = JunctionTreePrior(hidden_dim)
    
    def encode(self, mol):
        """编码分子"""
        # 分解为树和图
        tree, graph = self.junction_tree_decompose(mol)
        
        # 编码
        tree_h = self.tree_encoder(tree)
        graph_h = self.graph_encoder(graph)
        
        # 组合
        combined_h = tree_h + graph_h
        
        # 隐变量
        mean = self.mean_net(combined_h)
        logvar = self.logvar_net(combined_h)
        
        return mean, logvar
    
    def decode(self, z):
        """从隐向量解码分子"""
        # 采样
        eps = torch.randn_like(z)
        h = z + eps * torch.exp(0.5 * z)
        
        # 解码为树
        tree = self.tree_decoder(h)
        
        # 树节点展开为图
        mol = self.assemble_molecule(tree)
        
        return mol
    
    def forward(self, mol):
        mean, logvar = self.encode(mol)
        
        # 重参数化
        z = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)
        
        # 解码
        mol_recon = self.decode(z)
        
        # 损失
        recon_loss = self.compute_recon_loss(mol, mol_recon)
        kl_loss = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
        
        return recon_loss + kl_loss

4.3 GAN方法:GCPN

图卷积策略网络(GCPN)使用强化学习结合GAN进行分子生成。4

class GCPNModel(nn.Module):
    """GCPN: Graph Convolutional Policy Network"""
    def __init__(self, node_dim, edge_dim, hidden_dim, num_layers=6):
        super().__init__()
        
        # 图生成器
        self.graph_generator = GraphGenerationPolicy(
            node_dim=node_dim,
            edge_dim=edge_dim,
            hidden_dim=hidden_dim,
            num_layers=num_layers
        )
        
        # 判别器
        self.discriminator = MolecularDiscriminator(
            node_dim=node_dim,
            edge_dim=edge_dim,
            hidden_dim=hidden_dim
        )
        
        # 奖励网络
        self.reward_net = MolecularRewardNet(hidden_dim)
    
    def generate_step(self, graph_state, action):
        """
        一步生成动作
        action: (add_node, add_edge, stop)
        """
        if action.type == 'add_node':
            # 添加新节点
            graph_state.add_node(action.node_type, action.node_features)
        
        elif action.type == 'add_edge':
            # 添加边
            graph_state.add_edge(action.src, action.dst, action.edge_type)
        
        elif action.type == 'stop':
            # 停止生成
            pass
        
        return graph_state
    
    def discriminator_loss(self, generated_mols, real_mols):
        """判别器损失"""
        # 生成样本
        fake_scores = self.discriminator(generated_mols)
        
        # 真实样本
        real_scores = self.discriminator(real_mols)
        
        # WGAN-GP loss
        gp = self.compute_gradient_penalty(generated_mols, real_mols)
        
        return -torch.mean(fake_scores) + torch.mean(real_scores) + 10 * gp
    
    def generator_loss(self, generated_mols, target_properties):
        """生成器损失"""
        # 判别器奖励
        d_reward = torch.sigmoid(self.discriminator(generated_mols))
        
        # 性质奖励
        p_reward = -torch.abs(self.reward_net(generated_mols) - target_properties)
        
        # 有效性奖励
        valid_reward = self.check_validity(generated_mols)
        
        return -(d_reward + p_reward + valid_reward).mean()

5. 分子图Transformer

5.1 分子语言模型

将分子表示为SMILES序列,使用Transformer进行编码:

class MoleculeTransformer(nn.Module):
    """分子Transformer"""
    def __init__(self, vocab_size, hidden_dim, num_heads, num_layers):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        self.pos_embedding = nn.Embedding(512, hidden_dim)
        
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_dim,
                nhead=num_heads,
                dim_feedforward=hidden_dim * 4
            ),
            num_layers=num_layers
        )
        
        self.predictor = nn.Linear(hidden_dim, 1)
    
    def forward(self, smiles_tokens):
        """
        smiles_tokens: (batch, seq_len) SMILES token序列
        """
        x = self.embedding(smiles_tokens)
        x = x + self.pos_embedding(torch.arange(x.shape[1], device=x.device))
        
        # Transformer编码
        x = x.transpose(0, 1)  # (seq, batch, dim)
        x = self.transformer(x)
        x = x.transpose(0, 1)  # (batch, seq, dim)
        
        # 池化
        x = x.mean(dim=1)
        
        return self.predictor(x)

5.2 几何深度学习

对于3D分子结构,需要考虑几何约束

class GeometricGNN(nn.Module):
    """几何GNN:考虑3D坐标"""
    def __init__(self, node_dim, edge_dim, hidden_dim):
        super().__init__()
        
        self.node_embed = nn.Linear(node_dim, hidden_dim)
        self.edge_embed = nn.Linear(edge_dim, hidden_dim)
        
        # 几何感知的消息函数
        self.message_net = nn.Sequential(
            nn.Linear(hidden_dim * 3 + 1, hidden_dim),  # h_i, h_j, e_ij, dist
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.update_net = nn.GRUCell(hidden_dim, hidden_dim)
    
    def message(self, h_i, h_j, e_ij, dist_ij):
        """考虑距离的消息"""
        msg_input = torch.cat([h_i, h_j, e_ij, dist_ij.unsqueeze(-1)], dim=-1)
        return F.relu(self.message_net(msg_input))
    
    def forward(self, x, edge_index, edge_attr, coords):
        """
        coords: (N, 3) 原子坐标
        """
        h = self.node_embed(x)
        e = self.edge_embed(edge_attr)
        
        # 计算距离
        dist = torch.norm(coords.unsqueeze(1) - coords.unsqueeze(0), dim=-1)
        
        # 消息传递
        messages = torch.zeros_like(h)
        
        for i in range(edge_index.shape[1]):
            src, dst = edge_index[0, i], edge_index[1, i]
            msg = self.message(h[src], h[dst], e[i], dist[src, dst])
            messages[dst] += msg
        
        # 更新
        for v in range(h.shape[0]):
            h[v] = self.update_net(messages[v], h[v])
        
        return h

6. 实践:QM9性质预测

6.1 数据加载

from torch_geometric.datasets import QM9
from torch_geometric.data import DataLoader
 
def load_qm9(batch_size=32):
    """加载QM9数据集"""
    dataset = QM9(root='./data/QM9')
    
    # QM9的12个性质
    props = [
        'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2',
        'zpve', 'u0', 'u298', 'h298', 'g298', 'cv'
    ]
    
    return dataset, props
 
def qm9_collate(batch):
    """QM9数据整理"""
    return batch

6.2 完整训练脚本

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import QM9
from torch_geometric.nn import global_mean_pool, global_add_pool
 
class MolecularGNN(nn.Module):
    """分子性质预测GNN"""
    def __init__(self, node_dim, hidden_dim, num_tasks=12):
        super().__init__()
        
        self.encoder = EdgeConditionedGNN(
            node_dim=node_dim,
            edge_dim=5,  # 键类型维度
            hidden_dim=hidden_dim,
            num_layers=4
        )
        
        self.readout = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU()
        )
        
        self.predictor = nn.Linear(hidden_dim // 2, num_tasks)
    
    def forward(self, data):
        # 编码
        h = self.encoder(data.x, data.edge_index, data.edge_attr)
        
        # 图级池化
        h_mean = global_mean_pool(h, data.batch)
        h_max = global_add_pool(h, data.batch)
        h = torch.cat([h_mean, h_max], dim=-1)
        
        # 读出
        h = self.readout(h)
        
        return self.predictor(h)
 
 
def train_qm9():
    # 加载数据
    dataset = QM9(root='./data/QM9')
    
    # 划分数据集
    perm = torch.randperm(len(dataset))
    train_idx = perm[:100000]
    val_idx = perm[100000:110000]
    test_idx = perm[110000:]
    
    train_loader = DataLoader(dataset[train_idx.tolist()], batch_size=64, shuffle=True)
    val_loader = DataLoader(dataset[val_idx.tolist()], batch_size=64)
    test_loader = DataLoader(dataset[test_idx.tolist()], batch_size=64)
    
    # 模型
    model = MolecularGNN(
        node_dim=11,  # QM9节点特征维度
        hidden_dim=128,
        num_tasks=12
    ).to('cuda')
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.L1Loss()  # MAE损失
    
    # 训练
    for epoch in range(50):
        model.train()
        train_loss = 0
        
        for batch in train_loader:
            batch = batch.to('cuda')
            
            optimizer.zero_grad()
            pred = model(batch)
            
            # 所有12个性质
            loss = criterion(pred, batch.y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # 验证
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to('cuda')
                pred = model(batch)
                loss = criterion(pred, batch.y)
                val_loss += loss.item()
        
        print(f"Epoch {epoch}: Train Loss = {train_loss/len(train_loader):.4f}, "
              f"Val Loss = {val_loss/len(val_loader):.4f}")
    
    # 测试
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to('cuda')
            pred = model(batch)
            loss = criterion(pred, batch.y)
            test_loss += loss.item()
    
    print(f"Test Loss: {test_loss/len(test_loader):.4f}")
 
if __name__ == '__main__':
    train_qm9()

7. 相关主题

主题描述
图神经网络GNN基础概念
图卷积网络GCN方法
GNN高效训练大规模图训练

参考

Footnotes

  1. Gilmer et al., “Neural Message Passing for Quantum Chemistry”, ICML 2017

  2. Öztürk et al., “DeepDTA: Deep Drug-Target Binding Affinity Prediction”, Bioinformatics 2019

  3. Jin et al., “Junction Tree Variational Autoencoder for Molecular Graph Generation”, ICML 2018

  4. You et al., “Graph Convolutional Policy Network for Goal-Directed Molecular Graph Generation”, NeurIPS 2018