概述
空间方法(Spatial Methods)从图结构本身出发,将图卷积定义为直接在邻居节点上的特征聚合操作。相比谱方法,空间方法具有不依赖图拉普拉斯、支持归纳学习、可处理动态图等优势,已成为图神经网络的主流范式。1
本文档系统讲解GraphSAGE、GAT、GATv2、APPNP、GCNII等代表性空间GCN架构,并扩展到大规模训练方法(SIGN、SAGN、ClusterGCN、GraphSAINT)。
1. 消息传递的统一框架
1.1 Gilmer et al. (MPNN) 框架
Gilmer 等人在 2017 年提出 MPNN 框架,统一了图卷积、空域 GNN 等多种模型:
步骤 1:消息函数(每个边上)
步骤 2:聚合函数(每个节点上)
步骤 3:更新函数(每个节点上)
1.2 框架特例化
不同 GNN 架构是 MPNN 的特例化:
| 模型 | 消息函数 | 聚合函数 | 更新函数 |
|---|---|---|---|
| GCN | 度归一化求和 | ReLU() | |
| GraphSAGE | Mean/Max/LSTM | ||
| GAT | 加权求和 | ||
| GIN | 求和 | MLP |
2. GraphSAGE:归纳式图卷积
2.1 动机
Kipf-Welling GCN 是直推式(transductive)学习:
- 需要所有节点的图结构
- 无法处理训练时未见过的节点
- 新增节点需要重新训练
GraphSAGE 的目标:学习一个聚合函数而非节点嵌入,使其能泛化到新节点。
2.2 核心思想
Hamilton et al. (NeurIPS 2017) 提出:
- 采样邻居:固定数量的邻居(避免指数爆炸)
- 聚合函数:学习多个候选聚合器(Mean、MaxPool、LSTM)
- 拼接 + 线性变换:保留自身特征
2.3 算法流程
输入:节点 ,特征
输出:节点 的新表示
步骤:
- 采样邻居:(如 )
- 聚合邻居:
- 拼接自身:
- 归一化:
2.4 聚合函数设计
Mean Aggregator
注意:这等价于 Kipf-Welling GCN 的传播规则。
MaxPool Aggregator
逐元素取最大,引入非线性。
LSTM Aggregator
LSTM 处理邻居序列,但需要注意邻居无序性,需要随机排序。
2.5 GraphSAGE 的 PyTorch 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class MeanAggregator(nn.Module):
"""Mean aggregator"""
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, neighbor_feats):
# neighbor_feats: (num_samples, in_features)
mean = neighbor_feats.mean(dim=0)
return self.linear(mean)
class MaxPoolAggregator(nn.Module):
"""Max-pooling aggregator"""
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, neighbor_feats):
# neighbor_feats: (num_samples, in_features)
transformed = self.linear(neighbor_feats)
return transformed.max(dim=0)[0]
class GraphSAGE(nn.Module):
"""GraphSAGE 完整实现"""
def __init__(self, in_channels, hidden_channels, out_channels,
num_layers=2, aggr='mean', dropout=0.5):
super().__init__()
self.num_layers = num_layers
self.dropout = dropout
# 每层的聚合器和线性变换
self.aggregators = nn.ModuleList()
self.linears = nn.ModuleList()
in_dim = in_channels
for _ in range(num_layers):
if aggr == 'mean':
agg = MeanAggregator(in_dim, hidden_channels)
elif aggr == 'maxpool':
agg = MaxPoolAggregator(in_dim, hidden_channels)
else:
raise ValueError(f"Unknown aggregator: {aggr}")
self.aggregators.append(agg)
self.linears.append(nn.Linear(in_dim + hidden_channels, hidden_channels))
in_dim = hidden_channels
self.output_linear = nn.Linear(hidden_channels, out_channels)
def forward(self, x, adj_list):
"""
x: (N, in_channels) 节点特征
adj_list: list of lists,第 i 个元素是节点 i 的邻居索引列表
"""
h = x
for layer in range(self.num_layers):
neighbor_aggs = []
for node_idx in range(h.size(0)):
neighbors = adj_list[node_idx]
if len(neighbors) == 0:
agg = torch.zeros(h.size(1), device=h.device)
else:
neighbor_feats = h[neighbors]
agg = self.aggregators[layer](neighbor_feats)
neighbor_aggs.append(agg)
neighbor_aggs = torch.stack(neighbor_aggs)
# 拼接自身和邻居聚合
h_concat = torch.cat([h, neighbor_aggs], dim=1)
h_new = self.linears[layer](h_concat)
h_new = F.relu(h_new)
h_new = F.dropout(h_new, p=self.dropout, training=self.training)
# 归一化
h = F.normalize(h_new, p=2, dim=1)
return self.output_linear(h)2.6 GraphSAGE 的局限
- 邻居采样引入方差:不同样本产生不同结果
- LSTM 聚合器的顺序敏感性:理论上无序,但实际有偏
- 采样策略影响性能:需要精心设计
3. GAT:图注意力网络
3.1 动机
Kipf-Welling GCN 使用固定的度归一化作为聚合权重。然而不同邻居的重要性可能不同。
GAT 的目标:学习自适应的邻居权重。
3.2 GAT 注意力机制
步骤 1:线性变换
对每个节点 ,先做线性投影:
步骤 2:计算注意力系数
其中 是可学习向量, 表示拼接。
步骤 3:Softmax 归一化
步骤 4:加权聚合
3.3 多头注意力
为稳定训练,使用多头注意力:
最后一层通常使用平均而非拼接:
3.4 GAT 的 PyTorch 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class GATLayer(nn.Module):
"""GAT 单层"""
def __init__(self, in_features, out_features, num_heads=1,
dropout=0.6, alpha=0.2, concat=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.num_heads = num_heads
self.concat = concat
self.dropout = dropout
# 每个注意力头独立的线性变换
self.W = nn.Parameter(torch.empty(num_heads, in_features, out_features))
# 注意力向量 a
self.a_src = nn.Parameter(torch.empty(num_heads, out_features, 1))
self.a_dst = nn.Parameter(torch.empty(num_heads, out_features, 1))
nn.init.xavier_uniform_(self.W)
nn.init.xavier_uniform_(self.a_src)
nn.init.xavier_uniform_(self.a_dst)
self.leakyrelu = nn.LeakyReLU(alpha)
def forward(self, h, adj):
"""
h: (N, in_features)
adj: (N, N) 邻接矩阵(含自环)
"""
N = h.size(0)
# (num_heads, N, out_features)
h_transformed = torch.einsum('jkl,ij->ikl', self.W, h)
# 计算注意力分数
# e_src: (num_heads, N, 1)
e_src = torch.einsum('hij,hkj->hik', h_transformed, self.a_src)
e_dst = torch.einsum('hij,hkj->hik', h_transformed, self.a_dst)
# (num_heads, N, N)
e = self.leakyrelu(e_src + e_dst.transpose(1, 2))
# Mask: 邻接矩阵为 0 的位置设置为 -inf
mask = -1e9 * (1.0 - adj).unsqueeze(0)
e = e + mask
# Softmax
attention = F.softmax(e, dim=2)
attention = F.dropout(attention, p=self.dropout, training=self.training)
# 加权聚合
# h_transformed: (num_heads, N, out_features)
# attention: (num_heads, N, N)
h_prime = torch.einsum('hij,hjk->hik', attention, h_transformed)
if self.concat:
return F.elu(h_prime.transpose(0, 1).reshape(N, -1))
else:
return h_prime.mean(dim=0)
class GAT(nn.Module):
"""完整 GAT 用于节点分类"""
def __init__(self, in_channels, hidden_channels, out_channels,
num_heads=8, num_layers=2, dropout=0.6):
super().__init__()
self.layers = nn.ModuleList()
# 第一层:多头 + 拼接
self.layers.append(
GATLayer(in_channels, hidden_channels, num_heads=num_heads,
dropout=dropout, concat=True)
)
# 中间层:多头 + 拼接
for _ in range(num_layers - 2):
self.layers.append(
GATLayer(hidden_channels * num_heads, hidden_channels,
num_heads=num_heads, dropout=dropout, concat=True)
)
# 最后一层:多头 + 平均
if num_layers > 1:
self.layers.append(
GATLayer(hidden_channels * num_heads, out_channels,
num_heads=1, dropout=dropout, concat=False)
)
def forward(self, x, adj):
for layer in self.layers:
x = layer(x, adj)
return F.log_softmax(x, dim=1)4. GATv2:修正 GAT 的单调性问题
4.1 GAT 的”静态注意力”问题
Brody et al. (ICLR 2022) 发现:GAT 的注意力是”静态”的,即对不同查询节点的邻居排序相同。2
证明:
GAT 的注意力分数:
设 :
关键观察:LeakyReLU 在正部分是线性的。因此:
- 当 时:
- 这可以分解为 的形式
- softmax 归一化后, 在所有邻居上抵消
- 因此 仅依赖 ,与查询 无关
结论:GAT 的排名永远是关于邻居特征的固定排名,与查询节点无关。
4.2 GATv2 的修正
GATv2 通过调换操作顺序解决这个问题:
关键区别:
- GAT:
- GATv2:
将投影 放在 LeakyReLU 之前,避免分解为 的形式。
4.3 GATv2 的有效性证明
Brody et al. 证明 GATv2 是普遍表达能力(universal)的:
其中 可以是任意函数。这意味着 GATv2 可以实现任意复杂的邻居排序。
4.4 GATv2 的 PyTorch 实现
class GATv2Layer(nn.Module):
"""GATv2 单层:注意操作顺序的变化"""
def __init__(self, in_features, out_features, num_heads=1,
dropout=0.6, alpha=0.2, concat=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.num_heads = num_heads
self.concat = concat
self.dropout = dropout
# 每个头独立的线性变换
self.W = nn.Parameter(torch.empty(num_heads, in_features, out_features))
# 注意力向量 a(单个)
self.a = nn.Parameter(torch.empty(num_heads, 2 * out_features, 1))
nn.init.xavier_uniform_(self.W)
nn.init.xavier_uniform_(self.a)
self.leakyrelu = nn.LeakyReLU(alpha)
def forward(self, h, adj):
N = h.size(0)
# 线性变换
h_transformed = torch.einsum('hij,kj->hik', self.W, h)
# 构造所有节点对(i, j)
# h_transformed_i: (num_heads, N, 1, out_features)
# h_transformed_j: (num_heads, 1, N, out_features)
h_i = h_transformed.unsqueeze(2).expand(-1, -1, N, -1)
h_j = h_transformed.unsqueeze(1).expand(-1, N, -1, -1)
# 拼接 (i, j) 对
h_concat = torch.cat([h_i, h_j], dim=-1) # (num_heads, N, N, 2*out_features)
# GATv2 关键:LeakyReLU 在 a^T 之前
e = self.leakyrelu(h_concat)
e = torch.einsum('hijk,hkl->hijl', e, self.a).squeeze(-1) # (num_heads, N, N)
# Mask
mask = -1e9 * (1.0 - adj).unsqueeze(0)
e = e + mask
attention = F.softmax(e, dim=2)
attention = F.dropout(attention, p=self.dropout, training=self.training)
# 加权聚合
h_prime = torch.einsum('hij,hjk->hik', attention, h_transformed)
if self.concat:
return F.elu(h_prime.transpose(0, 1).reshape(N, -1))
else:
return h_prime.mean(dim=0)4.5 GAT vs GATv2 性能对比
| 任务 | GAT 准确率 | GATv2 准确率 |
|---|---|---|
| Cora | 83.0% | 83.4% |
| Citeseer | 72.5% | 73.1% |
| Pubmed | 79.0% | 79.5% |
| PPI | 97.3% | 98.0% |
5. APPNP:个性化 PageRank 传播
5.1 动机
Klicpera et al. (ICLR 2019) 发现 GCN 的核心瓶颈不是特征变换,而是传播。3
关键洞察:
- GCN 等价于低通滤波器在特征上的应用
- 传播深度增加,准确率反而下降(过平滑)
- PPNP / APPNP 通过解耦特征变换和传播,解决了这个问题
5.2 PPNP 个性化 PageRank
PPNP 首先用神经网络预测节点表示,然后通过个性化 PageRank 传播:
其中:
- :神经网络预测
- :归一化邻接矩阵
- :teleport 概率(通常 )
5.3 APPNP:近似传播
直接计算 PPNP 需要求逆矩阵,复杂度高。APPNP 用迭代近似:
其中:
- :神经网络初始预测
- :从 保留的比例
- :传播步数(通常 10 步)
5.4 APPNP 的 PyTorch 实现
class APPNP(nn.Module):
"""APPNP 模型"""
def __init__(self, in_channels, hidden_channels, out_channels,
K=10, alpha=0.1, dropout=0.5):
super().__init__()
# 特征变换网络(MLP)
self.feature_transform = nn.Sequential(
nn.Linear(in_channels, hidden_channels),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_channels, hidden_channels),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_channels, out_channels)
)
self.K = K
self.alpha = alpha
def forward(self, x, adj_normalized):
# 初始预测
h_0 = self.feature_transform(x)
# 迭代传播
h = h_0
for _ in range(self.K):
h = (1 - self.alpha) * (adj_normalized @ h) + self.alpha * h_0
return F.log_softmax(h, dim=1)5.5 APPNP 的优势
- 解耦特征变换与传播:可以先用 MLP 学表示,再用 PageRank 传播
- 不易过平滑:每次迭代都从 重新注入
- 参数效率:每层不需要额外的可学习参数
- 理论保证:等价于 PPNP 的精确传播
6. GCNII:深度 GCN
6.1 过平滑问题
普通 GCN 在深度增加时性能急剧下降(通常 2-3 层最佳)。Chen et al. (ICML 2020) 分析认为这是过平滑所致。4
6.2 GCNII 的两大技巧
初始残差(Initial Residual)
将输入 作为残差加到每层:
简化为:
恒等映射(Identity Mapping)
在权重矩阵中加入恒等项:
其中 (随层数 衰减)。
6.3 GCNII 的传播规则
完整公式:
其中:
- 是超参数(通常取 0.5, 0.5)
6.4 GCNII 的理论保证
Chen et al. 证明 层 GCNII 等价于一个 阶多项式滤波器,与 Kipf-Welling GCN 的特殊滤波器不同,这个多项式系数可以更灵活。
6.5 GCNII 的 PyTorch 实现
class GCNIILayer(nn.Module):
"""GCNII 单层"""
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias=False)
nn.init.xavier_uniform_(self.weight)
def forward(self, h, h_0, adj, alpha_l, beta_l):
# Identity mapping: (1-beta_l) I + beta_l W
W_eff = (1 - beta_l) * torch.eye(self.linear.weight.size(0),
device=h.device) + \
beta_l * self.linear.weight
support = h @ W_eff
# 邻居聚合
aggregate = adj @ support
# 初始残差 + 自身残差
out = (1 - alpha_l) * aggregate + alpha_l * h_0
return F.relu(out)
class GCNII(nn.Module):
"""完整 GCNII"""
def __init__(self, in_channels, hidden_channels, out_channels,
num_layers=64, lambda_alpha=0.5, lambda_beta=0.5,
dropout=0.6):
super().__init__()
self.input_linear = nn.Linear(in_channels, hidden_channels)
self.input_relu = nn.ReLU()
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(GCNIILayer(hidden_channels, hidden_channels))
self.output_linear = nn.Linear(hidden_channels, out_channels)
self.lambda_alpha = lambda_alpha
self.lambda_beta = lambda_beta
self.num_layers = num_layers
self.dropout = dropout
def forward(self, x, adj):
h_0 = self.input_relu(self.input_linear(x))
h = F.dropout(h_0, p=self.dropout, training=self.training)
for l in range(self.num_layers):
alpha_l = self.lambda_alpha * 1.0 / (l + 1)
beta_l = self.lambda_beta * 1.0 / (l + 1)
h_new = self.layers[l](h, h_0, adj, alpha_l, beta_l)
h = F.dropout(h_new, p=self.dropout, training=self.training)
return F.log_softmax(self.output_linear(h), dim=1)6.6 GCNII 的性能
| 模型 | Cora | Citeseer | Pubmed |
|---|---|---|---|
| GCN (2层) | 81.5% | 70.3% | 79.0% |
| GCN (16层) | 64.9% | 16.9% | 40.4% |
| GCNII (16层) | 85.5% | 73.1% | 80.2% |
| GCNII (64层) | 85.3% | 72.9% | 79.9% |
GCNII 在 64 层时仍能保持性能,远超普通 GCN。
7. 可扩展 GCN 训练
7.1 挑战
真实图常包含数百万节点和边(如社交网络、引文网络)。直接训练 GCN 面临:
- 内存爆炸:存储所有节点的中间表示
- 计算瓶颈:每层都需要全图前向
- 邻居爆炸:深度增加时邻居呈指数增长
7.2 SIGN:可扩展 Inception-like GCN
SIGN (Rossi et al., 2020) 使用多种图算子的并行组合:
其中 是不同传播步数的集合(如 )。
class SIGN(nn.Module):
"""SIGN 模型"""
def __init__(self, in_channels, hidden_channels, out_channels, K=3):
super().__init__()
self.linears = nn.ModuleList()
# 输入层
self.linears.append(nn.Linear(in_channels * (K + 1), hidden_channels))
# 输出层
self.output = nn.Linear(hidden_channels, out_channels)
self.K = K
def forward(self, x, adj):
# 预计算不同传播步数的特征
feats = [x]
h = x
for _ in range(self.K):
h = adj @ h
feats.append(h)
# 拼接所有特征
h = torch.cat(feats, dim=1)
h = self.linears[0](h)
h = F.relu(h)
return self.output(h)7.3 SAGN:自適応图采样网络
SAGN (Chen et al., 2021) 进一步:
- 用 SIGN 作为基础
- 引入自适应采样
- 添加标签传播(Label Propagation)
7.4 ClusterGCN:基于聚类的采样
ClusterGCN (Chiang et al., ICML 2019) 思路:
- 用图聚类算法(如 METIS)将大图划分为多个子图
- 每个 batch 训练一个子图
- 子图内做标准 GCN
优势:训练在子图上进行,计算效率高
局限:聚类质量影响性能
7.5 GraphSAINT:基于采样的归纳式训练
GraphSAINT (Zeng et al., ICLR 2020) 设计了多种采样器:
节点采样(Node Sampler)
每个 batch 随机采样 个节点,仅计算这些节点的表示。
边采样(Edge Sampler)
按边的概率采样,按边的稀疏度归一化。
随机游走采样(Random Walk Sampler)
通过随机游走生成子图。
优势:提供无偏估计,支持归纳学习。
8. 模型对比与选择指南
8.1 性能对比
| 模型 | 时间复杂度 | 空间复杂度 | 表达能力 | 适用场景 |
|---|---|---|---|---|
| GCN | 中等 | 基线、简单图 | ||
| GraphSAGE | 中等 | 归纳学习、大图 | ||
| GAT | 中高 | 邻居重要性不同 | ||
| GATv2 | 高 | 同 GAT,更灵活 | ||
| APPNP | 中高 | 异配图、深度学习 | ||
| GCNII | 高 | 深度图、需要高表达 | ||
| SGC | 低 | 快速基线 | ||
| SIGN | 中等 | 大规模图 |
8.2 选择指南
| 场景 | 推荐模型 |
|---|---|
| 小图基准测试 | GCN, GAT |
| 大图归纳学习 | GraphSAGE, GraphSAINT |
| 邻居重要性不同 | GATv2 |
| 深度图、复杂结构 | GCNII |
| 异配图 | APPNP, GPR-GNN |
| 大规模快速训练 | SGC, SIGN |
| 归纳到新图 | GraphSAGE |
9. 实践建议
9.1 初始化策略
def init_weights(model):
"""GCN 推荐初始化"""
for m in model.modules():
if isinstance(m, nn.Linear):
# Xavier/Glorot 初始化
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (GCNConv, GATLayer)):
if hasattr(m, 'weight') and m.weight is not None:
nn.init.xavier_uniform_(m.weight)9.2 正则化技巧
- Dropout:节点特征上的 Dropout 最重要
- 边 Dropout:训练时随机删除一些边
- L2 正则化:对权重矩阵
- Early Stopping:基于验证集
9.3 常见陷阱
- 过深的层数:2-3 层通常最佳
- 忽略度归一化:导致度大的节点主导
- 邻居采样方差:GraphSAGE 需要多次采样平均
- 标签泄漏:转导式学习的潜在问题
10. 未来方向
10.1 异配图
传统 GCN 假设相邻节点相似(同配性)。但许多真实图是异配的:
- 社交网络:朋友之间不一定相似
- 知识图谱:不同类型实体相连
解决方向:
- H2GCN (Zhu et al., NeurIPS 2020)
- GPR-GNN (Chien et al., ICLR 2021)
- FAGCN (Bo et al., NeurIPS 2021)
10.2 大规模图
亿级节点图的训练:
- 分布式 GCN 训练(多机多卡)
- 模型并行(图划分 + 跨机器消息传递)
- 内存高效训练(特征压缩、量化)
10.3 动态图
时序图(节点和边动态变化):
- TGN (Rossi et al., 2020)
- DyRep (Trivedi et al., 2019)
- EvolveGCN (Pareja et al., 2020)
参考
Footnotes
-
Gilmer et al., “Neural Message Passing for Quantum Chemistry”, ICML 2017 ↩
-
Brody, Alon, Yahav, “How Attentive are Graph Attention Networks?”, ICLR 2022. arXiv:2105.14491 ↩
-
Gasteiger, Bojchevski, Günnemann, “Predict then Propagate: Graph Neural Networks meet Personalized PageRank”, ICLR 2019. arXiv:1810.05997 ↩
-
Chen, Wei, Huang, Ding, Li, “Simple and Deep Graph Convolutional Networks”, ICML 2020. arXiv:2007.02133 ↩