拓扑感知图扩散模型
NeurIPS 2025的研究提出了将拓扑约束融入图生成扩散模型的方法,为分子设计、材料发现等应用提供了新的可能性。1
1. 图生成与拓扑约束
1.1 为什么需要拓扑约束
传统图生成方法的问题:
- 连通性:可能生成非连通的图
- 环结构:无法保证正确的环结构(如芳香环)
- 拓扑有效性:生成的图可能违反化学规则
拓扑约束的作用:
- 强制连通性:确保生成连通的分子图
- 环约束:保证正确的化学环结构
- 结构保持:保持目标拓扑性质
1.2 图的拓扑特征
| 拓扑特征 | 维度 | 化学意义 |
|---|---|---|
| Betti数 | 连通分量数 | |
| 独立环数(基础环) | ||
| 空腔数(3D结构) | ||
| 度数分布 | - | 原子价态 |
| 环分布 | - | 环的大小和数量 |
Betti数定义:
2. 图扩散模型基础
2.1 扩散过程
前向过程(噪声添加):
反向过程(去噪生成):
2.2 图的表示
import torch
import torch.nn as nn
from torch_geometric.data import Data
class GraphRepresentation:
"""
图的表示方法
用于图扩散模型
"""
def __init__(self):
self.node_dim = 100 # 节点特征维度
self.edge_dim = 50 # 边特征维度
def graph_to_tensor(self, edge_index, node_features=None):
"""
将图转换为扩散模型所需的张量表示
Parameters:
-----------
edge_index : torch.Tensor (2, E)
边索引
node_features : torch.Tensor (N, F), optional
节点特征
Returns:
--------
dict : 包含邻接矩阵、距离矩阵等
"""
n_nodes = edge_index.max().item() + 1
# 邻接矩阵
adj = torch.zeros(n_nodes, n_nodes)
adj[edge_index[0], edge_index[1]] = 1
adj = (adj + adj.T) / 2 # 对称化
return {
'adjacency': adj,
'n_nodes': n_nodes,
'node_features': node_features
}
def tensor_to_graph(self, adj, node_features=None):
"""
将张量表示转换回图
"""
edge_index = adj.nonzero(as_tuple=False).T
return Data(x=node_features, edge_index=edge_index)3. 拓扑约束的融入方法
3.1 拓扑损失函数
import torch
import torch.nn as nn
from ripser import ripser
import numpy as np
class TopologicalConstraintLoss(nn.Module):
"""
拓扑约束损失函数
强制生成图具有特定拓扑性质
"""
def __init__(self, target_betti=None, lambda_topo=1.0):
super().__init__()
self.target_betti = target_betti or {}
self.lambda_topo = lambda_topo
def compute_betti_numbers(self, adj_matrix):
"""
计算图的Betti数
Parameters:
-----------
adj_matrix : torch.Tensor or np.ndarray
邻接矩阵
Returns:
--------
dict : Betti数
"""
if isinstance(adj_matrix, torch.Tensor):
adj_matrix = adj_matrix.cpu().numpy()
# 计算持久同调
# 将邻接矩阵转换为距离矩阵
n = len(adj_matrix)
dist_matrix = np.zeros((n, n))
dist_matrix[adj_matrix > 0] = 1
np.fill_diagonal(dist_matrix, 0)
# 使用图距离作为过滤
from scipy.sparse.csgraph import shortest_path
dist_matrix = shortest_path(adj_matrix)
dist_matrix = np.nan_to_num(dist_matrix, nan=n)
# 计算持久同调
# 将图表示为点云(使用节点索引作为坐标)
points = np.eye(n) * 3 # 缩放以便计算
result = ripser(points, maxdim=2, thresh=n)
diagrams = result['dgms']
# 估计Betti数
betti_0 = 1 # 连通图
betti_1 = max(0, len(diagrams[1]) - n + 1) # 环数估计
betti_2 = len([d for d in diagrams[2] if d[1] < float('inf')])
return {
'beta_0': betti_0,
'beta_1': betti_1,
'beta_2': betti_2
}
def forward(self, generated_adj, reference_adj=None):
"""
计算拓扑约束损失
Parameters:
-----------
generated_adj : torch.Tensor
生成的邻接矩阵
reference_adj : torch.Tensor, optional
参考邻接矩阵
"""
# 计算生成分布的Betti数
generated_betti = self.compute_betti_numbers(generated_adj)
loss = 0
# 如果有参考图,计算与参考的差异
if reference_adj is not None:
reference_betti = self.compute_betti_numbers(reference_adj)
for key in ['beta_0', 'beta_1', 'beta_2']:
if key in self.target_betti:
target = self.target_betti[key]
generated = generated_betti.get(key, 0)
loss += (generated - target) ** 2
# 如果指定了目标Betti数
elif self.target_betti:
for key, target in self.target_betti.items():
generated = generated_betti.get(key, 0)
loss += (generated - target) ** 2
return self.lambda_topo * loss3.2 持久同调引导的生成
class PersistenceGuidedGenerator(nn.Module):
"""
持久同调引导的图生成器
使用持久图特征引导生成过程
"""
def __init__(self, latent_dim, hidden_dim, node_dim):
super().__init__()
self.latent_dim = latent_dim
self.hidden_dim = hidden_dim
# 生成器
self.generator = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim * 2),
nn.ReLU(),
nn.Linear(hidden_dim * 2, node_dim * node_dim) # 邻接矩阵
)
# 持久图预测器
self.persistence_predictor = nn.Sequential(
nn.Linear(hidden_dim * 2, 128),
nn.ReLU(),
nn.Linear(128, 27) # 3 dims * 9 stats
)
def forward(self, z, target_persistence=None):
"""
z: 潜在编码
target_persistence: 目标持久图特征
"""
# 生成邻接矩阵
adj_flat = self.generator(z)
n = int(np.sqrt(len(adj_flat)))
adj = adj_flat.reshape(n, n)
# 确保对称和非负
adj = (adj + adj.T) / 2
adj = torch.relu(adj)
# 预测持久图特征
h = self.generator[:3](z)
pred_persistence = self.persistence_predictor(h)
# 计算持久损失
persistence_loss = 0
if target_persistence is not None:
persistence_loss = torch.norm(pred_persistence - target_persistence)
return {
'adjacency': adj,
'persistence_features': pred_persistence,
'persistence_loss': persistence_loss
}4. 完整拓扑感知扩散模型
4.1 模型架构
import torch
import torch.nn as nn
import torch.nn.functional as F
class TopoAwareGraphDiffusion(nn.Module):
"""
拓扑感知图扩散模型
结合图神经网络、扩散模型和拓扑约束
"""
def __init__(self, node_dim, hidden_dim, edge_dim, n_steps=1000):
super().__init__()
self.n_steps = n_steps
# 时间嵌入
self.time_embedding = nn.Sequential(
nn.Linear(1, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# 图编码器
self.encoder = nn.ModuleList([
GraphConv(node_dim + hidden_dim, hidden_dim)
for _ in range(3)
])
# 邻接矩阵解码器
self.adj_decoder = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
# 拓扑约束模块
self.topo_constraint = TopologicalConstraintModule(hidden_dim)
# 去噪网络
self.denoiser = Denoiser(hidden_dim)
def forward(self, x_t, t, edge_index=None, return_loss=False):
"""
前向传播或损失计算
"""
# 时间嵌入
t_emb = self.time_embedding(t.unsqueeze(-1))
if return_loss:
# 训练模式:计算损失
return self._compute_loss(x_t, t_emb, edge_index)
else:
# 生成模式:去噪
return self._denoise(x_t, t_emb, edge_index)
def _compute_loss(self, x_t, t_emb, edge_index):
"""
计算训练损失
"""
# 图卷积
h = x_t
for conv in self.encoder:
h = conv(h + t_emb, edge_index)
# 预测邻接矩阵
adj_pred = self._predict_adj(h, edge_index)
# 拓扑约束损失
topo_loss = self.topo_constraint(adj_pred)
return topo_loss
def _denoise(self, x_t, t_emb, edge_index):
"""
去噪过程
"""
# 迭代去噪
for step in reversed(range(self.n_steps)):
# 图卷积
h = x_t
for conv in self.encoder:
h = conv(h + t_emb, edge_index)
# 预测更新
x_t = self.denoiser(h, x_t)
return x_t
def _predict_adj(self, h, edge_index):
"""
预测邻接矩阵
"""
src, dst = edge_index
# 节点对特征
h_src = h[src]
h_dst = h[dst]
combined = torch.cat([h_src, h_dst], dim=-1)
# 预测边概率
prob = self.adj_decoder(combined).squeeze(-1)
return prob4.2 图卷积层
class GraphConv(nn.Module):
"""
图卷积层
"""
def __init__(self, in_dim, out_dim):
super().__init__()
self.lin = nn.Linear(in_dim, out_dim)
self.edge_lin = nn.Linear(out_dim * 2, out_dim)
self.norm = nn.LayerNorm(out_dim)
def forward(self, x, edge_index):
src, dst = edge_index
# 消息传递
out = torch.zeros_like(x)
messages = self.edge_lin(torch.cat([x[src], x[dst]], dim=-1))
out = out.index_add(0, dst, messages)
# 归一化
deg = torch.bincount(dst, minlength=x.shape[0]).float().clamp(min=1)
out = out / deg.unsqueeze(-1)
# 线性变换
out = self.lin(out)
out = self.norm(out)
return F.relu(out)4.3 拓扑约束模块
class TopologicalConstraintModule(nn.Module):
"""
拓扑约束模块
强制生成图具有正确的拓扑性质
"""
def __init__(self, hidden_dim):
super().__init__()
# Betti数预测器
self.betti_predictor = nn.Sequential(
nn.Linear(hidden_dim, 64),
nn.ReLU(),
nn.Linear(64, 3) # beta_0, beta_1, beta_2
)
# 环结构预测器
self.ring_predictor = nn.Sequential(
nn.Linear(hidden_dim, 64),
nn.ReLU(),
nn.Linear(64, 10) # 各种环的数量
)
# 连通性预测器
self.connectivity_predictor = nn.Sequential(
nn.Linear(hidden_dim, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Sigmoid() # 连通概率
)
def forward(self, h):
"""
预测图的拓扑性质
"""
# 全局池化
h_pool = h.mean(dim=0, keepdim=True)
# 预测
betti_pred = self.betti_predictor(h_pool)
ring_pred = self.ring_predictor(h_pool)
conn_pred = self.connectivity_predictor(h_pool)
return {
'betti': betti_pred,
'rings': ring_pred,
'connectivity': conn_pred
}
def compute_topo_loss(self, predictions, targets):
"""
计算拓扑约束损失
"""
loss = 0
# Betti数损失
if 'betti' in targets:
loss += F.mse_loss(predictions['betti'], targets['betti'])
# 连通性损失
if 'connected' in targets:
target_conn = torch.tensor([targets['connected']], dtype=torch.float32)
loss += F.binary_cross_entropy(
predictions['connectivity'],
target_conn.to(predictions['connectivity'].device)
)
return loss5. 分子图生成应用
5.1 分子约束
class MolecularTopologicalConstraints:
"""
分子图的拓扑约束
保证生成有效的分子
"""
@staticmethod
def valid_molecular_betti(n_atoms, n_heavy_atoms, rings):
"""
计算有效分子的预期Betti数
Parameters:
-----------
n_atoms : int
总原子数
n_heavy_atoms : int
重原子数(非氢)
rings : dict
环信息,如 {'benzene': 1, 'aliphatic': 2}
Returns:
--------
dict : 目标Betti数
"""
# 分子图的基础性质
n_carbons = rings.get('carbons', 0)
# 估计边数(基于原子价态)
n_edges_est = int(1.5 * n_heavy_atoms) # 平均每个重原子1.5个键
# 欧拉公式: n - m + f = 1 + c
# n: 节点数, m: 边数, f: 面数(环数), c: 连通分量
c = 1 # 连通分子
n = n_heavy_atoms
m = n_edges_est
# 基础环数 = m - n + c
base_cycles = max(0, m - n + c)
# 芳香环贡献额外的拓扑复杂性
aromatic_rings = rings.get('aromatic', 0)
return {
'beta_0': 1, # 连通
'beta_1': base_cycles + aromatic_rings, # 环数
'beta_2': 0 # 平面分子
}
@staticmethod
def ring_size_distribution(rings):
"""
获取环大小分布约束
"""
distribution = torch.zeros(10) # 最多考虑到10元环
for ring_size, count in rings.items():
if isinstance(ring_size, int) and 3 <= ring_size <= 10:
distribution[ring_size] = count
return distribution
def generate_molecule_with_topology(
n_atoms,
target_rings=None,
target_properties=None
):
"""
生成具有特定拓扑的分子
Parameters:
-----------
n_atoms : int
目标原子数
target_rings : dict, optional
目标环结构
target_properties : dict, optional
目标分子性质
"""
# 计算拓扑约束
topo_constraints = MolecularTopologicalConstraints.valid_molecular_betti(
n_atoms,
int(n_atoms * 0.8), # 假设80%是重原子
target_rings or {}
)
# 构建生成器
model = TopoAwareGraphDiffusion(
node_dim=50,
hidden_dim=256,
edge_dim=10
)
# 训练或加载模型
# ...
# 生成
with torch.no_grad():
z = torch.randn(1, model.latent_dim)
target_persistence = torch.tensor([
topo_constraints['beta_0'],
topo_constraints['beta_1'],
topo_constraints['beta_2']
])
generated = model(z, target_persistence=target_persistence)
return generated5.2 完整训练流程
def train_topo_graph_diffusion(
train_graphs,
n_epochs=100,
batch_size=32,
lr=1e-4
):
"""
训练拓扑感知图扩散模型
"""
from torch_geometric.loader import DataLoader
# 创建模型
model = TopoAwareGraphDiffusion(
node_dim=50,
hidden_dim=256,
edge_dim=10
)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# 数据加载器
dataloader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
# 噪声调度
betas = torch.linspace(1e-4, 0.02, n_steps)
alphas = 1 - betas
alpha_bar = torch.cumprod(alphas, dim=0)
for epoch in range(n_epochs):
model.train()
total_loss = 0
for batch in dataloader:
optimizer.zero_grad()
# 获取图数据
edge_index = batch.edge_index
x0 = batch.x # 真实节点特征
# 采样时间步
t = torch.randint(0, n_steps, (batch.num_graphs,))
# 添加噪声
noise = torch.randn_like(x0)
x_t = torch.sqrt(alpha_bar[t]).view(-1, 1) * x0 + \
torch.sqrt(1 - alpha_bar[t]).view(-1, 1) * noise
# 扩展edge_index到batch
batch_edge_index = expand_edge_index(edge_index, batch.batch)
# 计算损失
loss = model(x_t, t, batch_edge_index, return_loss=True)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch}: Loss={total_loss/len(dataloader):.4f}")
return model
def expand_edge_index(edge_index, batch):
"""
扩展边索引到batch
"""
# 获取每个图的基础偏移
num_nodes_per_graph = torch.bincount(batch)
offsets = torch.cumsum(
torch.cat([torch.zeros(1, device=batch.device), num_nodes_per_graph[:-1]]),
dim=0
)
# 为每条边添加偏移
batch_indices = torch.arange(len(num_nodes_per_graph), device=batch.device)
batch_indices = batch_indices.repeat_interleave(
torch.bincount(batch).long()
)
expanded_edge_index = edge_index + offsets[batch_indices]
return expanded_edge_index6. 实验与评估
6.1 评估指标
class TopologyAwareMetrics:
"""
拓扑感知评估指标
"""
@staticmethod
def betti_accuracy(generated_adj, target_betti):
"""
Betti数准确率
"""
from ripser import ripser
# 计算生成分布的Betti数
# 简化版本
n = len(generated_adj)
m = generated_adj.sum() / 2 # 边数
# 估计Betti_1
beta_1_est = max(0, int(m - n + 1))
return 1.0 if beta_1_est == target_betti.get('beta_1', beta_1_est) else 0.0
@staticmethod
def connectivity_rate(graphs):
"""
连通率
"""
connected = 0
for adj in graphs:
if isinstance(adj, torch.Tensor):
adj = adj.cpu().numpy()
# BFS检查连通性
n = len(adj)
visited = [False] * n
queue = [0]
visited[0] = True
while queue:
node = queue.pop(0)
for neighbor in np.where(adj[node] > 0)[0]:
if not visited[neighbor]:
visited[neighbor] = True
queue.append(neighbor)
if all(visited):
connected += 1
return connected / len(graphs)
@staticmethod
def ring_validity(generated_molecules, target_rings):
"""
环结构有效性
"""
from rdkit import Chem
valid = 0
for mol in generated_molecules:
if mol is None:
continue
# 检查分子有效性
try:
# 获取环信息
ring_info = mol.GetRingInfo()
n_rings = ring_info.NumRings()
# 检查是否符合目标
if target_rings:
# 简化检查
if abs(n_rings - target_rings.get('total', n_rings)) <= 1:
valid += 1
else:
valid += 1
except:
continue
return valid / len(generated_molecules)6.2 实验设置
def evaluate_topo_graph_model(model, test_graphs, target_topology):
"""
评估拓扑感知图模型
"""
metrics = TopologyAwareMetrics()
# 生成测试集
model.eval()
generated_graphs = []
with torch.no_grad():
for i in range(len(test_graphs)):
z = torch.randn(1, model.latent_dim)
target_persistence = torch.tensor([
target_topology['beta_0'],
target_topology['beta_1'],
target_topology['beta_2']
])
generated = model(z, target_persistence=target_persistence)
generated_graphs.append(generated['adjacency'])
# 计算指标
results = {
'connectivity_rate': metrics.connectivity_rate(generated_graphs),
'betti_accuracy': np.mean([
metrics.betti_accuracy(g, target_topology)
for g in generated_graphs
])
}
return results7. 最新研究进展
7.1 NeurIPS 2025工作
Topology-Aware Graph Diffusion Model with Persistent Homology
核心贡献:
- 持久同调引导:使用PH特征引导生成过程
- 拓扑损失:强制生成图满足拓扑约束
- 分子应用:生成有效的药物分子
7.2 未来方向
| 方向 | 描述 | 潜力 |
|---|---|---|
| 3D分子生成 | 拓扑+几何约束 | ⭐⭐⭐⭐⭐ |
| 材料设计 | 晶体结构生成 | ⭐⭐⭐⭐ |
| 动态图生成 | 时变网络 | ⭐⭐⭐⭐ |
参考文献
相关文档
Footnotes
-
Chen, Y., et al. (2025). Topology-Aware Graph Diffusion Model with Persistent Homology. NeurIPS 2025. ↩