概述
分子可以用图结构自然表示:原子作为节点,化学键作为边。这种表示使得图神经网络成为分子建模的理想工具。1
分子图表示:
H
|
H - C = O 节点: C, H, H, O
| 边: C-H, C-H, C=O
H
应用场景
| 任务 | 输入 | 输出 | 示例 |
|---|---|---|---|
| 性质预测 | 分子图 | 标量/向量 | 毒性、溶解度 |
| 药物-靶点预测 | 药物图 + 蛋白结构 | 亲和力分数 | 候选药物筛选 |
| 分子生成 | 隐变量/条件 | 分子图 | 新药设计 |
| 反应预测 | 反应物图 | 产物图 | 有机合成 |
1. 分子图的表示
1.1 节点特征
原子节点的特征向量:
| 特征 | 描述 | 维度 |
|---|---|---|
| 原子类型 | C, N, O, S, … | 原子种类数 |
| 度数 | 直接键合的原子数 | 0-6 |
| 手性 | R/S/无 | 3 |
| 形式电荷 | +1, 0, -1, … | 整数 |
| 杂化类型 | sp, sp², sp³, … | 枚举 |
| 芳香性 | 是否在芳香环中 | 1 |
| 氢原子数 | 连接的H数 | 整数 |
| 质量 | 原子质量 | 实数 |
1.2 边特征
化学键的特征向量:
| 特征 | 描述 | 维度 |
|---|---|---|
| 键类型 | 单键/双键/三键/芳香键 | 4 |
| 共轭 | 是否共轭 | 1 |
| 成环 | 是否在环中 | 1 |
| 立体化学 | E/Z/无 | 3 |
1.3 RDKit特征提取
from rdkit import Chem
from rdkit.Chem import AllChem
import torch
def atom_to_feature(atom):
"""将RDKit原子转换为特征向量"""
return [
atom.GetAtomicNum(), # 原子序数
atom.GetDegree(), # 度数
atom.GetFormalCharge(), # 形式电荷
atom.GetHybridization(), # 杂化类型
atom.GetIsAromatic(), # 芳香性
atom.GetTotalNumHs(), # 氢原子数
atom.GetChiralTag(), # 手性
]
def bond_to_feature(bond):
"""将RDKit键转换为特征向量"""
return [
int(bond.GetBondType()), # 键类型
int(bond.GetIsConjugated()), # 共轭
int(bond.GetIsRing()), # 成环
int(bond.GetStereo()), # 立体化学
]
def mol_to_graph(mol):
"""
将RDKit分子转换为图结构
"""
# 获取原子特征
num_atoms = mol.GetNumAtoms()
atom_features = []
for atom in mol.GetAtoms():
features = [
atom.GetAtomicNum(), # 100维独热编码
atom.GetDegree(),
atom.GetFormalCharge(),
atom.GetTotalNumHs(),
atom.GetNumRadicalElectrons(),
atom.GetHybridization(),
atom.GetIsAromatic(),
atom.GetIsInRing(),
atom.GetChiralTag(),
]
atom_features.append(features)
atom_features = torch.tensor(atom_features, dtype=torch.float)
# 获取边索引和边特征
edge_indices = []
edge_features = []
for bond in mol.GetBonds():
i = bond.GetBeginAtomIdx()
j = bond.GetEndAtomIdx()
edge_indices.append([i, j])
edge_indices.append([j, i])
bond_features = [
int(bond.GetBondType()),
int(bond.GetIsConjugated()),
int(bond.GetIsRing()),
]
edge_features.append(bond_features)
edge_features.append(bond_features)
edge_index = torch.tensor(edge_indices, dtype=torch.long).t()
edge_attr = torch.tensor(edge_features, dtype=torch.float)
return atom_features, edge_index, edge_attr2. 分子性质预测
2.1 任务定义
给定分子图 ,预测其性质 :
常见的分子性质包括:
- 量子力学性质:HOMO-LUMO能隙、偶极矩
- 物理化学性质:溶解度、沸点、熔点
- 生物学性质:毒性、血脑屏障穿透性
2.2 常用数据集
| 数据集 | 分子数 | 性质数 | 描述 |
|---|---|---|---|
| QM9 | 134K | 12 | 量子力学性质 |
| ZINC | 250K | ~20 | 药物相似性 |
| Tox21 | 12K | 12 | 毒性 |
| BBBP | 2K | 1 | 血脑屏障穿透 |
2.3 分子GNN架构
消息传递神经网络(MPNN)
MPNN框架在分子图上特别有效:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MolecularMPNN(nn.Module):
"""分子性质预测的MPNN"""
def __init__(self, node_dim, edge_dim, hidden_dim, num_layers=4, num_tasks=1):
super().__init__()
self.num_layers = num_layers
self.num_tasks = num_tasks
# 节点和边的嵌入
self.node_embed = nn.Linear(node_dim, hidden_dim)
self.edge_embed = nn.Linear(edge_dim, hidden_dim)
# 消息传递层
self.message_layers = nn.ModuleList([
nn.Linear(hidden_dim * 2 + hidden_dim, hidden_dim)
for _ in range(num_layers)
])
# 更新层
self.update_layers = nn.ModuleList([
nn.GRUCell(hidden_dim, hidden_dim)
for _ in range(num_layers)
])
# 读出函数
self.readout = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU()
)
# 任务头
self.predictor = nn.Linear(hidden_dim // 2, num_tasks)
def message(self, h_i, h_j, e_ij):
"""消息函数"""
return F.relu(self.message_layers[0](torch.cat([h_i, h_j, e_ij], dim=-1)))
def forward(self, x, edge_index, edge_attr, batch_idx=None):
"""
x: (N, node_dim) 节点特征
edge_index: (2, E) 边索引
edge_attr: (E, edge_dim) 边特征
"""
# 嵌入
h = self.node_embed(x)
e = self.edge_embed(edge_attr)
# 消息传递
for _ in range(self.num_layers):
h_old = h.clone()
# 收集邻居消息
messages = torch.zeros_like(h)
for i in range(edge_index.shape[1]):
src, dst = edge_index[0, i], edge_index[1, i]
msg = self.message(h_old[src], h_old[dst], e[i])
messages[dst] += msg
# 更新
for v in range(h.shape[0]):
neighbors = edge_index[1][edge_index[0] == v]
if len(neighbors) > 0:
h[v] = self.update_layers[0](messages[v], h_old[v])
h = F.dropout(h, p=0.2, training=self.training)
# 图级读出
if batch_idx is None:
# 单图:全局池化
graph_h = torch.cat([h.mean(dim=0), h.max(dim=0)[0]], dim=-1)
else:
# 多图:按batch池化
graph_h = []
for b in range(batch_idx.max() + 1):
mask = batch_idx == b
graph_h.append(torch.cat([
h[mask].mean(dim=0),
h[mask].max(dim=0)[0]
], dim=-1))
graph_h = torch.stack(graph_h)
# 预测
h = self.readout(graph_h)
out = self.predictor(h)
return out2.4 带边特征的GNN
分子性质预测中,边特征(如键类型)至关重要:
class EdgeConditionedGNN(nn.Module):
"""边条件GNN"""
def __init__(self, node_dim, edge_dim, hidden_dim, num_layers=3):
super().__init__()
self.num_layers = num_layers
# 节点和边嵌入
self.node_embed = nn.Linear(node_dim, hidden_dim)
self.edge_embed = nn.Linear(edge_dim, hidden_dim)
# 消息网络(输入:源节点+边)
self.message_nets = nn.ModuleList([
nn.Linear(hidden_dim * 2, hidden_dim)
for _ in range(num_layers)
])
# 更新网络
self.update_nets = nn.ModuleList([
nn.GRUCell(hidden_dim, hidden_dim)
for _ in range(num_layers)
])
def forward(self, x, edge_index, edge_attr):
h = self.node_embed(x)
e = self.edge_embed(edge_attr)
for l in range(self.num_layers):
# 消息传递
messages = torch.zeros_like(h)
counts = torch.zeros(h.shape[0], device=h.device)
for i in range(edge_index.shape[1]):
src, dst = edge_index[0, i], edge_index[1, i]
msg_input = torch.cat([h[src], e[i]], dim=-1)
msg = F.relu(self.message_nets[l](msg_input))
messages[dst] += msg
counts[dst] += 1
# 归一化
counts[counts == 0] = 1
messages = messages / counts.unsqueeze(-1)
# 更新
h_new = torch.zeros_like(h)
for v in range(h.shape[0]):
h_new[v] = self.update_nets[l](messages[v], h[v])
h = h_new
return h3. 药物-靶点交互预测
3.1 任务定义
预测药物分子与蛋白质靶点之间的相互作用强度(亲和力)。2
药物图 + 蛋白图/序列 → 亲和力预测
3.2 蛋白表示
蛋白质可以用以下方式表示:
- 序列表示:氨基酸序列 → 蛋白质语言模型
- 图表示:氨基酸作为节点,空间接触作为边
class ProteinGraphBuilder:
"""构建蛋白质接触图"""
def __init__(self, contact_threshold=8.0):
self.contact_threshold = contact_threshold
def build_from_structure(self, pdb_file):
"""从PDB结构文件构建图"""
# 使用RCSB PDB或AlphaFold获取结构
coords = self.extract_ca_coordinates(pdb_file)
# 计算距离矩阵
dist_matrix = torch.cdist(coords, coords)
# 构建接触图
edge_index = (dist_matrix < self.contact_threshold).nonzero().t()
return edge_index, coords
def build_from_sequence(self, sequence):
"""从序列构建k-mer图"""
k = 3 # k-mer大小
nodes = [sequence[i:i+k] for i in range(len(sequence) - k + 1)]
edges = []
for i in range(len(nodes) - 1):
edges.append([i, i+1]) # 相邻k-mer相连
edges.append([i+1, i])
return torch.tensor(nodes), torch.tensor(edges)3.3 药物-靶点交互模型
class DrugTargetInteractionModel(nn.Module):
"""药物-靶点交互预测模型"""
def __init__(self, drug_node_dim, drug_edge_dim, target_dim, hidden_dim):
super().__init__()
# 药物编码器(分子GNN)
self.drug_encoder = MolecularMPNN(
node_dim=drug_node_dim,
edge_dim=drug_edge_dim,
hidden_dim=hidden_dim,
num_layers=4
)
# 蛋白质编码器(Transformer/1D CNN)
self.target_encoder = nn.Sequential(
nn.Linear(20, hidden_dim), # 氨基酸 one-hot
nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=4),
num_layers=3
)
)
# 交互预测器
self.predictor = nn.Sequential(
nn.Linear(hidden_dim * 3, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, 1)
)
def forward(self, drug_x, drug_edge_index, drug_edge_attr,
target_seq, drug_batch=None, target_batch=None):
# 编码药物
drug_emb = self.drug_encoder(drug_x, drug_edge_index, drug_edge_attr, drug_batch)
# 编码靶点
target_emb = self.target_encoder(target_seq)
# 交互预测
# 简单拼接
interaction = torch.cat([
drug_emb,
target_emb.mean(dim=1), # 蛋白质池化
drug_emb * target_emb.mean(dim=1) # 逐元素交互
], dim=-1)
affinity = self.predictor(interaction)
return affinity4. 分子生成
4.1 任务定义
生成具有特定性质的分子图:
4.2 VAE方法:JT-VAE
连接树变分自编码器(JT-VAE)将分子分解为支架树和分子图两部分。3
分子 → 分解 → 连接树 → 编码 → 隐向量 → 解码 → 连接树 → 组装 → 分子
class JTVAE(nn.Module):
"""JT-VAE: 连接树变分自编码器"""
def __init__(self, hidden_dim, vocab_size):
super().__init__()
# 树编码器(Tree LSTM)
self.tree_encoder = TreeLSTM(vocab_size, hidden_dim)
# 图编码器(MPNN)
self.graph_encoder = MolecularMPNN(..., hidden_dim)
# 隐变量
self.mean_net = nn.Linear(hidden_dim, hidden_dim)
self.logvar_net = nn.Linear(hidden_dim, hidden_dim)
# 解码器
self.tree_decoder = GraphAFDecoder(hidden_dim)
self.junction_tree_prior = JunctionTreePrior(hidden_dim)
def encode(self, mol):
"""编码分子"""
# 分解为树和图
tree, graph = self.junction_tree_decompose(mol)
# 编码
tree_h = self.tree_encoder(tree)
graph_h = self.graph_encoder(graph)
# 组合
combined_h = tree_h + graph_h
# 隐变量
mean = self.mean_net(combined_h)
logvar = self.logvar_net(combined_h)
return mean, logvar
def decode(self, z):
"""从隐向量解码分子"""
# 采样
eps = torch.randn_like(z)
h = z + eps * torch.exp(0.5 * z)
# 解码为树
tree = self.tree_decoder(h)
# 树节点展开为图
mol = self.assemble_molecule(tree)
return mol
def forward(self, mol):
mean, logvar = self.encode(mol)
# 重参数化
z = mean + torch.randn_like(mean) * torch.exp(0.5 * logvar)
# 解码
mol_recon = self.decode(z)
# 损失
recon_loss = self.compute_recon_loss(mol, mol_recon)
kl_loss = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
return recon_loss + kl_loss4.3 GAN方法:GCPN
图卷积策略网络(GCPN)使用强化学习结合GAN进行分子生成。4
class GCPNModel(nn.Module):
"""GCPN: Graph Convolutional Policy Network"""
def __init__(self, node_dim, edge_dim, hidden_dim, num_layers=6):
super().__init__()
# 图生成器
self.graph_generator = GraphGenerationPolicy(
node_dim=node_dim,
edge_dim=edge_dim,
hidden_dim=hidden_dim,
num_layers=num_layers
)
# 判别器
self.discriminator = MolecularDiscriminator(
node_dim=node_dim,
edge_dim=edge_dim,
hidden_dim=hidden_dim
)
# 奖励网络
self.reward_net = MolecularRewardNet(hidden_dim)
def generate_step(self, graph_state, action):
"""
一步生成动作
action: (add_node, add_edge, stop)
"""
if action.type == 'add_node':
# 添加新节点
graph_state.add_node(action.node_type, action.node_features)
elif action.type == 'add_edge':
# 添加边
graph_state.add_edge(action.src, action.dst, action.edge_type)
elif action.type == 'stop':
# 停止生成
pass
return graph_state
def discriminator_loss(self, generated_mols, real_mols):
"""判别器损失"""
# 生成样本
fake_scores = self.discriminator(generated_mols)
# 真实样本
real_scores = self.discriminator(real_mols)
# WGAN-GP loss
gp = self.compute_gradient_penalty(generated_mols, real_mols)
return -torch.mean(fake_scores) + torch.mean(real_scores) + 10 * gp
def generator_loss(self, generated_mols, target_properties):
"""生成器损失"""
# 判别器奖励
d_reward = torch.sigmoid(self.discriminator(generated_mols))
# 性质奖励
p_reward = -torch.abs(self.reward_net(generated_mols) - target_properties)
# 有效性奖励
valid_reward = self.check_validity(generated_mols)
return -(d_reward + p_reward + valid_reward).mean()5. 分子图Transformer
5.1 分子语言模型
将分子表示为SMILES序列,使用Transformer进行编码:
class MoleculeTransformer(nn.Module):
"""分子Transformer"""
def __init__(self, vocab_size, hidden_dim, num_heads, num_layers):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_dim)
self.pos_embedding = nn.Embedding(512, hidden_dim)
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_heads,
dim_feedforward=hidden_dim * 4
),
num_layers=num_layers
)
self.predictor = nn.Linear(hidden_dim, 1)
def forward(self, smiles_tokens):
"""
smiles_tokens: (batch, seq_len) SMILES token序列
"""
x = self.embedding(smiles_tokens)
x = x + self.pos_embedding(torch.arange(x.shape[1], device=x.device))
# Transformer编码
x = x.transpose(0, 1) # (seq, batch, dim)
x = self.transformer(x)
x = x.transpose(0, 1) # (batch, seq, dim)
# 池化
x = x.mean(dim=1)
return self.predictor(x)5.2 几何深度学习
对于3D分子结构,需要考虑几何约束:
class GeometricGNN(nn.Module):
"""几何GNN:考虑3D坐标"""
def __init__(self, node_dim, edge_dim, hidden_dim):
super().__init__()
self.node_embed = nn.Linear(node_dim, hidden_dim)
self.edge_embed = nn.Linear(edge_dim, hidden_dim)
# 几何感知的消息函数
self.message_net = nn.Sequential(
nn.Linear(hidden_dim * 3 + 1, hidden_dim), # h_i, h_j, e_ij, dist
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
self.update_net = nn.GRUCell(hidden_dim, hidden_dim)
def message(self, h_i, h_j, e_ij, dist_ij):
"""考虑距离的消息"""
msg_input = torch.cat([h_i, h_j, e_ij, dist_ij.unsqueeze(-1)], dim=-1)
return F.relu(self.message_net(msg_input))
def forward(self, x, edge_index, edge_attr, coords):
"""
coords: (N, 3) 原子坐标
"""
h = self.node_embed(x)
e = self.edge_embed(edge_attr)
# 计算距离
dist = torch.norm(coords.unsqueeze(1) - coords.unsqueeze(0), dim=-1)
# 消息传递
messages = torch.zeros_like(h)
for i in range(edge_index.shape[1]):
src, dst = edge_index[0, i], edge_index[1, i]
msg = self.message(h[src], h[dst], e[i], dist[src, dst])
messages[dst] += msg
# 更新
for v in range(h.shape[0]):
h[v] = self.update_net(messages[v], h[v])
return h6. 实践:QM9性质预测
6.1 数据加载
from torch_geometric.datasets import QM9
from torch_geometric.data import DataLoader
def load_qm9(batch_size=32):
"""加载QM9数据集"""
dataset = QM9(root='./data/QM9')
# QM9的12个性质
props = [
'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2',
'zpve', 'u0', 'u298', 'h298', 'g298', 'cv'
]
return dataset, props
def qm9_collate(batch):
"""QM9数据整理"""
return batch6.2 完整训练脚本
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import QM9
from torch_geometric.nn import global_mean_pool, global_add_pool
class MolecularGNN(nn.Module):
"""分子性质预测GNN"""
def __init__(self, node_dim, hidden_dim, num_tasks=12):
super().__init__()
self.encoder = EdgeConditionedGNN(
node_dim=node_dim,
edge_dim=5, # 键类型维度
hidden_dim=hidden_dim,
num_layers=4
)
self.readout = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU()
)
self.predictor = nn.Linear(hidden_dim // 2, num_tasks)
def forward(self, data):
# 编码
h = self.encoder(data.x, data.edge_index, data.edge_attr)
# 图级池化
h_mean = global_mean_pool(h, data.batch)
h_max = global_add_pool(h, data.batch)
h = torch.cat([h_mean, h_max], dim=-1)
# 读出
h = self.readout(h)
return self.predictor(h)
def train_qm9():
# 加载数据
dataset = QM9(root='./data/QM9')
# 划分数据集
perm = torch.randperm(len(dataset))
train_idx = perm[:100000]
val_idx = perm[100000:110000]
test_idx = perm[110000:]
train_loader = DataLoader(dataset[train_idx.tolist()], batch_size=64, shuffle=True)
val_loader = DataLoader(dataset[val_idx.tolist()], batch_size=64)
test_loader = DataLoader(dataset[test_idx.tolist()], batch_size=64)
# 模型
model = MolecularGNN(
node_dim=11, # QM9节点特征维度
hidden_dim=128,
num_tasks=12
).to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.L1Loss() # MAE损失
# 训练
for epoch in range(50):
model.train()
train_loss = 0
for batch in train_loader:
batch = batch.to('cuda')
optimizer.zero_grad()
pred = model(batch)
# 所有12个性质
loss = criterion(pred, batch.y)
loss.backward()
optimizer.step()
train_loss += loss.item()
# 验证
model.eval()
val_loss = 0
with torch.no_grad():
for batch in val_loader:
batch = batch.to('cuda')
pred = model(batch)
loss = criterion(pred, batch.y)
val_loss += loss.item()
print(f"Epoch {epoch}: Train Loss = {train_loss/len(train_loader):.4f}, "
f"Val Loss = {val_loss/len(val_loader):.4f}")
# 测试
model.eval()
test_loss = 0
with torch.no_grad():
for batch in test_loader:
batch = batch.to('cuda')
pred = model(batch)
loss = criterion(pred, batch.y)
test_loss += loss.item()
print(f"Test Loss: {test_loss/len(test_loader):.4f}")
if __name__ == '__main__':
train_qm9()7. 相关主题
参考
Footnotes
-
Gilmer et al., “Neural Message Passing for Quantum Chemistry”, ICML 2017 ↩
-
Öztürk et al., “DeepDTA: Deep Drug-Target Binding Affinity Prediction”, Bioinformatics 2019 ↩
-
Jin et al., “Junction Tree Variational Autoencoder for Molecular Graph Generation”, ICML 2018 ↩
-
You et al., “Graph Convolutional Policy Network for Goal-Directed Molecular Graph Generation”, NeurIPS 2018 ↩