拓扑感知图扩散模型

NeurIPS 2025的研究提出了将拓扑约束融入图生成扩散模型的方法,为分子设计、材料发现等应用提供了新的可能性。1


1. 图生成与拓扑约束

1.1 为什么需要拓扑约束

传统图生成方法的问题:

  • 连通性:可能生成非连通的图
  • 环结构:无法保证正确的环结构(如芳香环)
  • 拓扑有效性:生成的图可能违反化学规则

拓扑约束的作用:

  • 强制连通性:确保生成连通的分子图
  • 环约束:保证正确的化学环结构
  • 结构保持:保持目标拓扑性质

1.2 图的拓扑特征

拓扑特征维度化学意义
Betti数连通分量数
独立环数(基础环)
空腔数(3D结构)
度数分布-原子价态
环分布-环的大小和数量

Betti数定义


2. 图扩散模型基础

2.1 扩散过程

前向过程(噪声添加):

反向过程(去噪生成):

2.2 图的表示

import torch
import torch.nn as nn
from torch_geometric.data import Data
 
class GraphRepresentation:
    """
    图的表示方法
    用于图扩散模型
    """
    
    def __init__(self):
        self.node_dim = 100  # 节点特征维度
        self.edge_dim = 50   # 边特征维度
    
    def graph_to_tensor(self, edge_index, node_features=None):
        """
        将图转换为扩散模型所需的张量表示
        
        Parameters:
        -----------
        edge_index : torch.Tensor (2, E)
            边索引
        node_features : torch.Tensor (N, F), optional
            节点特征
        
        Returns:
        --------
        dict : 包含邻接矩阵、距离矩阵等
        """
        n_nodes = edge_index.max().item() + 1
        
        # 邻接矩阵
        adj = torch.zeros(n_nodes, n_nodes)
        adj[edge_index[0], edge_index[1]] = 1
        adj = (adj + adj.T) / 2  # 对称化
        
        return {
            'adjacency': adj,
            'n_nodes': n_nodes,
            'node_features': node_features
        }
    
    def tensor_to_graph(self, adj, node_features=None):
        """
        将张量表示转换回图
        """
        edge_index = adj.nonzero(as_tuple=False).T
        return Data(x=node_features, edge_index=edge_index)

3. 拓扑约束的融入方法

3.1 拓扑损失函数

import torch
import torch.nn as nn
from ripser import ripser
import numpy as np
 
class TopologicalConstraintLoss(nn.Module):
    """
    拓扑约束损失函数
    强制生成图具有特定拓扑性质
    """
    
    def __init__(self, target_betti=None, lambda_topo=1.0):
        super().__init__()
        self.target_betti = target_betti or {}
        self.lambda_topo = lambda_topo
    
    def compute_betti_numbers(self, adj_matrix):
        """
        计算图的Betti数
        
        Parameters:
        -----------
        adj_matrix : torch.Tensor or np.ndarray
            邻接矩阵
        
        Returns:
        --------
        dict : Betti数
        """
        if isinstance(adj_matrix, torch.Tensor):
            adj_matrix = adj_matrix.cpu().numpy()
        
        # 计算持久同调
        # 将邻接矩阵转换为距离矩阵
        n = len(adj_matrix)
        dist_matrix = np.zeros((n, n))
        dist_matrix[adj_matrix > 0] = 1
        np.fill_diagonal(dist_matrix, 0)
        
        # 使用图距离作为过滤
        from scipy.sparse.csgraph import shortest_path
        dist_matrix = shortest_path(adj_matrix)
        dist_matrix = np.nan_to_num(dist_matrix, nan=n)
        
        # 计算持久同调
        # 将图表示为点云(使用节点索引作为坐标)
        points = np.eye(n) * 3  # 缩放以便计算
        
        result = ripser(points, maxdim=2, thresh=n)
        diagrams = result['dgms']
        
        # 估计Betti数
        betti_0 = 1  # 连通图
        betti_1 = max(0, len(diagrams[1]) - n + 1)  # 环数估计
        betti_2 = len([d for d in diagrams[2] if d[1] < float('inf')])
        
        return {
            'beta_0': betti_0,
            'beta_1': betti_1,
            'beta_2': betti_2
        }
    
    def forward(self, generated_adj, reference_adj=None):
        """
        计算拓扑约束损失
        
        Parameters:
        -----------
        generated_adj : torch.Tensor
            生成的邻接矩阵
        reference_adj : torch.Tensor, optional
            参考邻接矩阵
        """
        # 计算生成分布的Betti数
        generated_betti = self.compute_betti_numbers(generated_adj)
        
        loss = 0
        
        # 如果有参考图,计算与参考的差异
        if reference_adj is not None:
            reference_betti = self.compute_betti_numbers(reference_adj)
            
            for key in ['beta_0', 'beta_1', 'beta_2']:
                if key in self.target_betti:
                    target = self.target_betti[key]
                    generated = generated_betti.get(key, 0)
                    loss += (generated - target) ** 2
        
        # 如果指定了目标Betti数
        elif self.target_betti:
            for key, target in self.target_betti.items():
                generated = generated_betti.get(key, 0)
                loss += (generated - target) ** 2
        
        return self.lambda_topo * loss

