概述

Graph Transformer将Transformer的自注意力机制扩展到图结构数据,通过全局注意力捕获任意节点对之间的依赖关系,克服了传统消息传递GNN无法捕获长距离依赖的局限。12

Graph Transformer vs 传统GNN

特性传统GNNGraph Transformer
邻居聚合仅邻居节点所有节点(全连接)
依赖距离受限于层数任意距离
计算复杂度
位置感知依赖图结构可编码任意位置
表达能力≤1-WL超越1-WL

核心挑战

1. 位置编码问题

传统Transformer使用序列位置编码,但图没有自然的位置概念。需要设计图专属的位置编码

2. 结构感知问题

节点之间的结构关系(如距离、同构性)需要在注意力中体现。

3. 计算效率

全连接注意力在大型图上计算成本高昂,需要高效实现。


位置编码方法

拉普拉斯特征向量位置编码(LAPE)

理论基础

使用拉普拉斯矩阵的特征向量作为位置编码:

取最小的 个非平凡特征向量:

物理意义

  • 特征向量对应图的傅里叶基
  • 小特征值 → 低频 → 全局/粗粒度信息
  • 大特征值 → 高频 → 局部/细粒度信息

PyTorch实现

import torch
import torch.nn as nn
from scipy import sparse
from scipy.sparse.linalg import eigsh
import numpy as np
 
def compute_laplacian_pe(edge_index, num_nodes, k=16):
    """计算拉普拉斯位置编码"""
    # 构建稀疏邻接矩阵
    adj = sparse.lil_matrix((num_nodes, num_nodes))
    for i, j in zip(edge_index[0], edge_index[1]):
        adj[i, j] = 1.0
    adj = adj.tocsr()
    
    # 度矩阵
    d = np.array(adj.sum(axis=1)).flatten()
    D = sparse.diags(d)
    
    # 归一化拉普拉斯
    D_inv_sqrt = sparse.diags(1.0 / np.sqrt(d + 1e-10))
    L = sparse.eye(num_nodes) - D_inv_sqrt @ adj @ D_inv_sqrt
    
    # 特征分解
    eigenvalues, eigenvectors = eigsh(L.astype(float), k=k+1, which='SM')
    
    # 去掉第一个特征向量(全1)
    eigenvectors = eigenvectors[:, 1:k+1]
    
    return torch.from_numpy(eigenvectors).float()

在模型中使用

class LapPE_GNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, k_pe=16):
        super().__init__()
        self.k_pe = k_pe
        self.pe_encoder = nn.Linear(k_pe, hidden_channels)
        self.lin1 = nn.Linear(in_channels + hidden_channels, hidden_channels)
        self.lin2 = nn.Linear(hidden_channels, out_channels)
    
    def forward(self, x, edge_index, pe=None):
        # pe: (N, k_pe) 拉普拉斯位置编码
        if pe is not None:
            pe_emb = self.pe_encoder(pe)
            x = torch.cat([x, pe_emb], dim=-1)
        
        x = self.lin1(x)
        x = F.relu(x)
        x = self.lin2(x)
        return x

随机游走位置编码(RWPE)

理论基础

定义从节点 开始长度为 的随机游走分布:

嵌入方法

使用负无穷范数编码随机游走概率:

有效半径分析

随机游走有效半径:

对于树状图,

最短路径距离编码

定义

编码为可学习的嵌入:

其中 是距离 的独热编码。

位置编码对比

方法维度表达信息计算成本
LAPE谱域/全局中等
RWPE随机游走统计
SPD最短路径
相对位置相对距离

Self-Attention Network (SAN)

架构设计

SAN (Dwivedi & Bresson, 2020) 将Transformer直接应用于图结构:

其中 是边嵌入。

边嵌入

边嵌入通过MLP编码边属性:

拉普拉斯位置编码集成

