E(n)等变拓扑神经网络
E(n)等变拓扑神经网络(E(n) Equivariant Topological Neural Networks)是2025年ICLR的重要工作,将拓扑深度学习与等变性(Equivariance)理论结合,为分子和物理系统建模提供了强大的工具。1
1. 等变性理论基础
1.1 对称性与等变性
对称性是物理学和几何学的核心概念,描述了在某种变换下系统的不变性。
等变性(Equivariance)定义:对于变换群 和函数 ,如果:
其中 和 分别是输入和输出空间上的群作用,则 相对于群 是等变的。
1.2 欧几里得群 E(n)
E(n) 是 n 维欧几里得空间的等距同构群,包含:
| 变换类型 | 记号 | 自由度 |
|---|---|---|
| 平移 | ||
| 旋转 | ||
| 反射 | - | |
| 组合 |
1.3 常见的等变网络
| 网络类型 | 等变群 | 应用 |
|---|---|---|
| CNN | 图像 | |
| 3D CNN | 体积数据 | |
| GNN | 图结构 | |
| TFN | 分子 | |
| EGNN | 3D分子 |
2. 拓扑归纳偏置
2.1 什么是拓扑归纳偏置
拓扑归纳偏置是指网络架构中编码的关于数据拓扑结构的先验知识:
- 结构不变性:不依赖坐标系的结构表示
- 多尺度感知:同时捕获微观和宏观拓扑特征
- 持久性感知:区分重要结构和噪声
2.2 拓扑特征的类型
| 拓扑特征 | 维度 | 物理意义 |
|---|---|---|
| 连通分量 | 0维 | 相邻性、连通域 |
| 环路 | 1维 | 循环、周期性 |
| 空洞 | 2维 | 腔体、孔洞 |
| 高维洞 | k维 | 复杂结构 |
2.3 为什么需要拓扑归纳偏置
传统方法:
输入 → 欧几里得坐标 → 手工特征 → 分类器
↓
对噪声敏感
对变换不稳定
拓扑方法:
输入 → 拓扑结构 → 持久特征 → 分类器
↓
对噪声鲁棒
对变换稳定
3. E(n)等变拓扑网络架构
3.1 核心设计原则
E(n)等变拓扑网络遵循以下设计原则:
- 等变性约束:网络输出对欧几里得变换等变
- 拓扑感知:集成持久同调特征
- 高阶交互:超越成对关系的复杂结构建模
3.2 架构概览
class EquivariantTopoLayer(nn.Module):
"""
E(n)等变拓扑层
整合几何特征、拓扑特征和高阶交互
"""
def __init__(self, node_dim, edge_dim, hidden_dim, n_heads=4):
super().__init__()
# 几何特征处理(等变)
self.geo_encoder = GeometricEncoder(node_dim, hidden_dim)
# 拓扑特征处理
self.topo_encoder = TopologicalEncoder(hidden_dim)
# 高阶交互消息传递
self.higher_order_msg = HigherOrderMessagePassing(
hidden_dim, n_heads
)
# 更新函数(等变)
self.update = EquivariantUpdate(hidden_dim)
def forward(self, x, edge_index, positions, batch=None):
"""
x: 节点特征
edge_index: 边连接
positions: 3D坐标(等变关键)
"""
# 1. 编码几何特征
geo_feat = self.geo_encoder(x, positions)
# 2. 提取局部拓扑特征
local_topo = self.topo_encoder(x, edge_index)
# 3. 高阶交互消息传递
msg = self.higher_order_msg(
geo_feat, local_topo, edge_index, positions
)
# 4. 等变更新
updated = self.update(msg, positions)
return updated3.3 几何编码器
class GeometricEncoder(nn.Module):
"""
几何特征编码器
保持对欧几里得变换的等变性
"""
def __init__(self, node_dim, hidden_dim):
super().__init__()
# 节点特征嵌入
self.node_mlp = nn.Sequential(
nn.Linear(node_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# 距离编码(旋转不变)
self.distance_mlp = nn.Sequential(
nn.Linear(1, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# 方向编码(旋转等变)
self.direction_mlp = nn.Sequential(
nn.Linear(3, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
def forward(self, x, positions):
"""
positions: (N, 3) 坐标,变换时同步变换
"""
# 节点特征
h = self.node_mlp(x)
# 计算相对位置(等变)
rel_pos = positions.unsqueeze(1) - positions.unsqueeze(0) # (N, N, 3)
rel_dist = torch.norm(rel_pos, dim=-1, keepdim=True) # (N, N, 1)
rel_dir = rel_pos / (rel_dist + 1e-8) # (N, N, 3)
# 编码
dist_feat = self.distance_mlp(rel_dist) # (N, N, H)
dir_feat = self.direction_mlp(rel_dir) # (N, N, H)
return {
'node_feat': h,
'distance_feat': dist_feat,
'direction_feat': dir_feat,
'relative_positions': rel_pos
}3.4 拓扑编码器
class TopologicalEncoder(nn.Module):
"""
拓扑特征编码器
从局部邻域提取拓扑信息
"""
def __init__(self, hidden_dim, k_neighbors=10):
super().__init__()
self.k = k_neighbors
self.topo_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim // 2)
)
def forward(self, x, edge_index):
"""
edge_index: (2, E) 边索引
"""
# 构建邻接信息
adj = self._build_adjacency(x, edge_index)
# 计算局部拓扑特征
topo_feat = self._extract_local_topology(adj)
return self.topo_mlp(topo_feat)
def _build_adjacency(self, x, edge_index):
"""构建邻接矩阵及相关特征"""
N = x.shape[0]
adj = torch.zeros(N, N)
adj[edge_index[0], edge_index[1]] = 1
# 度数特征
degrees = adj.sum(dim=-1)
# 聚类系数(局部拓扑)
adj_sq = torch.matmul(adj, adj)
triangles = torch.diagonal(adj_sq, dim1=-2, dim2=-1)
possible = degrees * (degrees - 1)
clustering = triangles / (possible + 1e-8)
return {
'adj': adj,
'degrees': degrees,
'clustering': clustering
}
def _extract_local_topology(self, adj_info):
"""提取局部拓扑特征"""
return torch.cat([
adj_info['degrees'].unsqueeze(-1),
adj_info['clustering'].unsqueeze(-1)
], dim=-1)3.5 高阶交互消息传递
class HigherOrderMessagePassing(nn.Module):
"""
高阶交互消息传递
超越成对关系,建模复杂结构
"""
def __init__(self, hidden_dim, n_heads=4):
super().__init__()
self.n_heads = n_heads
self.head_dim = hidden_dim // n_heads
# 多头消息生成
self.message_mlp = nn.ModuleList([
nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, self.head_dim)
)
for _ in range(n_heads)
])
# 聚合函数
self.aggregate = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim)
)
def forward(self, node_feat, topo_feat, edge_index, positions):
"""
消息传递,聚合邻域信息
"""
src, dst = edge_index
# 构建消息
messages = []
for head in range(self.n_heads):
# 源节点和目标节点特征拼接
combined = torch.cat([
node_feat[src],
node_feat[dst]
], dim=-1)
# 生成消息
msg = self.message_mlp[head](combined)
messages.append(msg)
# 拼接多头消息
messages = torch.cat(messages, dim=-1) # (E, H)
# 聚合到目标节点
N = node_feat.shape[0]
aggregated = torch.zeros(N, self.n_heads * self.head_dim)
aggregated = aggregated.to(node_feat.device)
aggregated = aggregated.index_add(0, dst, messages)
return self.aggregate(aggregated)3.6 等变更新
class EquivariantUpdate(nn.Module):
"""
等变更新函数
保持对平移和旋转的等变性
"""
def __init__(self, hidden_dim):
super().__init__()
# 特征更新
self.feature_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# 位置更新(仅平移等变,不改变方向)
self.position_mlp = nn.Sequential(
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
def forward(self, aggregated_msg, positions):
"""
positions: (N, 3) 坐标
"""
# 更新特征
updated_feat = self.feature_mlp(aggregated_msg)
# 更新位置(等变:平移不变,旋转等变)
# 位置更新使用相对位移加权
scale = self.position_mlp(aggregated_msg) # (N, 1)
return {
'features': updated_feat,
'positions': positions, # 位置保持不变(简化版)
'position_scale': scale
}4. 完整模型
class EquivariantTopoNet(nn.Module):
"""
E(n)等变拓扑神经网络
端到端分子/物理系统建模
"""
def __init__(self, node_dim, edge_dim, hidden_dim=128,
n_layers=4, n_heads=4, output_dim=1):
super().__init__()
# 输入嵌入
self.node_embedding = nn.Linear(node_dim, hidden_dim)
self.edge_embedding = nn.Linear(edge_dim, hidden_dim)
# 等变拓扑层堆叠
self.layers = nn.ModuleList([
EquivariantTopoLayer(
hidden_dim, hidden_dim, hidden_dim, n_heads
)
for _ in range(n_layers)
])
# 输出预测
self.readout = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
def forward(self, data):
"""
data: PyG Data对象
- data.x: 节点特征
- data.edge_index: 边索引
- data.edge_attr: 边特征
- data.pos: 3D坐标
"""
x = self.node_embedding(data.x)
pos = data.pos
# 逐层传播
for layer in self.layers:
x = layer(x, data.edge_index, pos)
if isinstance(x, dict):
x = x['features']
# 图级别池化
x = global_mean_pool(x, data.batch)
# 预测
return self.readout(x)5. 应用场景
5.1 分子性质预测
def train_molecular_property_model():
"""
训练E(n)等变拓扑网络预测分子性质
"""
model = EquivariantTopoNet(
node_dim=ATOM_FEATURES_DIM, # 原子特征维度
edge_dim=BOND_FEATURES_DIM, # 键特征维度
hidden_dim=256,
n_layers=6,
n_heads=4,
output_dim=1
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=100
)
for epoch in range(100):
model.train()
total_loss = 0
for batch in dataloader:
optimizer.zero_grad()
pred = model(batch)
loss = F.mse_loss(pred, batch.y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
scheduler.step()
print(f"Epoch {epoch}: Loss={total_loss:.4f}")
return model5.2 分子构象生成
class EquivariantTopoConformationGenerator(nn.Module):
"""
等变拓扑分子构象生成器
"""
def __init__(self, latent_dim, hidden_dim):
super().__init__()
# 噪声到特征的解码器
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim * 2),
nn.ReLU(),
nn.Linear(hidden_dim * 2, hidden_dim * 4),
nn.ReLU(),
nn.Linear(hidden_dim * 4, hidden_dim)
)
# 等变位置生成
self.position_generator = EquivariantPositionGenerator(hidden_dim)
# 拓扑约束
self.topo_constraint = TopologicalConstraint()
def forward(self, z, ref_topology):
"""
z: 潜在编码
ref_topology: 参考拓扑结构
"""
# 解码特征
h = self.decoder(z)
# 生成3D位置
positions = self.position_generator(h)
# 拓扑约束损失
topo_loss = self.topo_constraint(positions, ref_topology)
return positions, topo_loss
class EquivariantPositionGenerator(nn.Module):
"""
等变位置生成器
保证生成的3D结构对旋转平移等变
"""
def __init__(self, hidden_dim):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 3) # 输出相对位移
)
def forward(self, features):
"""
生成相对位置(相对于质心)
"""
# 计算质心
centroid = features.mean(dim=0, keepdim=True)
# 生成相对于质心的位移
delta = self.mlp(features)
# 位置 = 质心 + 位移
positions = centroid + delta
return positions5.3 物理系统模拟
class PhysicsSimulationModel(nn.Module):
"""
物理系统模拟:预测粒子系统演化
"""
def __init__(self, particle_dim, hidden_dim):
super().__init__()
self.encoder = nn.Linear(particle_dim, hidden_dim)
self.dynamics = EquivariantTopoNet(
node_dim=hidden_dim,
edge_dim=1, # 距离
hidden_dim=hidden_dim,
n_layers=3
)
self.velocity_predictor = nn.Linear(hidden_dim, 3)
def forward(self, positions, velocities, edge_index):
"""
预测下一时刻的速度
positions: (N, 3)
velocities: (N, 3)
"""
# 构建节点特征(位置+速度)
node_features = torch.cat([positions, velocities], dim=-1)
# 编码
h = self.encoder(node_features)
# 等变动力学预测
output = self.dynamics(h, edge_index, positions)
# 预测速度变化
delta_v = self.velocity_predictor(output)
# 预测下一时刻速度
next_velocities = velocities + delta_v
return next_velocities6. 理论基础
6.1 等变性定理
Steerable CNN定理:对于映射 ,如果 相对于变换群 是等变的,则 必须具有特定的形式(Steerable)。
6.2 拓扑与等变性的关系
| 性质 | 拓扑 | 等变性 |
|---|---|---|
| 不变性 | 拓扑不变量 | 某些变换不变 |
| 结构性 | 捕获内在结构 | 捕获几何结构 |
| 对称性 | 无坐标系依赖 | 坐标变换对称 |
6.3 高阶交互的表示能力
# 1阶交互:成对关系
x_i^{(1)} = \sum_{j} w_{ij} x_j
# 2阶交互:三角形结构
x_i^{(2)} = \sum_{j,k} w_{ijk} x_j x_k
# k阶交互:k-单形结构
x_i^{(k)} = \sum_{j_1,\ldots,j_k} w_{ij_1\ldots j_k} \prod_{l=1}^k x_{j_l}7. 与其他方法的对比
7.1 与标准GNN对比
| 特性 | 标准GNN | E(n)等变拓扑网络 |
|---|---|---|
| 几何感知 | 弱 | 强 |
| 拓扑感知 | 有限 | 完整 |
| 高阶交互 | 成对 | 多阶 |
| 等变性 | 节点置换 | 欧几里得变换 |
7.2 与TFN/EGNN对比
| 特性 | TFN | EGNN | 本文方法 |
|---|---|---|---|
| 旋转变换 | |||
| 拓扑特征 | 无 | 无 | 有 |
| 高阶交互 | 无 | 无 | 有 |
| 计算复杂度 | 中等 | 中等 | 较高 |
7.3 计算效率
# 复杂度分析
class ComplexityAnalysis:
@staticmethod
def gnn_complexity(n_nodes, n_edges, hidden_dim):
return n_edges * hidden_dim # O(E * H)
@staticmethod
def equiv_topo_complexity(n_nodes, n_edges, n_simplices, hidden_dim):
edges_cost = n_edges * hidden_dim
higher_cost = n_simplices * hidden_dim * 2 # 高阶交互
topo_cost = n_nodes * np.log(n_nodes) # 拓扑计算
return edges_cost + higher_cost + topo_cost8. 实践建议
8.1 何时使用
| 场景 | 推荐程度 | 理由 |
|---|---|---|
| 分子性质预测 | ⭐⭐⭐⭐⭐ | 3D结构+拓扑特征 |
| 物理模拟 | ⭐⭐⭐⭐ | 等变性保证物理一致 |
| 蛋白质结构 | ⭐⭐⭐⭐ | 拓扑特征重要 |
| 图像处理 | ⭐⭐ | 不需要3D等变性 |
| 图分类 | ⭐⭐⭐ | 拓扑增强有用 |
8.2 超参数选择
# 推荐配置
config = {
'hidden_dim': 256, # 中等维度
'n_layers': 4-6, # 深度网络
'n_heads': 4-8, # 多头注意力
'dropout': 0.1, # 防止过拟合
'learning_rate': 1e-4, # 标准学习率
'weight_decay': 1e-5 # L2正则化
}8.3 常见问题
- 内存消耗:高阶交互计算量大,使用稀疏表示
- 训练不稳定:等变约束可能导致梯度问题,使用梯度裁剪
- 拓扑特征选择:不是越多越好,选择相关的维度
参考文献
相关文档
Footnotes
-
Hajij, M., et al. (2025). E(n) Equivariant Topological Neural Networks. ICLR 2025. ↩