3.2 持久同调引导的生成

class PersistenceGuidedGenerator(nn.Module):
    """
    持久同调引导的图生成器
    使用持久图特征引导生成过程
    """
    
    def __init__(self, latent_dim, hidden_dim, node_dim):
        super().__init__()
        
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        
        # 生成器
        self.generator = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, node_dim * node_dim)  # 邻接矩阵
        )
        
        # 持久图预测器
        self.persistence_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, 128),
            nn.ReLU(),
            nn.Linear(128, 27)  # 3 dims * 9 stats
        )
    
    def forward(self, z, target_persistence=None):
        """
        z: 潜在编码
        target_persistence: 目标持久图特征
        """
        # 生成邻接矩阵
        adj_flat = self.generator(z)
        n = int(np.sqrt(len(adj_flat)))
        adj = adj_flat.reshape(n, n)
        
        # 确保对称和非负
        adj = (adj + adj.T) / 2
        adj = torch.relu(adj)
        
        # 预测持久图特征
        h = self.generator[:3](z)
        pred_persistence = self.persistence_predictor(h)
        
        # 计算持久损失
        persistence_loss = 0
        if target_persistence is not None:
            persistence_loss = torch.norm(pred_persistence - target_persistence)
        
        return {
            'adjacency': adj,
            'persistence_features': pred_persistence,
            'persistence_loss': persistence_loss
        }

4. 完整拓扑感知扩散模型

4.1 模型架构

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class TopoAwareGraphDiffusion(nn.Module):
    """
    拓扑感知图扩散模型
    结合图神经网络、扩散模型和拓扑约束
    """
    
    def __init__(self, node_dim, hidden_dim, edge_dim, n_steps=1000):
        super().__init__()
        
        self.n_steps = n_steps
        
        # 时间嵌入
        self.time_embedding = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # 图编码器
        self.encoder = nn.ModuleList([
            GraphConv(node_dim + hidden_dim, hidden_dim)
            for _ in range(3)
        ])
        
        # 邻接矩阵解码器
        self.adj_decoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
        # 拓扑约束模块
        self.topo_constraint = TopologicalConstraintModule(hidden_dim)
        
        # 去噪网络
        self.denoiser = Denoiser(hidden_dim)
    
    def forward(self, x_t, t, edge_index=None, return_loss=False):
        """
        前向传播或损失计算
        """
        # 时间嵌入
        t_emb = self.time_embedding(t.unsqueeze(-1))
        
        if return_loss:
            # 训练模式:计算损失
            return self._compute_loss(x_t, t_emb, edge_index)
        else:
            # 生成模式:去噪
            return self._denoise(x_t, t_emb, edge_index)
    
    def _compute_loss(self, x_t, t_emb, edge_index):
        """
        计算训练损失
        """
        # 图卷积
        h = x_t
        for conv in self.encoder:
            h = conv(h + t_emb, edge_index)
        
        # 预测邻接矩阵
        adj_pred = self._predict_adj(h, edge_index)
        
        # 拓扑约束损失
        topo_loss = self.topo_constraint(adj_pred)
        
        return topo_loss
    
    def _denoise(self, x_t, t_emb, edge_index):
        """
        去噪过程
        """
        # 迭代去噪
        for step in reversed(range(self.n_steps)):
            # 图卷积
            h = x_t
            for conv in self.encoder:
                h = conv(h + t_emb, edge_index)
            
            # 预测更新
            x_t = self.denoiser(h, x_t)
        
        return x_t
    
    def _predict_adj(self, h, edge_index):
        """
        预测邻接矩阵
        """
        src, dst = edge_index
        
        # 节点对特征
        h_src = h[src]
        h_dst = h[dst]
        combined = torch.cat([h_src, h_dst], dim=-1)
        
        # 预测边概率
        prob = self.adj_decoder(combined).squeeze(-1)
        
        return prob

