E(n)等变拓扑神经网络

E(n)等变拓扑神经网络(E(n) Equivariant Topological Neural Networks)是2025年ICLR的重要工作,将拓扑深度学习与等变性(Equivariance)理论结合,为分子和物理系统建模提供了强大的工具。1


1. 等变性理论基础

1.1 对称性与等变性

对称性是物理学和几何学的核心概念,描述了在某种变换下系统的不变性。

等变性(Equivariance)定义:对于变换群 和函数 ,如果:

其中 分别是输入和输出空间上的群作用,则 相对于群 是等变的。

1.2 欧几里得群 E(n)

E(n) 是 n 维欧几里得空间的等距同构群,包含:

变换类型记号自由度
平移
旋转
反射-
组合

1.3 常见的等变网络

网络类型等变群应用
CNN图像
3D CNN体积数据
GNN图结构
TFN分子
EGNN3D分子

2. 拓扑归纳偏置

2.1 什么是拓扑归纳偏置

拓扑归纳偏置是指网络架构中编码的关于数据拓扑结构的先验知识:

  1. 结构不变性:不依赖坐标系的结构表示
  2. 多尺度感知:同时捕获微观和宏观拓扑特征
  3. 持久性感知:区分重要结构和噪声

2.2 拓扑特征的类型

拓扑特征维度物理意义
连通分量0维相邻性、连通域
环路1维循环、周期性
空洞2维腔体、孔洞
高维洞k维复杂结构

2.3 为什么需要拓扑归纳偏置

传统方法:
输入 → 欧几里得坐标 → 手工特征 → 分类器
         ↓
    对噪声敏感
    对变换不稳定

拓扑方法:
输入 → 拓扑结构 → 持久特征 → 分类器
         ↓
    对噪声鲁棒
    对变换稳定

3. E(n)等变拓扑网络架构

3.1 核心设计原则

E(n)等变拓扑网络遵循以下设计原则:

  1. 等变性约束:网络输出对欧几里得变换等变
  2. 拓扑感知:集成持久同调特征
  3. 高阶交互:超越成对关系的复杂结构建模

3.2 架构概览

class EquivariantTopoLayer(nn.Module):
    """
    E(n)等变拓扑层
    整合几何特征、拓扑特征和高阶交互
    """
    
    def __init__(self, node_dim, edge_dim, hidden_dim, n_heads=4):
        super().__init__()
        
        # 几何特征处理(等变)
        self.geo_encoder = GeometricEncoder(node_dim, hidden_dim)
        
        # 拓扑特征处理
        self.topo_encoder = TopologicalEncoder(hidden_dim)
        
        # 高阶交互消息传递
        self.higher_order_msg = HigherOrderMessagePassing(
            hidden_dim, n_heads
        )
        
        # 更新函数(等变)
        self.update = EquivariantUpdate(hidden_dim)
    
    def forward(self, x, edge_index, positions, batch=None):
        """
        x: 节点特征
        edge_index: 边连接
        positions: 3D坐标(等变关键)
        """
        # 1. 编码几何特征
        geo_feat = self.geo_encoder(x, positions)
        
        # 2. 提取局部拓扑特征
        local_topo = self.topo_encoder(x, edge_index)
        
        # 3. 高阶交互消息传递
        msg = self.higher_order_msg(
            geo_feat, local_topo, edge_index, positions
        )
        
        # 4. 等变更新
        updated = self.update(msg, positions)
        
        return updated

3.3 几何编码器

