概述

大规模图神经网络训练面临的核心挑战是计算和内存开销随节点数指数增长。传统全图训练在小规模图上效果良好,但在数百万节点的大规模图(如社交网络、知识图谱)上几乎不可行。1

本章介绍解决这一问题的主流技术:图采样方法高效架构设计

挑战的本质

节点数 N = 1,000,000
平均度 d = 50
2跳邻居数 ≈ N × d² = 2.5B(不可能全部计算)
问题表现影响
邻居爆炸指数增长的邻居数内存溢出
计算复杂度O(N·d^L)训练时间过长
GPU利用率低稀疏计算硬件浪费

解决方案分类

方法类型代表工作核心思想
节点采样GraphSAINT采样节点和子图
层采样FastGCN、LADIES每层独立采样
历史表示历史嵌入缓存避免重复计算
简化架构SIGN预计算多跳特征
图分割ClusterGCN聚类后分批训练

1. GraphSAINT:基于采样的归纳学习

1.1 核心思想

GraphSAINT(Graph Sampling Based INductive learning Framework)通过图采样器在每个训练步构建一个mini-batch子图,然后在此子图上执行标准GNN前向传播。1

原始图 (100万节点)
       ↓ 采样
子图 (如1000节点)
       ↓
GNN前向传播
       ↓
参数更新

1.2 采样策略

GraphSAINT提供三种采样器:

节点采样(Node Sampler)

按节点度分布进行采样:

其中 是节点 的度。低度节点被优先采样,减少邻居方差。

边采样(Edge Sampler)

按边采样概率进行:

高相关边优先保留,平衡度数影响。

MRF采样(Markov Random Field Sampler)

基于MRF的能量函数:

其中 是度相关势函数。

1.3 归一化修正

采样后需要修正邻接矩阵以保持期望无偏:

其中 是采样数, 是度矩阵, 是采样指示矩阵。

1.4 PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
 
class GraphSAINTNodeSampler:
    """GraphSAINT节点采样器"""
    def __init__(self, edge_index, num_nodes, num_samples, device='cpu'):
        self.edge_index = edge_index.to(device)
        self.num_nodes = num_nodes
        self.num_samples = num_samples
        self.device = device
        
        # 计算度
        self.degrees = self._compute_degrees()
        
        # 计算采样概率(度数倒数)
        self.probs = 1.0 / (self.degrees + 1)  # 加1避免除零
        self.probs = self.probs / self.probs.sum()
    
    def _compute_degrees(self):
        """计算每个节点的度数"""
        degrees = torch.zeros(self.num_nodes)
        for i in range(self.edge_index.shape[1]):
            u = self.edge_index[0, i]
            v = self.edge_index[1, i]
            degrees[u] += 1
            degrees[v] += 1
        return degrees
    
    def sample(self):
        """采样一个子图"""
        # 1. 采样节点
        sampled_nodes = torch.multinomial(
            self.probs, self.num_samples, replacement=False
        ).to(self.device)
        
        # 2. 构建子图邻接矩阵
        sub_adj, sub_edge_index = self._get_subgraph(sampled_nodes)
        
        return sampled_nodes, sub_edge_index, sub_adj
    
    def _get_subgraph(self, nodes):
        """提取由采样节点诱导的子图"""
        node_set = set(nodes.cpu().tolist())
        mask = torch.zeros(self.num_nodes, dtype=torch.bool)
        mask[nodes] = True
        
        # 过滤边
        edge_mask = mask[self.edge_index[0]] & mask[self.edge_index[1]]
        sub_edge_index = self.edge_index[:, edge_mask]
        
        # 重映射节点ID
        node_map = torch.zeros(self.num_nodes, dtype=torch.long, device=self.device)
        node_map[nodes] = torch.arange(len(nodes), device=self.device)
        sub_edge_index = node_map[sub_edge_index]
        
        # 构建稀疏邻接矩阵
        num_sub_nodes = len(nodes)
        sub_adj = torch.zeros(num_sub_nodes, num_sub_nodes, device=self.device)
        for i in range(sub_edge_index.shape[1]):
            u, v = sub_edge_index[0, i], sub_edge_index[1, i]
            sub_adj[u, v] = 1
        
        return sub_adj, sub_edge_index
 
 