4.2 图卷积层

class GraphConv(nn.Module):
    """
    图卷积层
    """
    
    def __init__(self, in_dim, out_dim):
        super().__init__()
        
        self.lin = nn.Linear(in_dim, out_dim)
        self.edge_lin = nn.Linear(out_dim * 2, out_dim)
        
        self.norm = nn.LayerNorm(out_dim)
    
    def forward(self, x, edge_index):
        src, dst = edge_index
        
        # 消息传递
        out = torch.zeros_like(x)
        messages = self.edge_lin(torch.cat([x[src], x[dst]], dim=-1))
        out = out.index_add(0, dst, messages)
        
        # 归一化
        deg = torch.bincount(dst, minlength=x.shape[0]).float().clamp(min=1)
        out = out / deg.unsqueeze(-1)
        
        # 线性变换
        out = self.lin(out)
        out = self.norm(out)
        
        return F.relu(out)

4.3 拓扑约束模块

class TopologicalConstraintModule(nn.Module):
    """
    拓扑约束模块
    强制生成图具有正确的拓扑性质
    """
    
    def __init__(self, hidden_dim):
        super().__init__()
        
        # Betti数预测器
        self.betti_predictor = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 3)  # beta_0, beta_1, beta_2
        )
        
        # 环结构预测器
        self.ring_predictor = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 10)  # 各种环的数量
        )
        
        # 连通性预测器
        self.connectivity_predictor = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()  # 连通概率
        )
    
    def forward(self, h):
        """
        预测图的拓扑性质
        """
        # 全局池化
        h_pool = h.mean(dim=0, keepdim=True)
        
        # 预测
        betti_pred = self.betti_predictor(h_pool)
        ring_pred = self.ring_predictor(h_pool)
        conn_pred = self.connectivity_predictor(h_pool)
        
        return {
            'betti': betti_pred,
            'rings': ring_pred,
            'connectivity': conn_pred
        }
    
    def compute_topo_loss(self, predictions, targets):
        """
        计算拓扑约束损失
        """
        loss = 0
        
        # Betti数损失
        if 'betti' in targets:
            loss += F.mse_loss(predictions['betti'], targets['betti'])
        
        # 连通性损失
        if 'connected' in targets:
            target_conn = torch.tensor([targets['connected']], dtype=torch.float32)
            loss += F.binary_cross_entropy(
                predictions['connectivity'], 
                target_conn.to(predictions['connectivity'].device)
            )
        
        return loss

5. 分子图生成应用

5.1 分子约束