class GeometricEncoder(nn.Module):
    """
    几何特征编码器
    保持对欧几里得变换的等变性
    """
    
    def __init__(self, node_dim, hidden_dim):
        super().__init__()
        
        # 节点特征嵌入
        self.node_mlp = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # 距离编码(旋转不变)
        self.distance_mlp = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # 方向编码(旋转等变)
        self.direction_mlp = nn.Sequential(
            nn.Linear(3, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
    
    def forward(self, x, positions):
        """
        positions: (N, 3) 坐标,变换时同步变换
        """
        # 节点特征
        h = self.node_mlp(x)
        
        # 计算相对位置(等变)
        rel_pos = positions.unsqueeze(1) - positions.unsqueeze(0)  # (N, N, 3)
        rel_dist = torch.norm(rel_pos, dim=-1, keepdim=True)  # (N, N, 1)
        rel_dir = rel_pos / (rel_dist + 1e-8)  # (N, N, 3)
        
        # 编码
        dist_feat = self.distance_mlp(rel_dist)  # (N, N, H)
        dir_feat = self.direction_mlp(rel_dir)  # (N, N, H)
        
        return {
            'node_feat': h,
            'distance_feat': dist_feat,
            'direction_feat': dir_feat,
            'relative_positions': rel_pos
        }

3.4 拓扑编码器

class TopologicalEncoder(nn.Module):
    """
    拓扑特征编码器
    从局部邻域提取拓扑信息
    """
    
    def __init__(self, hidden_dim, k_neighbors=10):
        super().__init__()
        self.k = k_neighbors
        
        self.topo_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2)
        )
    
    def forward(self, x, edge_index):
        """
        edge_index: (2, E) 边索引
        """
        # 构建邻接信息
        adj = self._build_adjacency(x, edge_index)
        
        # 计算局部拓扑特征
        topo_feat = self._extract_local_topology(adj)
        
        return self.topo_mlp(topo_feat)
    
    def _build_adjacency(self, x, edge_index):
        """构建邻接矩阵及相关特征"""
        N = x.shape[0]
        adj = torch.zeros(N, N)
        adj[edge_index[0], edge_index[1]] = 1
        
        # 度数特征
        degrees = adj.sum(dim=-1)
        
        # 聚类系数(局部拓扑)
        adj_sq = torch.matmul(adj, adj)
        triangles = torch.diagonal(adj_sq, dim1=-2, dim2=-1)
        
        possible = degrees * (degrees - 1)
        clustering = triangles / (possible + 1e-8)
        
        return {
            'adj': adj,
            'degrees': degrees,
            'clustering': clustering
        }
    
    def _extract_local_topology(self, adj_info):
        """提取局部拓扑特征"""
        return torch.cat([
            adj_info['degrees'].unsqueeze(-1),
            adj_info['clustering'].unsqueeze(-1)
        ], dim=-1)

3.5 高阶交互消息传递

class HigherOrderMessagePassing(nn.Module):
    """
    高阶交互消息传递
    超越成对关系,建模复杂结构
    """
    
    def __init__(self, hidden_dim, n_heads=4):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = hidden_dim // n_heads
        
        # 多头消息生成
        self.message_mlp = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim * 2, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, self.head_dim)
            )
            for _ in range(n_heads)
        ])
        
        # 聚合函数
        self.aggregate = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim)
        )
    
    def forward(self, node_feat, topo_feat, edge_index, positions):
        """
        消息传递,聚合邻域信息
        """
        src, dst = edge_index
        
        # 构建消息
        messages = []
        for head in range(self.n_heads):
            # 源节点和目标节点特征拼接
            combined = torch.cat([
                node_feat[src],
                node_feat[dst]
            ], dim=-1)
            
            # 生成消息
            msg = self.message_mlp[head](combined)
            messages.append(msg)
        
        # 拼接多头消息
        messages = torch.cat(messages, dim=-1)  # (E, H)
        
        # 聚合到目标节点
        N = node_feat.shape[0]
        aggregated = torch.zeros(N, self.n_heads * self.head_dim)
        aggregated = aggregated.to(node_feat.device)
        aggregated = aggregated.index_add(0, dst, messages)
        
        return self.aggregate(aggregated)

3.6 等变更新

class EquivariantUpdate(nn.Module):
    """
    等变更新函数
    保持对平移和旋转的等变性
    """
    
    def __init__(self, hidden_dim):
        super().__init__()
        
        # 特征更新
        self.feature_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # 位置更新(仅平移等变,不改变方向)
        self.position_mlp = nn.Sequential(
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, aggregated_msg, positions):
        """
        positions: (N, 3) 坐标
        """
        # 更新特征
        updated_feat = self.feature_mlp(aggregated_msg)
        
        # 更新位置(等变:平移不变,旋转等变)
        # 位置更新使用相对位移加权
        scale = self.position_mlp(aggregated_msg)  # (N, 1)
        
        return {
            'features': updated_feat,
            'positions': positions,  # 位置保持不变(简化版)
            'position_scale': scale
        }