class GraphSAINTGNN(nn.Module):
    """GraphSAINT框架下的GNN"""
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3):
        super().__init__()
        self.num_layers = num_layers
        
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        
        # 输入层
        self.convs.append(nn.Linear(in_channels, hidden_channels))
        self.norms.append(nn.LayerNorm(hidden_channels))
        
        # 隐藏层
        for _ in range(num_layers - 1):
            self.convs.append(nn.Linear(hidden_channels, hidden_channels))
            self.norms.append(nn.LayerNorm(hidden_channels))
        
        # 输出层
        self.classifier = nn.Linear(hidden_channels, out_channels)
    
    def forward(self, x, edge_index):
        h = x
        
        for i in range(self.num_layers):
            # 邻居聚合
            h = self._propagate(h, edge_index)
            
            # 线性变换 + 归一化
            h = self.convs[i](h)
            h = self.norms[i](h)
            h = F.relu(h)
            h = F.dropout(h, p=0.5, training=self.training)
        
        return self.classifier(h)
    
    def _propagate(self, h, edge_index):
        """简化的消息传递"""
        N = h.shape[0]
        out = torch.zeros_like(h)
        
        # 按目的节点聚合
        for i in range(edge_index.shape[1]):
            u, v = edge_index[0, i], edge_index[1, i]
            out[v] += h[u]
        
        # 归一化(度相关)
        degrees = torch.bincount(edge_index[1], minlength=N).float()
        degrees[degrees == 0] = 1  # 避免除零
        out = out / degrees.unsqueeze(-1)
        
        return out
 
 
def train_graphsaint(data, model, sampler, optimizer, epochs=100):
    """GraphSAINT训练循环"""
    model.train()
    
    for epoch in range(epochs):
        # 采样一个子图
        nodes, edge_index, adj = sampler.sample()
        
        # 获取子图数据
        sub_x = data.x[nodes]
        sub_y = data.y[nodes]
        
        optimizer.zero_grad()
        out = model(sub_x, edge_index)
        loss = F.cross_entropy(out, sub_y)
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss = {loss.item():.4f}")
    
    return model

1.5 采样器对比

采样器优点缺点适用场景
节点采样实现简单可能丢失高连接节点度分布均匀的图
边采样保留结构信息高方差稀疏图
MRF采样方差控制好计算复杂大规模异构图

2. FastGCN:逐层采样

2.1 核心思想

FastGCN在每一层独立采样固定数量的邻居,而不是采样完整的K-hop邻域。2

层0: 采样512个起始节点
       ↓
层1: 每个起始节点采样16个邻居
       ↓
层2: 每个1跳节点采样16个邻居
       ↓
... (继续)

2.2 采样概率

其中 是可学习的重要性权重。

2.3 方差分析

关键性质:采样策略需要方差控制以保证梯度估计的稳定性。

class FastGCNSampler:
    """FastGCN采样器"""
    def __init__(self, edge_index, num_nodes, layer_samples):
        """
        layer_samples: 每层的采样数列表,如 [512, 16, 16]
        """
        self.edge_index = edge_index
        self.num_nodes = num_nodes
        self.layer_samples = layer_samples
        
        # 预计算邻居列表
        self.neighbors = self._build_adjacency_list()
    
    def _build_adjacency_list(self):
        """构建邻接表"""
        neighbors = [[] for _ in range(self.num_nodes)]
        for i in range(self.edge_index.shape[1]):
            u, v = self.edge_index[0, i].item(), self.edge_index[1, i].item()
            neighbors[u].append(v)
        return neighbors
    
    def sample(self, start_nodes=None):
        """逐层采样"""
        if start_nodes is None:
            start_nodes = torch.randint(0, self.num_nodes, (self.layer_samples[0],))
        
        layers = [start_nodes]
        importance_weights = []
        
        for l, num_sample in enumerate(self.layer_samples[1:], 1):
            prev_layer = layers[-1]
            next_layer_nodes = []
            layer_weights = []
            
            for node in prev_layer:
                nbrs = self.neighbors[node]
                if len(nbrs) > 0:
                    # 采样邻居
                    sampled = np.random.choice(nbrs, min(num_sample, len(nbrs)), replace=False)
                    next_layer_nodes.extend(sampled)
                    layer_weights.extend([1.0 / len(nbrs)] * len(sampled))
                else:
                    # 无邻居,采样自身
                    next_layer_nodes.append(node)
                    layer_weights.append(1.0)
            
            # 去重
            next_layer_nodes = torch.tensor(list(set(next_layer_nodes)))
            layers.append(next_layer_nodes)
            importance_weights.append(torch.tensor(layer_weights))
        
        return layers, importance_weights