class MolecularTopologicalConstraints:
    """
    分子图的拓扑约束
    保证生成有效的分子
    """
    
    @staticmethod
    def valid_molecular_betti(n_atoms, n_heavy_atoms, rings):
        """
        计算有效分子的预期Betti数
        
        Parameters:
        -----------
        n_atoms : int
            总原子数
        n_heavy_atoms : int
            重原子数(非氢)
        rings : dict
            环信息,如 {'benzene': 1, 'aliphatic': 2}
        
        Returns:
        --------
        dict : 目标Betti数
        """
        # 分子图的基础性质
        n_carbons = rings.get('carbons', 0)
        
        # 估计边数(基于原子价态)
        n_edges_est = int(1.5 * n_heavy_atoms)  # 平均每个重原子1.5个键
        
        # 欧拉公式: n - m + f = 1 + c
        # n: 节点数, m: 边数, f: 面数(环数), c: 连通分量
        c = 1  # 连通分子
        n = n_heavy_atoms
        m = n_edges_est
        
        # 基础环数 = m - n + c
        base_cycles = max(0, m - n + c)
        
        # 芳香环贡献额外的拓扑复杂性
        aromatic_rings = rings.get('aromatic', 0)
        
        return {
            'beta_0': 1,  # 连通
            'beta_1': base_cycles + aromatic_rings,  # 环数
            'beta_2': 0  # 平面分子
        }
    
    @staticmethod
    def ring_size_distribution(rings):
        """
        获取环大小分布约束
        """
        distribution = torch.zeros(10)  # 最多考虑到10元环
        
        for ring_size, count in rings.items():
            if isinstance(ring_size, int) and 3 <= ring_size <= 10:
                distribution[ring_size] = count
        
        return distribution
 
def generate_molecule_with_topology(
    n_atoms, 
    target_rings=None,
    target_properties=None
):
    """
    生成具有特定拓扑的分子
    
    Parameters:
    -----------
    n_atoms : int
        目标原子数
    target_rings : dict, optional
        目标环结构
    target_properties : dict, optional
        目标分子性质
    """
    # 计算拓扑约束
    topo_constraints = MolecularTopologicalConstraints.valid_molecular_betti(
        n_atoms, 
        int(n_atoms * 0.8),  # 假设80%是重原子
        target_rings or {}
    )
    
    # 构建生成器
    model = TopoAwareGraphDiffusion(
        node_dim=50,
        hidden_dim=256,
        edge_dim=10
    )
    
    # 训练或加载模型
    # ...
    
    # 生成
    with torch.no_grad():
        z = torch.randn(1, model.latent_dim)
        target_persistence = torch.tensor([
            topo_constraints['beta_0'],
            topo_constraints['beta_1'],
            topo_constraints['beta_2']
        ])
        
        generated = model(z, target_persistence=target_persistence)
    
    return generated

5.2 完整训练流程

def train_topo_graph_diffusion(
    train_graphs,
    n_epochs=100,
    batch_size=32,
    lr=1e-4
):
    """
    训练拓扑感知图扩散模型
    """
    from torch_geometric.loader import DataLoader
    
    # 创建模型
    model = TopoAwareGraphDiffusion(
        node_dim=50,
        hidden_dim=256,
        edge_dim=10
    )
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # 数据加载器
    dataloader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    
    # 噪声调度
    betas = torch.linspace(1e-4, 0.02, n_steps)
    alphas = 1 - betas
    alpha_bar = torch.cumprod(alphas, dim=0)
    
    for epoch in range(n_epochs):
        model.train()
        total_loss = 0
        
        for batch in dataloader:
            optimizer.zero_grad()
            
            # 获取图数据
            edge_index = batch.edge_index
            x0 = batch.x  # 真实节点特征
            
            # 采样时间步
            t = torch.randint(0, n_steps, (batch.num_graphs,))
            
            # 添加噪声
            noise = torch.randn_like(x0)
            x_t = torch.sqrt(alpha_bar[t]).view(-1, 1) * x0 + \
                  torch.sqrt(1 - alpha_bar[t]).view(-1, 1) * noise
            
            # 扩展edge_index到batch
            batch_edge_index = expand_edge_index(edge_index, batch.batch)
            
            # 计算损失
            loss = model(x_t, t, batch_edge_index, return_loss=True)
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        print(f"Epoch {epoch}: Loss={total_loss/len(dataloader):.4f}")
    
    return model
 
