概述
大规模图神经网络训练面临的核心挑战是计算和内存开销随节点数指数增长。传统全图训练在小规模图上效果良好,但在数百万节点的大规模图(如社交网络、知识图谱)上几乎不可行。1
本章介绍解决这一问题的主流技术:图采样方法和高效架构设计。
挑战的本质
节点数 N = 1,000,000
平均度 d = 50
2跳邻居数 ≈ N × d² = 2.5B(不可能全部计算)
| 问题 | 表现 | 影响 |
|---|---|---|
| 邻居爆炸 | 指数增长的邻居数 | 内存溢出 |
| 计算复杂度 | O(N·d^L) | 训练时间过长 |
| GPU利用率低 | 稀疏计算 | 硬件浪费 |
解决方案分类
| 方法类型 | 代表工作 | 核心思想 |
|---|---|---|
| 节点采样 | GraphSAINT | 采样节点和子图 |
| 层采样 | FastGCN、LADIES | 每层独立采样 |
| 历史表示 | 历史嵌入缓存 | 避免重复计算 |
| 简化架构 | SIGN | 预计算多跳特征 |
| 图分割 | ClusterGCN | 聚类后分批训练 |
1. GraphSAINT:基于采样的归纳学习
1.1 核心思想
GraphSAINT(Graph Sampling Based INductive learning Framework)通过图采样器在每个训练步构建一个mini-batch子图,然后在此子图上执行标准GNN前向传播。1
原始图 (100万节点)
↓ 采样
子图 (如1000节点)
↓
GNN前向传播
↓
参数更新
1.2 采样策略
GraphSAINT提供三种采样器:
节点采样(Node Sampler)
按节点度分布进行采样:
其中 是节点 的度。低度节点被优先采样,减少邻居方差。
边采样(Edge Sampler)
按边采样概率进行:
高相关边优先保留,平衡度数影响。
MRF采样(Markov Random Field Sampler)
基于MRF的能量函数:
其中 是度相关势函数。
1.3 归一化修正
采样后需要修正邻接矩阵以保持期望无偏:
其中 是采样数, 是度矩阵, 是采样指示矩阵。
1.4 PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class GraphSAINTNodeSampler:
"""GraphSAINT节点采样器"""
def __init__(self, edge_index, num_nodes, num_samples, device='cpu'):
self.edge_index = edge_index.to(device)
self.num_nodes = num_nodes
self.num_samples = num_samples
self.device = device
# 计算度
self.degrees = self._compute_degrees()
# 计算采样概率(度数倒数)
self.probs = 1.0 / (self.degrees + 1) # 加1避免除零
self.probs = self.probs / self.probs.sum()
def _compute_degrees(self):
"""计算每个节点的度数"""
degrees = torch.zeros(self.num_nodes)
for i in range(self.edge_index.shape[1]):
u = self.edge_index[0, i]
v = self.edge_index[1, i]
degrees[u] += 1
degrees[v] += 1
return degrees
def sample(self):
"""采样一个子图"""
# 1. 采样节点
sampled_nodes = torch.multinomial(
self.probs, self.num_samples, replacement=False
).to(self.device)
# 2. 构建子图邻接矩阵
sub_adj, sub_edge_index = self._get_subgraph(sampled_nodes)
return sampled_nodes, sub_edge_index, sub_adj
def _get_subgraph(self, nodes):
"""提取由采样节点诱导的子图"""
node_set = set(nodes.cpu().tolist())
mask = torch.zeros(self.num_nodes, dtype=torch.bool)
mask[nodes] = True
# 过滤边
edge_mask = mask[self.edge_index[0]] & mask[self.edge_index[1]]
sub_edge_index = self.edge_index[:, edge_mask]
# 重映射节点ID
node_map = torch.zeros(self.num_nodes, dtype=torch.long, device=self.device)
node_map[nodes] = torch.arange(len(nodes), device=self.device)
sub_edge_index = node_map[sub_edge_index]
# 构建稀疏邻接矩阵
num_sub_nodes = len(nodes)
sub_adj = torch.zeros(num_sub_nodes, num_sub_nodes, device=self.device)
for i in range(sub_edge_index.shape[1]):
u, v = sub_edge_index[0, i], sub_edge_index[1, i]
sub_adj[u, v] = 1
return sub_adj, sub_edge_index
class GraphSAINTGNN(nn.Module):
"""GraphSAINT框架下的GNN"""
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3):
super().__init__()
self.num_layers = num_layers
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
# 输入层
self.convs.append(nn.Linear(in_channels, hidden_channels))
self.norms.append(nn.LayerNorm(hidden_channels))
# 隐藏层
for _ in range(num_layers - 1):
self.convs.append(nn.Linear(hidden_channels, hidden_channels))
self.norms.append(nn.LayerNorm(hidden_channels))
# 输出层
self.classifier = nn.Linear(hidden_channels, out_channels)
def forward(self, x, edge_index):
h = x
for i in range(self.num_layers):
# 邻居聚合
h = self._propagate(h, edge_index)
# 线性变换 + 归一化
h = self.convs[i](h)
h = self.norms[i](h)
h = F.relu(h)
h = F.dropout(h, p=0.5, training=self.training)
return self.classifier(h)
def _propagate(self, h, edge_index):
"""简化的消息传递"""
N = h.shape[0]
out = torch.zeros_like(h)
# 按目的节点聚合
for i in range(edge_index.shape[1]):
u, v = edge_index[0, i], edge_index[1, i]
out[v] += h[u]
# 归一化(度相关)
degrees = torch.bincount(edge_index[1], minlength=N).float()
degrees[degrees == 0] = 1 # 避免除零
out = out / degrees.unsqueeze(-1)
return out
def train_graphsaint(data, model, sampler, optimizer, epochs=100):
"""GraphSAINT训练循环"""
model.train()
for epoch in range(epochs):
# 采样一个子图
nodes, edge_index, adj = sampler.sample()
# 获取子图数据
sub_x = data.x[nodes]
sub_y = data.y[nodes]
optimizer.zero_grad()
out = model(sub_x, edge_index)
loss = F.cross_entropy(out, sub_y)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}: Loss = {loss.item():.4f}")
return model1.5 采样器对比
| 采样器 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
| 节点采样 | 实现简单 | 可能丢失高连接节点 | 度分布均匀的图 |
| 边采样 | 保留结构信息 | 高方差 | 稀疏图 |
| MRF采样 | 方差控制好 | 计算复杂 | 大规模异构图 |
2. FastGCN:逐层采样
2.1 核心思想
FastGCN在每一层独立采样固定数量的邻居,而不是采样完整的K-hop邻域。2
层0: 采样512个起始节点
↓
层1: 每个起始节点采样16个邻居
↓
层2: 每个1跳节点采样16个邻居
↓
... (继续)
2.2 采样概率
其中 是可学习的重要性权重。
2.3 方差分析
关键性质:采样策略需要方差控制以保证梯度估计的稳定性。
class FastGCNSampler:
"""FastGCN采样器"""
def __init__(self, edge_index, num_nodes, layer_samples):
"""
layer_samples: 每层的采样数列表,如 [512, 16, 16]
"""
self.edge_index = edge_index
self.num_nodes = num_nodes
self.layer_samples = layer_samples
# 预计算邻居列表
self.neighbors = self._build_adjacency_list()
def _build_adjacency_list(self):
"""构建邻接表"""
neighbors = [[] for _ in range(self.num_nodes)]
for i in range(self.edge_index.shape[1]):
u, v = self.edge_index[0, i].item(), self.edge_index[1, i].item()
neighbors[u].append(v)
return neighbors
def sample(self, start_nodes=None):
"""逐层采样"""
if start_nodes is None:
start_nodes = torch.randint(0, self.num_nodes, (self.layer_samples[0],))
layers = [start_nodes]
importance_weights = []
for l, num_sample in enumerate(self.layer_samples[1:], 1):
prev_layer = layers[-1]
next_layer_nodes = []
layer_weights = []
for node in prev_layer:
nbrs = self.neighbors[node]
if len(nbrs) > 0:
# 采样邻居
sampled = np.random.choice(nbrs, min(num_sample, len(nbrs)), replace=False)
next_layer_nodes.extend(sampled)
layer_weights.extend([1.0 / len(nbrs)] * len(sampled))
else:
# 无邻居,采样自身
next_layer_nodes.append(node)
layer_weights.append(1.0)
# 去重
next_layer_nodes = torch.tensor(list(set(next_layer_nodes)))
layers.append(next_layer_nodes)
importance_weights.append(torch.tensor(layer_weights))
return layers, importance_weights3. SIGN:简化的大规模GNN
3.1 核心思想
SIGN(Simplified Graph Neural Networks)通过预计算多跳特征来避免运行时的大规模邻居聚合。3
传统方法: SIGN方法:
在每个batch中计算多跳邻居 离线预计算多跳特征
时间复杂度: O(B × N_neighbor) 时间复杂度: O(1)(预处理后)
3.2 架构
SIGN的核心公式:
其中 是预计算的最大跳数。
3.3 PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class SIGN(nn.Module):
"""Simplified Graph Neural Networks"""
def __init__(self, in_channels, hidden_channels, out_channels, num_hops=3, num_layers=2):
super().__init__()
self.num_hops = num_hops
self.num_layers = num_layers
# 输入投影
self.input_proj = nn.Linear(in_channels, hidden_channels)
# 每层的权重(每跳一个)
self.weight_layers = nn.ModuleList()
for l in range(num_layers):
weights = nn.ModuleList([
nn.Linear(hidden_channels, hidden_channels)
for _ in range(num_hops + 1) # 包括自身
])
self.weight_layers.append(weights)
# BatchNorm
self.batch_norms = nn.ModuleList([
nn.BatchNorm1d(hidden_channels)
for _ in range(num_layers)
])
# 输出层
self.output_proj = nn.Linear(hidden_channels, out_channels)
def forward(self, x, adj_powers):
"""
x: 节点特征 (N, in_channels)
adj_powers: 邻接矩阵的R次幂列表 [(N, N), ...]
"""
h = x
for l in range(self.num_layers):
# 聚合多跳特征
hop_features = [h] # 自身作为第一跳
for r in range(1, self.num_hops + 1):
if r < len(adj_powers) + 1:
# 预计算的特征
hop_feat = adj_powers[r - 1] @ h
else:
hop_feat = h # fallback
hop_features.append(hop_feat)
# 线性组合
h_new = torch.zeros_like(h)
for r, feat in enumerate(hop_features):
h_new += self.weight_layers[l][r](feat)
# 归一化 + 激活
h = self.batch_norms[l](h_new)
h = F.relu(h)
# 除了最后一层都应用残差
if l < self.num_layers - 1:
h = F.dropout(h, p=0.5, training=self.training)
return self.output_proj(h)
def preprocess_graph(adj, num_hops=3, device='cpu'):
"""预计算邻接矩阵的幂"""
adj = adj.to(device)
# 归一化邻接矩阵
deg = adj.sum(dim=1, keepdim=True)
adj_norm = adj / deg.where(deg > 0, torch.ones_like(deg))
# 计算多跳邻接矩阵
adj_powers = []
current_power = adj_norm
for r in range(1, num_hops + 1):
adj_powers.append(current_power)
current_power = current_power @ adj_norm
return adj_powers
# 使用示例
def train_sign(data, adj, num_hops=3):
# 1. 预处理:计算邻接矩阵的幂
adj_powers = preprocess_graph(adj, num_hops=num_hops)
# 2. 创建模型
model = SIGN(
in_channels=data.num_features,
hidden_channels=256,
out_channels=data.num_classes,
num_hops=num_hops
).to(data.x.device)
# 3. 训练(无需在forward中计算多跳邻居)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(100):
model.train()
optimizer.zero_grad()
# 前向传播使用预计算的特征
out = model(data.x, adj_powers)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}: Loss = {loss.item():.4f}")3.4 SIGN vs 标准GNN
| 特性 | 标准GNN | SIGN |
|---|---|---|
| 邻居聚合 | 运行时计算 | 预处理计算 |
| 时间复杂度 | O(N·d^L) | O(N²·R)(预处理)+ O(N·d·R)(训练) |
| 内存 | O(N·d) | O(N²·R)(预处理大,但可稀疏化) |
| 表达能力 | 完整 | 近似(使用矩阵幂而非真实邻居) |
| 灵活性 | 高 | 中等 |
4. ClusterGCN:图聚类方法
4.1 核心思想
ClusterGCN通过图聚类算法(如Metis)将大图分割成多个子图,然后在每个子图簇上训练GNN。4
原始图 → Metis聚类 → K个子图 → 分批训练
4.2 优势
- 低方差:采样的是真实子图,结构完整
- 高效率:每个batch只处理一个簇
- 内存友好:不需要存储完整邻接矩阵
4.3 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import add_self_loops
from torch_geometric.data import Data, DataLoader
import scipy.sparse as sp
import numpy as np
class ClusterGCN:
"""ClusterGCN聚类器"""
def __init__(self, num_clusters):
self.num_clusters = num_clusters
def cluster(self, edge_index, num_nodes):
"""使用简单的随机分区作为聚类(实际应使用Metis)"""
cluster_id = torch.randint(0, self.num_clusters, (num_nodes,))
return cluster_id
def create_subgraph_data(self, data, cluster_id):
"""为每个簇创建子图数据"""
clusters = torch.unique(cluster_id)
subgraph_data_list = []
for c in clusters:
# 找到簇中的节点
node_mask = cluster_id == c
node_indices = torch.where(node_mask)[0]
# 构建子图
subgraph_nodes = node_indices.tolist()
subgraph_x = data.x[node_indices]
subgraph_y = data.y[node_indices]
# 过滤子图内的边
edge_mask = (cluster_id[edge_index[0]] == c) & (cluster_id[edge_index[1]] == c)
subgraph_edge_index = edge_index[:, edge_mask]
# 重映射节点ID
node_map = torch.zeros(data.num_nodes, dtype=torch.long)
node_map[node_indices] = torch.arange(len(node_indices))
subgraph_edge_index = node_map[subgraph_edge_index]
# 添加自环
subgraph_edge_index, _ = add_self_loops(subgraph_edge_index, num_nodes=len(node_indices))
# 创建子图数据
subgraph = Data(
x=subgraph_x,
edge_index=subgraph_edge_index,
y=subgraph_y,
num_nodes=len(node_indices)
)
subgraph_data_list.append(subgraph)
return subgraph_data_list
class ClusterGNN(nn.Module):
"""用于ClusterGCN的GNN"""
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = nn.Linear(in_channels, hidden_channels)
self.conv2 = nn.Linear(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x)
x = F.relu(x)
x = self._propagate(x, edge_index)
x = self.conv2(x)
return x
def _propagate(self, h, edge_index):
"""简化的消息传递"""
N = h.shape[0]
out = torch.zeros_like(h)
for i in range(edge_index.shape[1]):
u, v = edge_index[0, i], edge_index[1, i]
out[v] += h[u]
degrees = torch.bincount(edge_index[1], minlength=N).float()
degrees[degrees == 0] = 1
out = out / degrees.unsqueeze(-1)
return out
def train_clustergcn(data, edge_index, num_clusters=50, epochs=100):
# 1. 聚类
clusterer = ClusterGCN(num_clusters)
cluster_id = clusterer.cluster(edge_index, data.num_nodes)
# 2. 创建子图
subgraph_list = clusterer.create_subgraph_data(data, cluster_id)
loader = DataLoader(subgraph_list, batch_size=1, shuffle=True)
# 3. 创建模型
model = ClusterGNN(
in_channels=data.num_features,
hidden_channels=256,
out_channels=data.num_classes
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# 4. 训练
for epoch in range(epochs):
for subgraph in loader:
model.train()
optimizer.zero_grad()
out = model(subgraph.x, subgraph.edge_index)
loss = F.cross_entropy(out, subgraph.y)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}: Loss = {loss.item():.4f}")
return model5. 历史表示缓存
5.1 核心思想
对于归纳学习场景,训练时遇到的未见节点需要计算嵌入。通过缓存历史嵌入可以加速训练。5
新节点 → 查找缓存 ← 历史嵌入
↓
不存在 → 计算并缓存
↓
存在 → 直接使用
5.2 实现
class HistoryCache:
"""历史嵌入缓存"""
def __init__(self, gnn_model, feature_dim, cache_size=10000):
self.gnn_model = gnn_model
self.cache_size = cache_size
self.feature_dim = feature_dim
# LRU缓存
self.cache = {}
self.access_order = []
# 缓存统计
self.hits = 0
self.misses = 0
def get(self, node_ids):
"""获取节点嵌入"""
embeddings = []
miss_indices = []
for i, node_id in enumerate(node_ids):
if node_id in self.cache:
embeddings.append(self.cache[node_id])
self.hits += 1
else:
embeddings.append(None)
miss_indices.append(i)
self.misses += 1
# 计算缺失的嵌入
if miss_indices:
miss_nodes = torch.tensor([node_ids[i] for i in miss_indices])
miss_embeddings = self._compute_embeddings(miss_nodes)
# 缓存新嵌入
for node_id, emb in zip(miss_nodes.tolist(), miss_embeddings):
self._add_to_cache(node_id, emb)
# 填充结果
emb_idx = 0
for i in range(len(node_ids)):
if embeddings[i] is None:
embeddings[i] = miss_embeddings[emb_idx]
emb_idx += 1
return torch.stack(embeddings)
def _compute_embeddings(self, node_ids):
"""计算节点嵌入(调用GNN)"""
# 实际应用中需要提取节点相关的子图
with torch.no_grad():
embeddings = self.gnn_model(self.gnn_model.input_proj(
torch.randn(len(node_ids), self.feature_dim)
))
return embeddings
def _add_to_cache(self, node_id, embedding):
"""添加到缓存"""
if len(self.cache) >= self.cache_size:
# LRU淘汰
oldest = self.access_order.pop(0)
del self.cache[oldest]
self.cache[node_id] = embedding
self.access_order.append(node_id)
def get_stats(self):
"""获取缓存统计"""
total = self.hits + self.misses
hit_rate = self.hits / total if total > 0 else 0
return {
'hits': self.hits,
'misses': self.misses,
'hit_rate': hit_rate,
'cache_size': len(self.cache)
}6. 实践指南
6.1 方法选择
| 场景 | 推荐方法 | 理由 |
|---|---|---|
| 小规模图 (< 10K节点) | 全图训练 | 无需采样 |
| 中规模图 (10K - 1M) | GraphSAINT / ClusterGCN | 平衡效率与精度 |
| 大规模图 (> 1M) | SIGN + 缓存 | 预计算优势 |
| 异构图 | 历史缓存 | 适应新节点 |
6.2 超参数建议
| 参数 | GraphSAINT | FastGCN | SIGN |
|---|---|---|---|
| 采样数 | 1000-5000 | 64-512/层 | - |
| 层数 | 2-4 | 2-4 | 2-3 |
| batch size | 1-2子图 | 可较大 | 可较大 |
| dropout | 0.5-0.7 | 0.3-0.5 | 0.3-0.5 |
6.3 评估指标
def evaluate_sampling_methods(dataset, methods, metrics=['accuracy', 'time', 'memory']):
"""评估不同采样方法的性能"""
results = {}
for name, sampler_class in methods.items():
print(f"Evaluating {name}...")
sampler = sampler_class(dataset)
model = GCNModel(dataset.num_features, 256, dataset.num_classes)
# 测量时间
import time
start = time.time()
# ... 训练逻辑 ...
train_time = time.time() - start
# 测量内存
import psutil
mem_before = psutil.Process().memory_info().rss
# ... 训练逻辑 ...
mem_after = psutil.Process().memory_info().rss
memory = (mem_after - mem_before) / 1024 / 1024 # MB
# 测量精度
accuracy = evaluate(model, dataset)
results[name] = {
'accuracy': accuracy,
'time': train_time,
'memory': memory
}
return results7. 相关主题
参考
Footnotes
-
Zeng et al., “GraphSAINT: Graph Sampling Based Inductive Learning Method”, ICLR 2020 ↩ ↩2
-
Chen et al., “FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling”, ICLR 2018 ↩
-
Rossi et al., “SIGN: Scalable Inception Graph Neural Networks”, GRL 2020 ↩
-
Chiang et al., “Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks”, KDD 2019 ↩
-
Hu et al., “HeteroGNN: Heterogeneous Graph Neural Networks”, KDD 2020 ↩