3. SIGN:简化的大规模GNN

3.1 核心思想

SIGN(Simplified Graph Neural Networks)通过预计算多跳特征来避免运行时的大规模邻居聚合。3

传统方法:                        SIGN方法:
在每个batch中计算多跳邻居           离线预计算多跳特征

时间复杂度: O(B × N_neighbor)       时间复杂度: O(1)(预处理后)

3.2 架构

SIGN的核心公式:

其中 是预计算的最大跳数。

3.3 PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class SIGN(nn.Module):
    """Simplified Graph Neural Networks"""
    def __init__(self, in_channels, hidden_channels, out_channels, num_hops=3, num_layers=2):
        super().__init__()
        self.num_hops = num_hops
        self.num_layers = num_layers
        
        # 输入投影
        self.input_proj = nn.Linear(in_channels, hidden_channels)
        
        # 每层的权重(每跳一个)
        self.weight_layers = nn.ModuleList()
        for l in range(num_layers):
            weights = nn.ModuleList([
                nn.Linear(hidden_channels, hidden_channels)
                for _ in range(num_hops + 1)  # 包括自身
            ])
            self.weight_layers.append(weights)
        
        # BatchNorm
        self.batch_norms = nn.ModuleList([
            nn.BatchNorm1d(hidden_channels)
            for _ in range(num_layers)
        ])
        
        # 输出层
        self.output_proj = nn.Linear(hidden_channels, out_channels)
    
    def forward(self, x, adj_powers):
        """
        x: 节点特征 (N, in_channels)
        adj_powers: 邻接矩阵的R次幂列表 [(N, N), ...]
        """
        h = x
        
        for l in range(self.num_layers):
            # 聚合多跳特征
            hop_features = [h]  # 自身作为第一跳
            for r in range(1, self.num_hops + 1):
                if r < len(adj_powers) + 1:
                    # 预计算的特征
                    hop_feat = adj_powers[r - 1] @ h
                else:
                    hop_feat = h  # fallback
                hop_features.append(hop_feat)
            
            # 线性组合
            h_new = torch.zeros_like(h)
            for r, feat in enumerate(hop_features):
                h_new += self.weight_layers[l][r](feat)
            
            # 归一化 + 激活
            h = self.batch_norms[l](h_new)
            h = F.relu(h)
            
            # 除了最后一层都应用残差
            if l < self.num_layers - 1:
                h = F.dropout(h, p=0.5, training=self.training)
        
        return self.output_proj(h)
 
 
def preprocess_graph(adj, num_hops=3, device='cpu'):
    """预计算邻接矩阵的幂"""
    adj = adj.to(device)
    
    # 归一化邻接矩阵
    deg = adj.sum(dim=1, keepdim=True)
    adj_norm = adj / deg.where(deg > 0, torch.ones_like(deg))
    
    # 计算多跳邻接矩阵
    adj_powers = []
    current_power = adj_norm
    
    for r in range(1, num_hops + 1):
        adj_powers.append(current_power)
        current_power = current_power @ adj_norm
    
    return adj_powers
 
 
# 使用示例
def train_sign(data, adj, num_hops=3):
    # 1. 预处理:计算邻接矩阵的幂
    adj_powers = preprocess_graph(adj, num_hops=num_hops)
    
    # 2. 创建模型
    model = SIGN(
        in_channels=data.num_features,
        hidden_channels=256,
        out_channels=data.num_classes,
        num_hops=num_hops
    ).to(data.x.device)
    
    # 3. 训练(无需在forward中计算多跳邻居)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    for epoch in range(100):
        model.train()
        optimizer.zero_grad()
        
        # 前向传播使用预计算的特征
        out = model(data.x, adj_powers)
        loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

3.4 SIGN vs 标准GNN

特性标准GNNSIGN
邻居聚合运行时计算预处理计算
时间复杂度O(N·d^L)O(N²·R)(预处理)+ O(N·d·R)(训练)
内存O(N·d)O(N²·R)(预处理大,但可稀疏化)
表达能力完整近似(使用矩阵幂而非真实邻居)
灵活性中等

4. ClusterGCN:图聚类方法

4.1 核心思想

ClusterGCN通过图聚类算法(如Metis)将大图分割成多个子图,然后在每个子图簇上训练GNN。4