class SAN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, num_heads, k_pe=16):
        super().__init__()
        self.k_pe = k_pe
        
        # 特征编码
        self.node_emb = nn.Linear(in_channels, hidden_channels)
        self.pe_encoder = nn.Linear(k_pe, hidden_channels)
        
        # 边编码(如果需要)
        self.edge_emb = nn.Linear(edge_dim, hidden_channels)
        
        # 多层Transformer
        self.layers = nn.ModuleList([
            TransformerLayer(hidden_channels, num_heads)
            for _ in range(num_layers)
        ])
        
        # 输出
        self.classifier = nn.Linear(hidden_channels, out_channels)
    
    def forward(self, x, edge_index, pe, edge_attr=None):
        N = x.shape[0]
        
        # 编码节点特征
        h = self.node_emb(x)
        
        # 编码位置信息
        pe_emb = self.pe_encoder(pe)
        h = h + pe_emb
        
        # 可选:边嵌入
        if edge_attr is not None:
            edge_emb = self.edge_emb(edge_attr)
        else:
            edge_emb = torch.zeros(N, N, h.shape[-1], device=h.device)
        
        # Transformer层
        for layer in self.layers:
            h = layer(h, edge_emb)
        
        # 图级别输出
        h = h.mean(dim=0)  # 池化
        return self.classifier(h)

复杂度分析

操作时间复杂度空间复杂度
注意力计算
位置编码
总计

Graphormer

设计理念

Graphormer (Ying et al., 2021) 是微软设计的图Transformer,在分子预测等任务上取得SOTA。

核心组件

1. 中心性编码 (Centrality Encoding)

编码节点的度数信息:

其中 是节点 的度。

2. 空间编码 (Spatial Encoding)

编码节点间的最短路径距离:

或使用可学习的嵌入表。

3. 边编码 (Edge Encoding)

对于有边属性的图:

其中 是将边属性投影到维度的函数。

完整注意力计算

其中 是中心性增强的查询。

分子图应用

class Graphormer(nn.Module):
    def __init__(self, num_atoms, num_bonds, num_classes, hidden_dim=768, num_layers=12, num_heads=32):
        super().__init__()
        
        # 原子嵌入
        self.atom_embedding = nn.Embedding(num_atoms, hidden_dim)
        
        # 中心性嵌入
        self.degree_embedding = nn.Embedding(128, hidden_dim)  # 最大度数为127
        
        # 偏置嵌入
        self.distance_embedding = nn.Embedding(128, num_heads)  # 距离嵌入
        
        # 边嵌入
        self.bond_embedding = nn.Embedding(num_bonds, hidden_dim)
        
        # Transformer层
        self.layers = nn.ModuleList([
            GraphormerLayer(hidden_dim, num_heads)
            for _ in range(num_layers)
        ])
        
        # 输出
        self.norm = nn.LayerNorm(hidden_dim)
        self.classifier = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, batch):
        # 获取节点嵌入
        x = self.atom_embedding(batch['atoms'])
        deg = self.degree_embedding(batch['degrees'])
        x = x + deg
        
        # 计算注意力偏置
        attn_bias = self._compute_spatial_bias(batch)  # (B, num_heads, N, N)
        
        # 边嵌入
        if 'bonds' in batch:
            edge_emb = self.bond_embedding(batch['bonds'])
        else:
            edge_emb = 0
        
        # Transformer
        for layer in self.layers:
            x = layer(x, attn_bias, edge_emb)
        
        x = self.norm(x)
        
        # 图级别池化
        graph_emb = x.mean(dim=0)
        
        return self.classifier(graph_emb)
    
    def _compute_spatial_bias(self, batch):
        # 计算空间偏置(基于最短路径距离)
        # ...
        return spatial_bias

与传统GNN的对比

表达能力分析

模型WL测试表达能力
GCN1-WL≤1-WL
GIN1-WL=1-WL
Graph Transformer≤k-WL超越1-WL

实验对比

在ZINC分子数据集上的结果:

模型MAE ↓
GCN0.246
GIN0.210
Graphormer0.123
SAN0.139

适用场景

场景推荐模型
小图(<1K节点)Graph Transformer
大图(>10K节点)GCN/GAT(稀疏注意力)
分子图Graphormer
超大图图采样 + Graph Transformer

高效实现

图采样策略

