概述
Graph Transformer将Transformer的自注意力机制扩展到图结构数据,通过全局注意力捕获任意节点对之间的依赖关系,克服了传统消息传递GNN无法捕获长距离依赖的局限。12
Graph Transformer vs 传统GNN
| 特性 | 传统GNN | Graph Transformer |
|---|---|---|
| 邻居聚合 | 仅邻居节点 | 所有节点(全连接) |
| 依赖距离 | 受限于层数 | 任意距离 |
| 计算复杂度 | ||
| 位置感知 | 依赖图结构 | 可编码任意位置 |
| 表达能力 | ≤1-WL | 超越1-WL |
核心挑战
1. 位置编码问题
传统Transformer使用序列位置编码,但图没有自然的位置概念。需要设计图专属的位置编码。
2. 结构感知问题
节点之间的结构关系(如距离、同构性)需要在注意力中体现。
3. 计算效率
全连接注意力在大型图上计算成本高昂,需要高效实现。
位置编码方法
拉普拉斯特征向量位置编码(LAPE)
理论基础
使用拉普拉斯矩阵的特征向量作为位置编码:
取最小的 个非平凡特征向量:
物理意义
- 特征向量对应图的傅里叶基
- 小特征值 → 低频 → 全局/粗粒度信息
- 大特征值 → 高频 → 局部/细粒度信息
PyTorch实现
import torch
import torch.nn as nn
from scipy import sparse
from scipy.sparse.linalg import eigsh
import numpy as np
def compute_laplacian_pe(edge_index, num_nodes, k=16):
"""计算拉普拉斯位置编码"""
# 构建稀疏邻接矩阵
adj = sparse.lil_matrix((num_nodes, num_nodes))
for i, j in zip(edge_index[0], edge_index[1]):
adj[i, j] = 1.0
adj = adj.tocsr()
# 度矩阵
d = np.array(adj.sum(axis=1)).flatten()
D = sparse.diags(d)
# 归一化拉普拉斯
D_inv_sqrt = sparse.diags(1.0 / np.sqrt(d + 1e-10))
L = sparse.eye(num_nodes) - D_inv_sqrt @ adj @ D_inv_sqrt
# 特征分解
eigenvalues, eigenvectors = eigsh(L.astype(float), k=k+1, which='SM')
# 去掉第一个特征向量(全1)
eigenvectors = eigenvectors[:, 1:k+1]
return torch.from_numpy(eigenvectors).float()在模型中使用
class LapPE_GNN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, k_pe=16):
super().__init__()
self.k_pe = k_pe
self.pe_encoder = nn.Linear(k_pe, hidden_channels)
self.lin1 = nn.Linear(in_channels + hidden_channels, hidden_channels)
self.lin2 = nn.Linear(hidden_channels, out_channels)
def forward(self, x, edge_index, pe=None):
# pe: (N, k_pe) 拉普拉斯位置编码
if pe is not None:
pe_emb = self.pe_encoder(pe)
x = torch.cat([x, pe_emb], dim=-1)
x = self.lin1(x)
x = F.relu(x)
x = self.lin2(x)
return x随机游走位置编码(RWPE)
理论基础
定义从节点 开始长度为 的随机游走分布:
嵌入方法
使用负无穷范数编码随机游走概率:
有效半径分析
随机游走有效半径:
对于树状图,。
最短路径距离编码
定义
编码为可学习的嵌入:
其中 是距离 的独热编码。
位置编码对比
| 方法 | 维度 | 表达信息 | 计算成本 |
|---|---|---|---|
| LAPE | 谱域/全局 | 中等 | |
| RWPE | 随机游走统计 | 高 | |
| SPD | 最短路径 | 高 | |
| 相对位置 | 相对距离 | 低 |
Self-Attention Network (SAN)
架构设计
SAN (Dwivedi & Bresson, 2020) 将Transformer直接应用于图结构:
其中 是边嵌入。
边嵌入
边嵌入通过MLP编码边属性:
拉普拉斯位置编码集成
class SAN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers, num_heads, k_pe=16):
super().__init__()
self.k_pe = k_pe
# 特征编码
self.node_emb = nn.Linear(in_channels, hidden_channels)
self.pe_encoder = nn.Linear(k_pe, hidden_channels)
# 边编码(如果需要)
self.edge_emb = nn.Linear(edge_dim, hidden_channels)
# 多层Transformer
self.layers = nn.ModuleList([
TransformerLayer(hidden_channels, num_heads)
for _ in range(num_layers)
])
# 输出
self.classifier = nn.Linear(hidden_channels, out_channels)
def forward(self, x, edge_index, pe, edge_attr=None):
N = x.shape[0]
# 编码节点特征
h = self.node_emb(x)
# 编码位置信息
pe_emb = self.pe_encoder(pe)
h = h + pe_emb
# 可选:边嵌入
if edge_attr is not None:
edge_emb = self.edge_emb(edge_attr)
else:
edge_emb = torch.zeros(N, N, h.shape[-1], device=h.device)
# Transformer层
for layer in self.layers:
h = layer(h, edge_emb)
# 图级别输出
h = h.mean(dim=0) # 池化
return self.classifier(h)复杂度分析
| 操作 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| 注意力计算 | ||
| 位置编码 | ||
| 总计 |
Graphormer
设计理念
Graphormer (Ying et al., 2021) 是微软设计的图Transformer,在分子预测等任务上取得SOTA。
核心组件
1. 中心性编码 (Centrality Encoding)
编码节点的度数信息:
其中 是节点 的度。
2. 空间编码 (Spatial Encoding)
编码节点间的最短路径距离:
或使用可学习的嵌入表。
3. 边编码 (Edge Encoding)
对于有边属性的图:
其中 是将边属性投影到维度的函数。
完整注意力计算
其中 是中心性增强的查询。
分子图应用
class Graphormer(nn.Module):
def __init__(self, num_atoms, num_bonds, num_classes, hidden_dim=768, num_layers=12, num_heads=32):
super().__init__()
# 原子嵌入
self.atom_embedding = nn.Embedding(num_atoms, hidden_dim)
# 中心性嵌入
self.degree_embedding = nn.Embedding(128, hidden_dim) # 最大度数为127
# 偏置嵌入
self.distance_embedding = nn.Embedding(128, num_heads) # 距离嵌入
# 边嵌入
self.bond_embedding = nn.Embedding(num_bonds, hidden_dim)
# Transformer层
self.layers = nn.ModuleList([
GraphormerLayer(hidden_dim, num_heads)
for _ in range(num_layers)
])
# 输出
self.norm = nn.LayerNorm(hidden_dim)
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, batch):
# 获取节点嵌入
x = self.atom_embedding(batch['atoms'])
deg = self.degree_embedding(batch['degrees'])
x = x + deg
# 计算注意力偏置
attn_bias = self._compute_spatial_bias(batch) # (B, num_heads, N, N)
# 边嵌入
if 'bonds' in batch:
edge_emb = self.bond_embedding(batch['bonds'])
else:
edge_emb = 0
# Transformer
for layer in self.layers:
x = layer(x, attn_bias, edge_emb)
x = self.norm(x)
# 图级别池化
graph_emb = x.mean(dim=0)
return self.classifier(graph_emb)
def _compute_spatial_bias(self, batch):
# 计算空间偏置(基于最短路径距离)
# ...
return spatial_bias与传统GNN的对比
表达能力分析
| 模型 | WL测试 | 表达能力 |
|---|---|---|
| GCN | 1-WL | ≤1-WL |
| GIN | 1-WL | =1-WL |
| Graph Transformer | ≤k-WL | 超越1-WL |
实验对比
在ZINC分子数据集上的结果:
| 模型 | MAE ↓ |
|---|---|
| GCN | 0.246 |
| GIN | 0.210 |
| Graphormer | 0.123 |
| SAN | 0.139 |
适用场景
| 场景 | 推荐模型 |
|---|---|
| 小图(<1K节点) | Graph Transformer |
| 大图(>10K节点) | GCN/GAT(稀疏注意力) |
| 分子图 | Graphormer |
| 超大图 | 图采样 + Graph Transformer |
高效实现
图采样策略
def sample_subgraph(node_idx, num_hops=2, num_neighbors=10):
"""Graphormer风格的图采样"""
# BFS采样
frontier = {node_idx}
for _ in range(num_hops):
neighbors = []
for node in frontier:
neighbors.extend(get_neighbors(node))
# 随机选择邻居
neighbors = random.sample(neighbors, min(num_neighbors, len(neighbors)))
frontier.update(neighbors)
return list(frontier)稀疏注意力
from torch_sparse import SparseTensor
def sparse_graph_attention(x, adj, num_heads=8):
"""稀疏注意力实现"""
N = x.shape[0]
d_k = x.shape[1] // num_heads
# 稀疏邻接矩阵
adj_t = SparseTensor(row=adj[0], col=adj[1])
# QKV投影
Q = x @ W_q # (N, num_heads, d_k)
K = x @ W_k
V = x @ W_v
# 稀疏矩阵乘法实现注意力
# ...实战代码:分子性质预测
数据集
使用ZINC分子数据集,预测分子的溶解度。
import torch
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
# 加载数据
train_dataset = ZINC(root='/tmp/ZINC', subset=True, split='train')
val_dataset = ZINC(root='/tmp/ZINC', subset=True, split='val')
test_dataset = ZINC(root='/tmp/ZINC', subset=True, split='test')
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
# 计算拉普拉斯PE
def add_laplacian_pe(data, k=16):
pe = compute_laplacian_pe(data.edge_index, data.num_nodes, k)
data.pe = pe
return data
train_dataset = train_dataset.map(add_laplacian_pe)模型定义
class GraphTransformerForZINC(nn.Module):
def __init__(self, num_atom_types, num_bond_types, hidden_dim=168,
num_layers=6, num_heads=8, k_pe=16):
super().__init__()
self.atom_embedding = nn.Embedding(num_atom_types, hidden_dim)
self.pe_encoder = nn.Linear(k_pe, hidden_dim)
self.bond_embedding = nn.Embedding(num_bond_types, hidden_dim)
self.layers = nn.ModuleList([
GraphTransformerLayer(hidden_dim, num_heads)
for _ in range(num_layers)
])
self.final_norm = nn.LayerNorm(hidden_dim)
self.out_mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, batch):
x = self.atom_embedding(batch.x.squeeze())
pe = self.pe_encoder(batch.pe)
x = x + pe
# 可选的边嵌入
if batch.edge_attr is not None:
edge_emb = self.bond_embedding(batch.edge_attr)
x = x + edge_emb.mean(dim=0, keepdim=True)
# Transformer
for layer in self.layers:
x = layer(x, batch.edge_index)
x = self.final_norm(x)
# 分子级别池化
mol_emb = global_mean_pool(x, batch.batch)
return self.out_mlp(mol_emb).squeeze()训练
model = GraphTransformerForZINC(
num_atom_types=train_dataset.num_atom_types,
num_bond_types=train_dataset.num_bond_types
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.L1Loss()
for epoch in range(100):
model.train()
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
pred = model(batch)
loss = criterion(pred, batch.y)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch}: Loss = {total_loss/len(train_loader):.4f}")总结与展望
核心要点
- 位置编码是Graph Transformer的关键组件
- 拉普拉斯PE提供谱域视角,随机游走PE提供扩散视角
- 中心性编码和空间编码增强结构感知
- Graph Transformer在小图上效果显著优于传统GNN
未来方向
| 方向 | 研究问题 |
|---|---|
| 高效大规模 | 如何处理百万节点图? |
| 动态图 | 时序图的Transformer? |
| 多模态 | 结合分子3D结构信息 |
| 可解释性 | Graph Transformer的电路分析 |
参考
相关词条:图神经网络,GAT图注意力网络,Transformer数学基础,Swin Transformer