原始图 → Metis聚类 → K个子图 → 分批训练

4.2 优势

  • 低方差:采样的是真实子图,结构完整
  • 高效率:每个batch只处理一个簇
  • 内存友好:不需要存储完整邻接矩阵

4.3 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import add_self_loops
from torch_geometric.data import Data, DataLoader
import scipy.sparse as sp
import numpy as np
 
class ClusterGCN:
    """ClusterGCN聚类器"""
    def __init__(self, num_clusters):
        self.num_clusters = num_clusters
    
    def cluster(self, edge_index, num_nodes):
        """使用简单的随机分区作为聚类(实际应使用Metis)"""
        cluster_id = torch.randint(0, self.num_clusters, (num_nodes,))
        return cluster_id
    
    def create_subgraph_data(self, data, cluster_id):
        """为每个簇创建子图数据"""
        clusters = torch.unique(cluster_id)
        subgraph_data_list = []
        
        for c in clusters:
            # 找到簇中的节点
            node_mask = cluster_id == c
            node_indices = torch.where(node_mask)[0]
            
            # 构建子图
            subgraph_nodes = node_indices.tolist()
            subgraph_x = data.x[node_indices]
            subgraph_y = data.y[node_indices]
            
            # 过滤子图内的边
            edge_mask = (cluster_id[edge_index[0]] == c) & (cluster_id[edge_index[1]] == c)
            subgraph_edge_index = edge_index[:, edge_mask]
            
            # 重映射节点ID
            node_map = torch.zeros(data.num_nodes, dtype=torch.long)
            node_map[node_indices] = torch.arange(len(node_indices))
            subgraph_edge_index = node_map[subgraph_edge_index]
            
            # 添加自环
            subgraph_edge_index, _ = add_self_loops(subgraph_edge_index, num_nodes=len(node_indices))
            
            # 创建子图数据
            subgraph = Data(
                x=subgraph_x,
                edge_index=subgraph_edge_index,
                y=subgraph_y,
                num_nodes=len(node_indices)
            )
            subgraph_data_list.append(subgraph)
        
        return subgraph_data_list
 
 
class ClusterGNN(nn.Module):
    """用于ClusterGCN的GNN"""
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Linear(in_channels, hidden_channels)
        self.conv2 = nn.Linear(hidden_channels, out_channels)
    
    def forward(self, x, edge_index):
        x = self.conv1(x)
        x = F.relu(x)
        x = self._propagate(x, edge_index)
        x = self.conv2(x)
        return x
    
    def _propagate(self, h, edge_index):
        """简化的消息传递"""
        N = h.shape[0]
        out = torch.zeros_like(h)
        
        for i in range(edge_index.shape[1]):
            u, v = edge_index[0, i], edge_index[1, i]
            out[v] += h[u]
        
        degrees = torch.bincount(edge_index[1], minlength=N).float()
        degrees[degrees == 0] = 1
        out = out / degrees.unsqueeze(-1)
        
        return out
 
 