def expand_edge_index(edge_index, batch):
    """
    扩展边索引到batch
    """
    # 获取每个图的基础偏移
    num_nodes_per_graph = torch.bincount(batch)
    offsets = torch.cumsum(
        torch.cat([torch.zeros(1, device=batch.device), num_nodes_per_graph[:-1]]),
        dim=0
    )
    
    # 为每条边添加偏移
    batch_indices = torch.arange(len(num_nodes_per_graph), device=batch.device)
    batch_indices = batch_indices.repeat_interleave(
        torch.bincount(batch).long()
    )
    
    expanded_edge_index = edge_index + offsets[batch_indices]
    
    return expanded_edge_index

6. 实验与评估

6.1 评估指标

class TopologyAwareMetrics:
    """
    拓扑感知评估指标
    """
    
    @staticmethod
    def betti_accuracy(generated_adj, target_betti):
        """
        Betti数准确率
        """
        from ripser import ripser
        
        # 计算生成分布的Betti数
        # 简化版本
        n = len(generated_adj)
        m = generated_adj.sum() / 2  # 边数
        
        # 估计Betti_1
        beta_1_est = max(0, int(m - n + 1))
        
        return 1.0 if beta_1_est == target_betti.get('beta_1', beta_1_est) else 0.0
    
    @staticmethod
    def connectivity_rate(graphs):
        """
        连通率
        """
        connected = 0
        for adj in graphs:
            if isinstance(adj, torch.Tensor):
                adj = adj.cpu().numpy()
            
            # BFS检查连通性
            n = len(adj)
            visited = [False] * n
            queue = [0]
            visited[0] = True
            
            while queue:
                node = queue.pop(0)
                for neighbor in np.where(adj[node] > 0)[0]:
                    if not visited[neighbor]:
                        visited[neighbor] = True
                        queue.append(neighbor)
            
            if all(visited):
                connected += 1
        
        return connected / len(graphs)
    
    @staticmethod
    def ring_validity(generated_molecules, target_rings):
        """
        环结构有效性
        """
        from rdkit import Chem
        
        valid = 0
        for mol in generated_molecules:
            if mol is None:
                continue
            
            # 检查分子有效性
            try:
                # 获取环信息
                ring_info = mol.GetRingInfo()
                n_rings = ring_info.NumRings()
                
                # 检查是否符合目标
                if target_rings:
                    # 简化检查
                    if abs(n_rings - target_rings.get('total', n_rings)) <= 1:
                        valid += 1
                else:
                    valid += 1
                    
            except:
                continue
        
        return valid / len(generated_molecules)

6.2 实验设置

def evaluate_topo_graph_model(model, test_graphs, target_topology):
    """
    评估拓扑感知图模型
    """
    metrics = TopologyAwareMetrics()
    
    # 生成测试集
    model.eval()
    generated_graphs = []
    
    with torch.no_grad():
        for i in range(len(test_graphs)):
            z = torch.randn(1, model.latent_dim)
            target_persistence = torch.tensor([
                target_topology['beta_0'],
                target_topology['beta_1'],
                target_topology['beta_2']
            ])
            
            generated = model(z, target_persistence=target_persistence)
            generated_graphs.append(generated['adjacency'])
    
    # 计算指标
    results = {
        'connectivity_rate': metrics.connectivity_rate(generated_graphs),
        'betti_accuracy': np.mean([
            metrics.betti_accuracy(g, target_topology) 
            for g in generated_graphs
        ])
    }
    
    return results

7. 最新研究进展

7.1 NeurIPS 2025工作

Topology-Aware Graph Diffusion Model with Persistent Homology

核心贡献:

  1. 持久同调引导:使用PH特征引导生成过程
  2. 拓扑损失:强制生成图满足拓扑约束
  3. 分子应用:生成有效的药物分子

7.2 未来方向

方向描述潜力
3D分子生成拓扑+几何约束⭐⭐⭐⭐⭐
材料设计晶体结构生成⭐⭐⭐⭐
动态图生成时变网络⭐⭐⭐⭐

参考文献


相关文档

Footnotes

  1. Chen, Y., et al. (2025). Topology-Aware Graph Diffusion Model with Persistent Homology. NeurIPS 2025.