def sample_subgraph(node_idx, num_hops=2, num_neighbors=10):
    """Graphormer风格的图采样"""
    # BFS采样
    frontier = {node_idx}
    for _ in range(num_hops):
        neighbors = []
        for node in frontier:
            neighbors.extend(get_neighbors(node))
        # 随机选择邻居
        neighbors = random.sample(neighbors, min(num_neighbors, len(neighbors)))
        frontier.update(neighbors)
    return list(frontier)

稀疏注意力

from torch_sparse import SparseTensor
 
def sparse_graph_attention(x, adj, num_heads=8):
    """稀疏注意力实现"""
    N = x.shape[0]
    d_k = x.shape[1] // num_heads
    
    # 稀疏邻接矩阵
    adj_t = SparseTensor(row=adj[0], col=adj[1])
    
    # QKV投影
    Q = x @ W_q  # (N, num_heads, d_k)
    K = x @ W_k
    V = x @ W_v
    
    # 稀疏矩阵乘法实现注意力
    # ...

实战代码:分子性质预测

数据集

使用ZINC分子数据集,预测分子的溶解度。

import torch
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
 
# 加载数据
train_dataset = ZINC(root='/tmp/ZINC', subset=True, split='train')
val_dataset = ZINC(root='/tmp/ZINC', subset=True, split='val')
test_dataset = ZINC(root='/tmp/ZINC', subset=True, split='test')
 
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
 
# 计算拉普拉斯PE
def add_laplacian_pe(data, k=16):
    pe = compute_laplacian_pe(data.edge_index, data.num_nodes, k)
    data.pe = pe
    return data
 
train_dataset = train_dataset.map(add_laplacian_pe)

模型定义

class GraphTransformerForZINC(nn.Module):
    def __init__(self, num_atom_types, num_bond_types, hidden_dim=168, 
                 num_layers=6, num_heads=8, k_pe=16):
        super().__init__()
        self.atom_embedding = nn.Embedding(num_atom_types, hidden_dim)
        self.pe_encoder = nn.Linear(k_pe, hidden_dim)
        self.bond_embedding = nn.Embedding(num_bond_types, hidden_dim)
        
        self.layers = nn.ModuleList([
            GraphTransformerLayer(hidden_dim, num_heads)
            for _ in range(num_layers)
        ])
        
        self.final_norm = nn.LayerNorm(hidden_dim)
        self.out_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )
    
    def forward(self, batch):
        x = self.atom_embedding(batch.x.squeeze())
        pe = self.pe_encoder(batch.pe)
        x = x + pe
        
        # 可选的边嵌入
        if batch.edge_attr is not None:
            edge_emb = self.bond_embedding(batch.edge_attr)
            x = x + edge_emb.mean(dim=0, keepdim=True)
        
        # Transformer
        for layer in self.layers:
            x = layer(x, batch.edge_index)
        
        x = self.final_norm(x)
        
        # 分子级别池化
        mol_emb = global_mean_pool(x, batch.batch)
        
        return self.out_mlp(mol_emb).squeeze()

训练

model = GraphTransformerForZINC(
    num_atom_types=train_dataset.num_atom_types,
    num_bond_types=train_dataset.num_bond_types
).to(device)
 
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.L1Loss()
 
for epoch in range(100):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        
        pred = model(batch)
        loss = criterion(pred, batch.y)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch}: Loss = {total_loss/len(train_loader):.4f}")

总结与展望

核心要点

  1. 位置编码是Graph Transformer的关键组件
  2. 拉普拉斯PE提供谱域视角,随机游走PE提供扩散视角
  3. 中心性编码空间编码增强结构感知
  4. Graph Transformer在小图上效果显著优于传统GNN

未来方向

方向研究问题
高效大规模如何处理百万节点图?
动态图时序图的Transformer?
多模态结合分子3D结构信息
可解释性Graph Transformer的电路分析

参考


相关词条:图神经网络GAT图注意力网络Transformer数学基础Swin Transformer

Footnotes

  1. Dwivedi & Bresson, “A Generalization of Transformer Networks to Graphs”, AAAI 2021 Workshop

  2. Ying et al., “Do Transformers Actually Perform Better than Nets on Molecular Data?”, NeurIPS 2021