等变图神经网络与分子科学
1. 引言
等变图神经网络(Equivariant Graph Neural Networks)是近年来分子科学领域最重要的突破之一。与传统GNN不同,等变GNN在输入发生几何变换时,输出会相应地发生可预测的变换,这种特性使其能够自然地编码物理对称性,避免学习冗余的表示。1
为什么需要等变性?
在分子系统中,物理规律具有以下不变性/等变性:
| 变换类型 | 不变性 | 等变性 |
|---|---|---|
| 平移 | 能量、力 | 原子坐标 |
| 旋转 | 能量 | 原子坐标、力向量 |
| 反射 | 能量(手性分子除外) | 坐标、偶极矩 |
核心问题:如果训练数据中包含某个分子及其旋转版本,模型应该输出相同的能量预测,但传统GNN可能学到不同的表示,导致过拟合。
2. 群论基础
2.1 E(3)群
E(3)是三维欧几里得空间中的等距变换群,包含:
其中:
- :三维旋转群(特殊正交群)
- :三维平移群
- :半直积
2.2 标量与矢量
在E(3)变换下,物理量的行为不同:
| 类型 | 定义 | E(3)变换 |
|---|---|---|
| 标量 (0阶) | 与方向无关 | (不变) |
| 矢量 (1阶) | 有方向 | |
| 二阶张量 | 矩阵 |
3. 球谐函数与SO(3)表示
3.1 球谐函数基础
球谐函数 是 群在球面上的完备基函数,其中 是阶数(角度动量量子数),。
前几阶球谐函数:
3.2 Wigner-D矩阵
Wigner-D矩阵 描述了球谐函数在旋转 下的变换:
这意味着如果我们有阶数为 的特征向量 ,在旋转后变为:
4. 等变消息传递
4.1 等变卷积核
在传统GNN中,消息函数 可以是任意神经网络。但在等变GNN中,卷积核必须满足:
其中 是 诱导的表示变换。
核心约束:卷积核只能依赖于相对位移的旋转不变量(标量)。
4.2 TFN架构
Tensor Field Network (TFN) 提出了基于球谐函数的消息传递框架:
import torch
import torch.nn as nn
class EquivariantLayer(nn.Module):
"""
基于球谐函数的等变卷积层
"""
def __init__(self, l_max, in_channels, out_channels):
super().__init__()
self.l_max = l_max
self.in_channels = in_channels
self.out_channels = out_channels
# 每阶的MLP(不随旋转变化)
self.radial_mlp = nn.ModuleDict({
str(l): nn.Sequential(
nn.Linear(in_channels, 64),
nn.SiLU(),
nn.Linear(64, out_channels)
) for l in range(l_max + 1)
})
def compute_clebsch_gordan(self, l1, l2, l3):
"""
计算Clebsch-Gordan系数
CG系数描述了如何组合两个不可约表示
"""
# 简化实现,实际使用e3nn库
pass
def forward(self, node_features, edge_vectors):
"""
Args:
node_features: {l: [B, N, 2l+1, F]} 阶数为l的特征
edge_vectors: [B, N, N, 3] 边向量
"""
output = {l: torch.zeros_like(node_features[l]) for l in range(self.l_max + 1)}
for l1 in range(self.l_max + 1):
for l2 in range(self.l_max + 1):
# 输出阶数约束: |l1 - l2| <= l3 <= l1 + l2
for l3 in range(abs(l1 - l2), min(l1 + l2, self.l_max) + 1):
# 球谐函数编码边方向
Y_l2 = spherical_harmonics(edge_vectors, l2) # [B, N, N, 2l2+1]
# 径向MLP
radial = self.radial_mlp[str(l3)](node_features[l1]) # [B, N, 1, F]
# Clebsch-Gordan耦合
coupled = self.clebsch_gordan(Y_l2, radial, l1, l2, l3)
output[l3] += coupled
return output5. e3nn库核心操作
e3nn是目前最广泛使用的等变神经网络库,提供了:
5.1 irreps:不可约表示
from e3nn import o3
from e3nn.nn import Gate
# 定义不可约表示
# "1x0e" = 1个标量(阶数0,偶宇称)
# "1x1o" = 1个矢量(阶数1,奇宇称)
# "1x2e" = 1个二阶张量(阶数2,偶宇称)
irreps_in = o3.Irreps("16x0e + 8x1o + 4x2e")
irreps_out = o3.Irreps("32x0e + 16x1e")
# 创建线性等变层
lin = o3.Linear(irreps_in, irreps_out)5.2 TensorProduct:等变张量积
from e3nn.nn import TensorProduct
# 逐通道张量积
tp = TensorProduct(
"16x0e + 8x1o", # 输入1
"16x0e + 8x1o", # 输入2
"16x0e + 8x1o + 4x2e", # 输出
compilation_mode="xyzn"
)
# 输入必须是可加的(阶数相同),输出可以是混合阶数5.3 Gate:门控非线性
由于非线性激活(如ReLU)会破坏等变性,e3nn使用门控机制:
from e3nn.nn import Gate
# 标量特征使用非线性激活
# 门控特征控制门开闭
irreps_scalars = o3.Irreps("16x0e")
irreps_gates = o3.Irreps("8x1o") # 门控必须是标量或1阶
irreps_gated = o3.Irreps("8x1o + 4x2e")
gate = Gate(
irreps_scalars, [torch.nn.SiLU()], # 标量非线性
irreps_gates, [torch.nn.functional.sigmoid], # 门控
irreps_gated # 门控特征的等变操作
)6. 等变GNN架构示例
6.1 SchNet
SchNet是最早的等变分子属性预测模型之一,使用连续卷积:
class SchNetLayer(nn.Module):
"""
SchNet的连续卷积层
关键点:卷积核只依赖于距离(旋转不变量)
"""
def __init__(self, hidden_dim, n_filters):
super().__init__()
self.conv = nn.Linear(hidden_dim, n_filters)
self.filter = nn.Sequential(
nn.Linear(1, n_filters),
nn.ReLU(),
nn.Linear(n_filters, hidden_dim)
)
def forward(self, x, edge_index, edge_attr):
# edge_attr: 相对位移的距离
j, i = edge_index
# 消息传递
msg = self.conv(x[i]) * self.filter(edge_attr.unsqueeze(-1))
return scatter_add(msg, j, dim=0)6.2 NequIP架构
NequIP使用更复杂的等变消息传递(见NequIP与Allegro)。
7. 与标准GNN的对比
| 特性 | 标准GNN | 等变GNN |
|---|---|---|
| 表示能力 | 旋转不变 | 旋转等变(更丰富) |
| 参数效率 | 较低 | 较高(对称性约束) |
| 数据效率 | 需要数据增强 | 自然利用对称性 |
| 可解释性 | 一般 | 物理意义更清晰 |
| 计算复杂度 |
8. 应用场景
- 分子动力学势能面:预测原子间力和能量
- 蛋白质结构预测:三维坐标生成
- 材料科学:晶体性质预测
- 药物设计:分子对接、亲和力预测
参考文献
Footnotes
-
Batzner et al. “E(3)-Equivariant Graph Neural Networks for Data-Efficient and Accurate Molecular Dynamics.” arXiv:2201.01288, 2022. ↩