4. 完整模型

class EquivariantTopoNet(nn.Module):
    """
    E(n)等变拓扑神经网络
    端到端分子/物理系统建模
    """
    
    def __init__(self, node_dim, edge_dim, hidden_dim=128, 
                 n_layers=4, n_heads=4, output_dim=1):
        super().__init__()
        
        # 输入嵌入
        self.node_embedding = nn.Linear(node_dim, hidden_dim)
        self.edge_embedding = nn.Linear(edge_dim, hidden_dim)
        
        # 等变拓扑层堆叠
        self.layers = nn.ModuleList([
            EquivariantTopoLayer(
                hidden_dim, hidden_dim, hidden_dim, n_heads
            )
            for _ in range(n_layers)
        ])
        
        # 输出预测
        self.readout = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, data):
        """
        data: PyG Data对象
        - data.x: 节点特征
        - data.edge_index: 边索引
        - data.edge_attr: 边特征
        - data.pos: 3D坐标
        """
        x = self.node_embedding(data.x)
        pos = data.pos
        
        # 逐层传播
        for layer in self.layers:
            x = layer(x, data.edge_index, pos)
            if isinstance(x, dict):
                x = x['features']
        
        # 图级别池化
        x = global_mean_pool(x, data.batch)
        
        # 预测
        return self.readout(x)

5. 应用场景

5.1 分子性质预测

def train_molecular_property_model():
    """
    训练E(n)等变拓扑网络预测分子性质
    """
    model = EquivariantTopoNet(
        node_dim=ATOM_FEATURES_DIM,  # 原子特征维度
        edge_dim=BOND_FEATURES_DIM,  # 键特征维度
        hidden_dim=256,
        n_layers=6,
        n_heads=4,
        output_dim=1
    )
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=100
    )
    
    for epoch in range(100):
        model.train()
        total_loss = 0
        
        for batch in dataloader:
            optimizer.zero_grad()
            
            pred = model(batch)
            loss = F.mse_loss(pred, batch.y)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()
        
        scheduler.step()
        print(f"Epoch {epoch}: Loss={total_loss:.4f}")
    
    return model

5.2 分子构象生成

class EquivariantTopoConformationGenerator(nn.Module):
    """
    等变拓扑分子构象生成器
    """
    
    def __init__(self, latent_dim, hidden_dim):
        super().__init__()
        
        # 噪声到特征的解码器
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, hidden_dim * 4),
            nn.ReLU(),
            nn.Linear(hidden_dim * 4, hidden_dim)
        )
        
        # 等变位置生成
        self.position_generator = EquivariantPositionGenerator(hidden_dim)
        
        # 拓扑约束
        self.topo_constraint = TopologicalConstraint()
    
    def forward(self, z, ref_topology):
        """
        z: 潜在编码
        ref_topology: 参考拓扑结构
        """
        # 解码特征
        h = self.decoder(z)
        
        # 生成3D位置
        positions = self.position_generator(h)
        
        # 拓扑约束损失
        topo_loss = self.topo_constraint(positions, ref_topology)
        
        return positions, topo_loss
 
class EquivariantPositionGenerator(nn.Module):
    """
    等变位置生成器
    保证生成的3D结构对旋转平移等变
    """
    
    def __init__(self, hidden_dim):
        super().__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 3)  # 输出相对位移
        )
    
    def forward(self, features):
        """
        生成相对位置(相对于质心)
        """
        # 计算质心
        centroid = features.mean(dim=0, keepdim=True)
        
        # 生成相对于质心的位移
        delta = self.mlp(features)
        
        # 位置 = 质心 + 位移
        positions = centroid + delta
        
        return positions

5.3 物理系统模拟

