等变图神经网络与分子科学

1. 引言

等变图神经网络(Equivariant Graph Neural Networks)是近年来分子科学领域最重要的突破之一。与传统GNN不同,等变GNN在输入发生几何变换时,输出会相应地发生可预测的变换,这种特性使其能够自然地编码物理对称性,避免学习冗余的表示。1

为什么需要等变性?

在分子系统中,物理规律具有以下不变性/等变性:

变换类型不变性等变性
平移能量、力原子坐标
旋转能量原子坐标、力向量
反射能量(手性分子除外)坐标、偶极矩

核心问题:如果训练数据中包含某个分子及其旋转版本,模型应该输出相同的能量预测,但传统GNN可能学到不同的表示,导致过拟合。


2. 群论基础

2.1 E(3)群

E(3)是三维欧几里得空间中的等距变换群,包含:

其中:

  • :三维旋转群(特殊正交群)
  • :三维平移群
  • :半直积

2.2 标量与矢量

在E(3)变换下,物理量的行为不同:

类型定义E(3)变换
标量 (0阶)与方向无关 (不变)
矢量 (1阶)有方向
二阶张量 矩阵

3. 球谐函数与SO(3)表示

3.1 球谐函数基础

球谐函数 群在球面上的完备基函数,其中 阶数(角度动量量子数),

前几阶球谐函数

3.2 Wigner-D矩阵

Wigner-D矩阵 描述了球谐函数在旋转 下的变换:

这意味着如果我们有阶数为 的特征向量 ,在旋转后变为:


4. 等变消息传递

4.1 等变卷积核

在传统GNN中,消息函数 可以是任意神经网络。但在等变GNN中,卷积核必须满足:

其中 诱导的表示变换。

核心约束:卷积核只能依赖于相对位移的旋转不变量(标量)。

4.2 TFN架构

Tensor Field Network (TFN) 提出了基于球谐函数的消息传递框架:

import torch
import torch.nn as nn
 
class EquivariantLayer(nn.Module):
    """
    基于球谐函数的等变卷积层
    """
    def __init__(self, l_max, in_channels, out_channels):
        super().__init__()
        self.l_max = l_max
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        # 每阶的MLP(不随旋转变化)
        self.radial_mlp = nn.ModuleDict({
            str(l): nn.Sequential(
                nn.Linear(in_channels, 64),
                nn.SiLU(),
                nn.Linear(64, out_channels)
            ) for l in range(l_max + 1)
        })
    
    def compute_clebsch_gordan(self, l1, l2, l3):
        """
        计算Clebsch-Gordan系数
        CG系数描述了如何组合两个不可约表示
        """
        # 简化实现,实际使用e3nn库
        pass
    
    def forward(self, node_features, edge_vectors):
        """
        Args:
            node_features: {l: [B, N, 2l+1, F]} 阶数为l的特征
            edge_vectors: [B, N, N, 3] 边向量
        """
        output = {l: torch.zeros_like(node_features[l]) for l in range(self.l_max + 1)}
        
        for l1 in range(self.l_max + 1):
            for l2 in range(self.l_max + 1):
                # 输出阶数约束: |l1 - l2| <= l3 <= l1 + l2
                for l3 in range(abs(l1 - l2), min(l1 + l2, self.l_max) + 1):
                    # 球谐函数编码边方向
                    Y_l2 = spherical_harmonics(edge_vectors, l2)  # [B, N, N, 2l2+1]
                    
                    # 径向MLP
                    radial = self.radial_mlp[str(l3)](node_features[l1])  # [B, N, 1, F]
                    
                    # Clebsch-Gordan耦合
                    coupled = self.clebsch_gordan(Y_l2, radial, l1, l2, l3)
                    
                    output[l3] += coupled
        
        return output

5. e3nn库核心操作

e3nn是目前最广泛使用的等变神经网络库,提供了:

5.1 irreps:不可约表示

from e3nn import o3
from e3nn.nn import Gate
 
# 定义不可约表示
# "1x0e" = 1个标量(阶数0,偶宇称)
# "1x1o" = 1个矢量(阶数1,奇宇称)  
# "1x2e" = 1个二阶张量(阶数2,偶宇称)
irreps_in = o3.Irreps("16x0e + 8x1o + 4x2e")
irreps_out = o3.Irreps("32x0e + 16x1e")
 
# 创建线性等变层
lin = o3.Linear(irreps_in, irreps_out)

5.2 TensorProduct:等变张量积

from e3nn.nn import TensorProduct
 
# 逐通道张量积
tp = TensorProduct(
    "16x0e + 8x1o",      # 输入1
    "16x0e + 8x1o",      # 输入2
    "16x0e + 8x1o + 4x2e", # 输出
    compilation_mode="xyzn"
)
 
# 输入必须是可加的(阶数相同),输出可以是混合阶数

5.3 Gate:门控非线性

由于非线性激活(如ReLU)会破坏等变性,e3nn使用门控机制

from e3nn.nn import Gate
 
# 标量特征使用非线性激活
# 门控特征控制门开闭
irreps_scalars = o3.Irreps("16x0e")
irreps_gates = o3.Irreps("8x1o")  # 门控必须是标量或1阶
irreps_gated = o3.Irreps("8x1o + 4x2e")
 
gate = Gate(
    irreps_scalars, [torch.nn.SiLU()],      # 标量非线性
    irreps_gates, [torch.nn.functional.sigmoid],  # 门控
    irreps_gated  # 门控特征的等变操作
)

6. 等变GNN架构示例

6.1 SchNet

SchNet是最早的等变分子属性预测模型之一,使用连续卷积:

class SchNetLayer(nn.Module):
    """
    SchNet的连续卷积层
    关键点:卷积核只依赖于距离(旋转不变量)
    """
    def __init__(self, hidden_dim, n_filters):
        super().__init__()
        self.conv = nn.Linear(hidden_dim, n_filters)
        self.filter = nn.Sequential(
            nn.Linear(1, n_filters),
            nn.ReLU(),
            nn.Linear(n_filters, hidden_dim)
        )
    
    def forward(self, x, edge_index, edge_attr):
        # edge_attr: 相对位移的距离
        j, i = edge_index
        # 消息传递
        msg = self.conv(x[i]) * self.filter(edge_attr.unsqueeze(-1))
        return scatter_add(msg, j, dim=0)

6.2 NequIP架构

NequIP使用更复杂的等变消息传递(见NequIP与Allegro)。


7. 与标准GNN的对比

特性标准GNN等变GNN
表示能力旋转不变旋转等变(更丰富)
参数效率较低较高(对称性约束)
数据效率需要数据增强自然利用对称性
可解释性一般物理意义更清晰
计算复杂度

8. 应用场景

  1. 分子动力学势能面:预测原子间力和能量
  2. 蛋白质结构预测:三维坐标生成
  3. 材料科学:晶体性质预测
  4. 药物设计:分子对接、亲和力预测

参考文献


相关主题图神经网络, 图卷积网络, 物理信息神经网络

Footnotes

  1. Batzner et al. “E(3)-Equivariant Graph Neural Networks for Data-Efficient and Accurate Molecular Dynamics.” arXiv:2201.01288, 2022.