def train_clustergcn(data, edge_index, num_clusters=50, epochs=100):
    # 1. 聚类
    clusterer = ClusterGCN(num_clusters)
    cluster_id = clusterer.cluster(edge_index, data.num_nodes)
    
    # 2. 创建子图
    subgraph_list = clusterer.create_subgraph_data(data, cluster_id)
    loader = DataLoader(subgraph_list, batch_size=1, shuffle=True)
    
    # 3. 创建模型
    model = ClusterGNN(
        in_channels=data.num_features,
        hidden_channels=256,
        out_channels=data.num_classes
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    # 4. 训练
    for epoch in range(epochs):
        for subgraph in loader:
            model.train()
            optimizer.zero_grad()
            
            out = model(subgraph.x, subgraph.edge_index)
            loss = F.cross_entropy(out, subgraph.y)
            loss.backward()
            optimizer.step()
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss = {loss.item():.4f}")
    
    return model

5. 历史表示缓存

5.1 核心思想

对于归纳学习场景,训练时遇到的未见节点需要计算嵌入。通过缓存历史嵌入可以加速训练。5

新节点 → 查找缓存 ← 历史嵌入
              ↓
        不存在 → 计算并缓存
              ↓
        存在 → 直接使用

5.2 实现

class HistoryCache:
    """历史嵌入缓存"""
    def __init__(self, gnn_model, feature_dim, cache_size=10000):
        self.gnn_model = gnn_model
        self.cache_size = cache_size
        self.feature_dim = feature_dim
        
        # LRU缓存
        self.cache = {}
        self.access_order = []
        
        # 缓存统计
        self.hits = 0
        self.misses = 0
    
    def get(self, node_ids):
        """获取节点嵌入"""
        embeddings = []
        miss_indices = []
        
        for i, node_id in enumerate(node_ids):
            if node_id in self.cache:
                embeddings.append(self.cache[node_id])
                self.hits += 1
            else:
                embeddings.append(None)
                miss_indices.append(i)
                self.misses += 1
        
        # 计算缺失的嵌入
        if miss_indices:
            miss_nodes = torch.tensor([node_ids[i] for i in miss_indices])
            miss_embeddings = self._compute_embeddings(miss_nodes)
            
            # 缓存新嵌入
            for node_id, emb in zip(miss_nodes.tolist(), miss_embeddings):
                self._add_to_cache(node_id, emb)
            
            # 填充结果
            emb_idx = 0
            for i in range(len(node_ids)):
                if embeddings[i] is None:
                    embeddings[i] = miss_embeddings[emb_idx]
                    emb_idx += 1
        
        return torch.stack(embeddings)
    
    def _compute_embeddings(self, node_ids):
        """计算节点嵌入(调用GNN)"""
        # 实际应用中需要提取节点相关的子图
        with torch.no_grad():
            embeddings = self.gnn_model(self.gnn_model.input_proj(
                torch.randn(len(node_ids), self.feature_dim)
            ))
        return embeddings
    
    def _add_to_cache(self, node_id, embedding):
        """添加到缓存"""
        if len(self.cache) >= self.cache_size:
            # LRU淘汰
            oldest = self.access_order.pop(0)
            del self.cache[oldest]
        
        self.cache[node_id] = embedding
        self.access_order.append(node_id)
    
    def get_stats(self):
        """获取缓存统计"""
        total = self.hits + self.misses
        hit_rate = self.hits / total if total > 0 else 0
        return {
            'hits': self.hits,
            'misses': self.misses,
            'hit_rate': hit_rate,
            'cache_size': len(self.cache)
        }

6. 实践指南

6.1 方法选择

场景推荐方法理由
小规模图 (< 10K节点)全图训练无需采样
中规模图 (10K - 1M)GraphSAINT / ClusterGCN平衡效率与精度
大规模图 (> 1M)SIGN + 缓存预计算优势
异构图历史缓存适应新节点

6.2 超参数建议

参数GraphSAINTFastGCNSIGN
采样数1000-500064-512/层-
层数2-42-42-3
batch size1-2子图可较大可较大
dropout0.5-0.70.3-0.50.3-0.5

6.3 评估指标

def evaluate_sampling_methods(dataset, methods, metrics=['accuracy', 'time', 'memory']):
    """评估不同采样方法的性能"""
    results = {}
    
    for name, sampler_class in methods.items():
        print(f"Evaluating {name}...")
        
        sampler = sampler_class(dataset)
        model = GCNModel(dataset.num_features, 256, dataset.num_classes)
        
        # 测量时间
        import time
        start = time.time()
        # ... 训练逻辑 ...
        train_time = time.time() - start
        
        # 测量内存
        import psutil
        mem_before = psutil.Process().memory_info().rss
        # ... 训练逻辑 ...
        mem_after = psutil.Process().memory_info().rss
        memory = (mem_after - mem_before) / 1024 / 1024  # MB
        
        # 测量精度
        accuracy = evaluate(model, dataset)
        
        results[name] = {
            'accuracy': accuracy,
            'time': train_time,
            'memory': memory
        }
    
    return results

7. 相关主题

主题描述
图神经网络GNN基础概念
图卷积网络GCN详细分析
GNN深度限制过平滑、过压缩问题

参考

Footnotes

  1. Zeng et al., “GraphSAINT: Graph Sampling Based Inductive Learning Method”, ICLR 2020 2

  2. Chen et al., “FastGCN: Fast Learning with Graph Convolutional Networks via Importance Sampling”, ICLR 2018

  3. Rossi et al., “SIGN: Scalable Inception Graph Neural Networks”, GRL 2020

  4. Chiang et al., “Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks”, KDD 2019

  5. Hu et al., “HeteroGNN: Heterogeneous Graph Neural Networks”, KDD 2020