class PhysicsSimulationModel(nn.Module):
    """
    物理系统模拟:预测粒子系统演化
    """
    
    def __init__(self, particle_dim, hidden_dim):
        super().__init__()
        
        self.encoder = nn.Linear(particle_dim, hidden_dim)
        self.dynamics = EquivariantTopoNet(
            node_dim=hidden_dim,
            edge_dim=1,  # 距离
            hidden_dim=hidden_dim,
            n_layers=3
        )
        self.velocity_predictor = nn.Linear(hidden_dim, 3)
    
    def forward(self, positions, velocities, edge_index):
        """
        预测下一时刻的速度
        positions: (N, 3)
        velocities: (N, 3)
        """
        # 构建节点特征(位置+速度)
        node_features = torch.cat([positions, velocities], dim=-1)
        
        # 编码
        h = self.encoder(node_features)
        
        # 等变动力学预测
        output = self.dynamics(h, edge_index, positions)
        
        # 预测速度变化
        delta_v = self.velocity_predictor(output)
        
        # 预测下一时刻速度
        next_velocities = velocities + delta_v
        
        return next_velocities

6. 理论基础

6.1 等变性定理

Steerable CNN定理:对于映射 ,如果 相对于变换群 是等变的,则 必须具有特定的形式(Steerable)。

6.2 拓扑与等变性的关系

性质拓扑等变性
不变性拓扑不变量某些变换不变
结构性捕获内在结构捕获几何结构
对称性无坐标系依赖坐标变换对称

6.3 高阶交互的表示能力

# 1阶交互:成对关系
x_i^{(1)} = \sum_{j} w_{ij} x_j
 
# 2阶交互:三角形结构
x_i^{(2)} = \sum_{j,k} w_{ijk} x_j x_k
 
# k阶交互:k-单形结构
x_i^{(k)} = \sum_{j_1,\ldots,j_k} w_{ij_1\ldots j_k} \prod_{l=1}^k x_{j_l}

7. 与其他方法的对比

7.1 与标准GNN对比

特性标准GNNE(n)等变拓扑网络
几何感知
拓扑感知有限完整
高阶交互成对多阶
等变性节点置换欧几里得变换

7.2 与TFN/EGNN对比

特性TFNEGNN本文方法
旋转变换
拓扑特征
高阶交互
计算复杂度中等中等较高

7.3 计算效率

# 复杂度分析
class ComplexityAnalysis:
    @staticmethod
    def gnn_complexity(n_nodes, n_edges, hidden_dim):
        return n_edges * hidden_dim  # O(E * H)
    
    @staticmethod
    def equiv_topo_complexity(n_nodes, n_edges, n_simplices, hidden_dim):
        edges_cost = n_edges * hidden_dim
        higher_cost = n_simplices * hidden_dim * 2  # 高阶交互
        topo_cost = n_nodes * np.log(n_nodes)  # 拓扑计算
        return edges_cost + higher_cost + topo_cost

8. 实践建议

8.1 何时使用

场景推荐程度理由
分子性质预测⭐⭐⭐⭐⭐3D结构+拓扑特征
物理模拟⭐⭐⭐⭐等变性保证物理一致
蛋白质结构⭐⭐⭐⭐拓扑特征重要
图像处理⭐⭐不需要3D等变性
图分类⭐⭐⭐拓扑增强有用

8.2 超参数选择

# 推荐配置
config = {
    'hidden_dim': 256,      # 中等维度
    'n_layers': 4-6,       # 深度网络
    'n_heads': 4-8,        # 多头注意力
    'dropout': 0.1,        # 防止过拟合
    'learning_rate': 1e-4, # 标准学习率
    'weight_decay': 1e-5   # L2正则化
}

8.3 常见问题

  1. 内存消耗:高阶交互计算量大,使用稀疏表示
  2. 训练不稳定:等变约束可能导致梯度问题,使用梯度裁剪
  3. 拓扑特征选择:不是越多越好,选择相关的维度

参考文献


相关文档

Footnotes

  1. Hajij, M., et al. (2025). E(n) Equivariant Topological Neural Networks. ICLR 2025.