概述
深层图神经网络(Deep Graph Neural Networks)的训练面临两个核心挑战:过平滑(Over-smoothing)和过压缩(Over-squashing)。1
这两个问题严重限制了GNN向深层发展的能力:
| 问题 | 本质 | 表现 |
|---|---|---|
| 过平滑 | 多次平滑导致节点表示趋于相同 | 节点特征丧失区分性 |
| 过压缩 | 多跳信息被压缩到固定维度 | 远距离依赖丢失 |
理解这两个问题的数学本质,对于设计更深、更强大的GNN架构至关重要。
1. 过平滑问题
1.1 定义与直观理解
过平滑是指随着GNN层数增加,节点的表示向量逐渐趋同,最终所有节点的嵌入变得几乎相同。
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
def measure_smoothing(embeddings):
"""测量嵌入的平滑程度(越低越平滑)"""
# 计算所有节点对之间嵌入的距离方差
pairwise_dist = torch.pdist(embeddings, p=2)
return pairwise_dist.var().item()
# 测试不同层数的平滑程度
def test_smoothing_levels():
"""
假设我们有一个简单的图结构
观察随着层数增加,节点嵌入的区分性如何变化
"""
# 底层真实嵌入应该有较大差异
# 经过多层GNN后,差异逐渐消失
pass1.2 数学描述
拉普拉斯平滑的视角
对于两层GCN,假设没有激活函数和激活函数:
其中 是对称归一化邻接矩阵。
特征值分解分析
设 的特征分解为:
其中 是特征向量矩阵, 是特征值矩阵。
关键性质:
- 归一化邻接矩阵 的特征值满足:
- 最小特征值 (对应全1向量)
- 特征值越接近1,信号衰减越少
K层传播的数学表达
经过 层图卷积后:
当 时:
其中 是全1向量, 是稳态分布(与图的度分布相关)。
推论:所有节点的表示趋于相同!(除非有残差连接)
1.3 过平滑的数学界限
有效平滑距离
定义平滑速度为节点表示收敛到相同值的速率:2
其中 是对应于特征值 的特征向量分量。
收敛边界
对于任意节点 :
其中 ( 是第二大的特征值)。
越接近1,收敛越慢(不易过平滑)
越接近0,收敛越快(易过平滑)
PairNorm的数学动机
PairNorm的核心思想是保持节点嵌入的总平方范数恒定:3
其中 是归一化算子。
1.4 过平滑的诊断指标
class SmoothingMetrics:
"""计算过平滑程度的多种指标"""
@staticmethod
def pairwise_distance_var(embeddings):
"""节点嵌入对的距离方差"""
pairwise_dist = torch.cdist(embeddings, embeddings)
# 排除对角线
mask = ~torch.eye(embeddings.shape[0], dtype=torch.bool)
return pairwise_dist[mask].var()
@staticmethod
def condition_number(laplacian):
"""图的拉普拉斯矩阵条件数(与过平滑相关)"""
eigenvalues = torch.linalg.eigvalsh(laplacian)
# 排除零特征值
nonzero_eigenvalues = eigenvalues[1:]
return nonzero_eigenvalues.max() / nonzero_eigenvalues.min()
@staticmethod
def spectral_gap(normalized_adj):
"""谱间隙 1 - λ2(越大越不易过平滑)"""
eigenvalues = torch.linalg.eigvalsh(normalized_adj)
# λ1 = 1, λ2 是第二大的
return 1 - eigenvalues[-2].item()2. 过压缩问题
2.1 定义与直观理解
过压缩是指GNN的消息传递机制将来自多个邻居/多条路径的信息”压缩”到固定维度的向量中,导致远距离信息丢失。4
层数1: 节点只能接收1跳邻居的信息
层数2: 节点可以接收2跳邻居的信息(但被压缩)
层数3: 节点可以接收3跳邻居的信息(被进一步压缩)
...
层数K: 节点可以接收K跳邻居的信息(严重压缩)
2.2 数学分析:Jacobian视角
消息传递的回流
考虑一个简化的线性消息传递:
Jacobian矩阵的分析
节点 关于输入 的 Jacobian:
关键问题:当 很大时,这个乘积的范数会如何变化?
如果权重的谱范数 ,梯度会指数衰减(与RNN的梯度消失类似)。
2.3 信息瓶颈理论
通道容量视角
将消息传递视为一个信息通道:
输入: 来自多个邻居的信息 {h_u : u ∈ N(v)}
↓
编码: 通过线性变换 + 聚合 → h_v
↓
输出: 固定维度的节点嵌入
核心约束:固定维度的嵌入无法无损失地表示来自任意数量邻居的信息。
互信息分析
设 表示输入与输出之间的互信息:
其中 是嵌入维度, 是通道容量。
2.4 曲率基测量方法
Topping et al. (2022) 提出使用图Ricci曲率来量化过压缩程度。4
曲率定义
对于无权无向图,节点 和 之间的Ricci曲率定义为:
其中:
- :最短路径上的中间节点数
- :边的权重乘积
负曲率与过压缩
- 正曲率边:信息流动顺畅
- 负曲率边:信息流动受阻,易发生过压缩
import numpy as np
def compute_ricci_curvature(G, edge):
"""
简化版的图Ricci曲率计算
参数:
G: NetworkX图
edge: 边 (u, v)
返回:
curvature: 曲率值(负值表示负曲率边)
"""
u, v = edge
# 获取邻居
neighbors_u = set(G.neighbors(u)) - {v}
neighbors_v = set(G.neighbors(v)) - {u}
# 计算最短路径上的中间节点
common_neighbors = neighbors_u & neighbors_v
u_to_v_via_common = len(common_neighbors)
# 通过其他节点的路径
paths_through_other = 0
for w in (neighbors_u - common_neighbors):
for x in (neighbors_v - common_neighbors):
if G.has_edge(w, x):
paths_through_other += 1
# Ricci曲率公式
curvature = 1 - (u_to_v_via_common + paths_through_other) / (len(neighbors_u) + len(neighbors_v))
return curvature
def identify_bottleneck_edges(G, threshold=-0.1):
"""
识别可能导致过压缩的瓶颈边
参数:
G: NetworkX图
threshold: 曲率阈值(低于此值认为是瓶颈)
返回:
bottleneck_edges: 瓶颈边列表
"""
bottleneck_edges = []
for edge in G.edges():
curvature = compute_ricci_curvature(G, edge)
if curvature < threshold:
bottleneck_edges.append((edge, curvature))
return bottleneck_edges3. 解决方案
3.1 架构层面
3.1.1 残差连接
最直接的方法是在每层GNN后添加残差连接:1
效果:保持原始信号通路,延缓过平滑
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class ResGNN(nn.Module):
"""带残差连接的GNN"""
def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
super().__init__()
self.num_layers = num_layers
# 输入投影层
self.input_proj = nn.Linear(in_channels, hidden_channels)
# GNN层列表
self.convs = nn.ModuleList()
for _ in range(num_layers):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
# 输出层
self.output_proj = nn.Linear(hidden_channels, out_channels)
def forward(self, x, edge_index):
# 输入投影
h = self.input_proj(x)
# 带残差连接的GNN层
for i, conv in enumerate(self.convs):
h_new = conv(h, edge_index)
h_new = F.relu(h_new)
# 残差连接:每两层加一次残差
if i % 2 == 1:
h = h + h_new # 残差连接
else:
h = h_new
return self.output_proj(h)3.1.2 JK-Net(Jump Knowledge Network)
Xu et al. (2018) 提出的JK-Net通过跳转连接聚合多层次表示:5
聚合方式:
- Concat:拼接所有层
- Max-pooling:逐维度取最大值
- LSTM:用注意力加权
class JKNet(nn.Module):
"""Jump Knowledge Network"""
def __init__(self, in_channels, hidden_channels, out_channels,
num_layers, aggr='concat'):
super().__init__()
self.aggr = aggr
self.conv1 = GCNConv(in_channels, hidden_channels)
self.convs = nn.ModuleList()
for _ in range(num_layers - 1):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
if aggr == 'concat':
self.final_proj = nn.Linear(hidden_channels * num_layers, out_channels)
else:
self.final_proj = nn.Linear(hidden_channels, out_channels)
def forward(self, x, edge_index):
# 存储所有层的表示
layer_outputs = []
# 第一层
x = F.relu(self.conv1(x, edge_index))
layer_outputs.append(x)
# 中间层
for conv in self.convs:
x = F.relu(conv(x, edge_index))
layer_outputs.append(x)
# 聚合
if self.aggr == 'concat':
h = torch.cat(layer_outputs, dim=1)
elif self.aggr == 'max':
h = torch.stack(layer_outputs, dim=0).max(dim=0)[0]
elif self.aggr == 'lstm':
# LSTM注意力聚合
h = self.lstm_aggregate(layer_outputs)
else:
h = layer_outputs[-1]
return self.final_proj(h)
def lstm_aggregate(self, layer_outputs):
"""LSTM风格的聚合"""
# 简化版本:使用最后一层
return layer_outputs[-1]3.1.3 深层监督(Deep Supervision)
在中间层添加辅助损失函数:6
class DeeplySupervisedGNN(nn.Module):
"""带深层监督的GNN"""
def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
super().__init__()
self.num_layers = num_layers
self.convs = nn.ModuleList()
self.classifiers = nn.ModuleList()
# 第一个卷积层
self.convs.append(GCNConv(in_channels, hidden_channels))
self.classifiers.append(nn.Linear(hidden_channels, out_channels))
# 中间层
for _ in range(1, num_layers):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.classifiers.append(nn.Linear(hidden_channels, out_channels))
def forward(self, x, edge_index, train_mask=None):
all_logits = []
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
logits = self.classifiers[i](x)
all_logits.append(logits)
# 训练时:使用加权平均的损失
# 推理时:使用最后一层的输出
return all_logits
def compute_loss(self, all_logits, labels, train_mask):
"""计算深层监督的总损失"""
total_loss = 0
for i, logits in enumerate(all_logits):
# 越深层权重越小(课程学习思想)
weight = 1.0 / (i + 1)
loss = F.cross_entropy(logits[train_mask], labels[train_mask])
total_loss += weight * loss
return total_loss3.2 归一化层面
3.2.1 PairNorm
Zhao et al. (2020) 提出的PairNorm通过保持节点嵌入的总相似度来缓解过平滑:3
核心思想:
- 对于每个 batch,计算所有节点对的距离平方和
- 归一化这个和为常数
class PairNorm(nn.Module):
"""PairNorm: 保持节点嵌入对的总体差异"""
def __init__(self, scale=1.0):
super().__init__()
self.scale = scale
def forward(self, x):
# x: (N, D) 其中 N 是节点数
N = x.shape[0]
# 计算所有节点对之间的差异
# x.unsqueeze(1): (N, 1, D)
# x.unsqueeze(0): (1, N, D)
diff = x.unsqueeze(1) - x.unsqueeze(0) # (N, N, D)
# 计算距离平方和
sq_dist = torch.sum(diff ** 2, dim=2) # (N, N)
# 对非对角线元素求和
mask = ~torch.eye(N, dtype=torch.bool, device=x.device)
total_sq_dist = sq_dist[mask].sum()
# PairNorm归一化
if total_sq_dist > 0:
norm_factor = torch.sqrt(N * (N - 1) / total_sq_dist)
x = x * norm_factor * self.scale
return x
class GCNWithPairNorm(nn.Module):
"""使用PairNorm的GCN"""
def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
super().__init__()
self.input_proj = nn.Linear(in_channels, hidden_channels)
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
for _ in range(num_layers):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.norms.append(PairNorm())
self.classifier = nn.Linear(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.input_proj(x)
for conv, norm in zip(self.convs, self.norms):
x = conv(x, edge_index)
x = F.relu(x)
x = norm(x) # PairNorm
return self.classifier(x)3.2.2 SUGCON(Self-Attention Graph Normalization)
SUGCON使用自注意力机制进行归一化:
class SUGCON(nn.Module):
"""SUGCON归一化"""
def __init__(self, in_features):
super().__init__()
self.in_features = in_features
self.scale = in_features ** 0.5
def forward(self, x):
# 自注意力权重
attn = torch.matmul(x, x.transpose(-2, -1)) / self.scale
attn = F.softmax(attn, dim=-1)
# 加权聚合
x_normalized = torch.matmul(attn, x)
return x_normalized3.3 图结构层面
3.3.1 图重连(Graph Rewiring)
Topping et al. (2022) 提出通过曲率重连来消除过压缩:4
算法步骤:
- 计算图中所有边的Ricci曲率
- 识别负曲率边(瓶颈边)
- 通过边重连或添加新边来”填补”负曲率区域
def curvature_based_rewiring(G, target_curvature=0.0, max_iterations=100):
"""
基于曲率的图重连算法
参数:
G: NetworkX图
target_curvature: 目标曲率(通常为0或正值)
max_iterations: 最大迭代次数
返回:
G_rewired: 重连后的图
"""
G_rewired = G.copy()
for iteration in range(max_iterations):
bottleneck_edges = identify_bottleneck_edges(G_rewired, threshold=target_curvature)
if not bottleneck_edges:
break
# 对每个瓶颈边,尝试重连
for (u, v), curvature in bottleneck_edges:
# 策略1:添加u和v之间的边
# 策略2:将u或v连接到其他高曲率节点
# 策略3:重新路由
# 简化:添加一个新节点作为桥梁
new_node = G_rewired.number_of_nodes()
G_rewired.add_node(new_node)
G_rewired.add_edge(u, new_node)
G_rewired.add_edge(new_node, v)
# 移除原来的边(可选)
G_rewired.remove_edge(u, v)
return G_rewired3.3.2 DiffWire
另一种方法是使用随机重连来增加图的展开量:
def random_edge_flip(G, rewiring_ratio=0.1):
"""
随机边翻转:增加图的连通性
参数:
G: NetworkX图
rewiring_ratio: 重连比例
"""
import random
G_rewired = G.copy()
num_edges = G_rewired.number_of_edges()
num_rewire = int(num_edges * rewiring_ratio)
nodes = list(G_rewired.nodes())
for _ in range(num_rewire):
# 随机选择一条边
u, v = random.choice(list(G_rewired.edges()))
# 随机选择另一个节点
w = random.choice(nodes)
# 确保不形成自环和重边
if w != u and w != v and not G_rewired.has_edge(u, w):
G_rewired.remove_edge(u, v)
G_rewired.add_edge(u, w)
return G_rewired4. 实用指南
4.1 何时使用更深的GNN
| 场景 | 推荐层数 | 原因 |
|---|---|---|
| 节点特征强,图结构弱 | 2-3层 | 浅层即可捕获特征信息 |
| 图结构复杂,需要多跳信息 | 4-8层 | 需要足够深度捕获远距离依赖 |
| 异构图/异配图 | 2-4层 | 深层更易受异配性影响 |
| 大规模稀疏图 | 2-3层 | 深层导致过度平滑 |
| 小规模同配图 | 6-10层(配合残差) | 可以尝试更深 |
4.2 推荐层数经验值
Cora/Citeseer/Pubmed 数据集:
├── 2层: 通常最优(特征+一阶邻居)
├── 3-4层: 需要残差连接
└── 8层以上: 通常效果下降
分子图(PQM9, ZINC):
├── 4-6层: 常见配置
└── 取决于分子大小
社交网络(Reddit, Flickr):
├── 2-3层: 邻居爆炸,需要采样
└── 深层需要GraphSAINT等采样方法
4.3 诊断工具
def diagnose_gnn_depth_issues(model, data, device='cpu'):
"""
诊断GNN的深度问题
返回:
诊断报告
"""
model.eval()
model = model.to(device)
data = data.to(device)
# 获取嵌入
with torch.no_grad():
embeddings = model(data.x, data.edge_index)
report = {
'smoothing': {},
'capacity': {},
}
# 1. 过平滑检测
pairwise_var = SmoothingMetrics.pairwise_distance_var(embeddings)
report['smoothing']['pairwise_distance_var'] = pairwise_var
# 2. 计算条件数(需要拉普拉斯矩阵)
edge_index = data.edge_index.cpu()
num_nodes = data.num_nodes
adj = torch.zeros(num_nodes, num_nodes)
adj[edge_index[0], edge_index[1]] = 1
adj = adj + adj.T # 对称化
# 度矩阵
D = torch.diag(adj.sum(dim=1))
# 拉普拉斯矩阵
L = D - adj
cond_num = SmoothingMetrics.condition_number(L)
report['smoothing']['laplacian_condition_number'] = cond_num
# 3. 嵌入的秩(反映嵌入空间的维度)
try:
embed_rank = torch.matrix_rank(embeddings).item()
report['capacity']['embedding_rank'] = embed_rank
report['capacity']['max_possible_rank'] = min(embeddings.shape)
report['capacity']['rank_ratio'] = embed_rank / min(embeddings.shape)
except:
pass
return report4.4 超参数建议
| 参数 | 推荐值 | 说明 |
|---|---|---|
| 隐藏维度 | 64-256 | 足够大以捕获信息,但不要过大 |
| 层数 | 2-6 | 通常2-4层足够 |
| dropout | 0.5-0.7 | 深层需要更强的正则化 |
| 学习率 | 0.01 | 标准设置 |
| 权重衰减 | 5e-4 | 防止过拟合 |
| 邻居采样数 | 10-25 | 大图深层时减少方差 |
4.5 最佳实践检查清单
- 层数控制:从2-3层开始,根据需要逐步增加
- 残差连接:层数>3时必须使用
- 批归一化/LayerNorm:深层GNN的标准配置
- 邻居采样:大图时控制感受野
- 监控指标:训练时记录平滑度指标
- 对比实验:对比浅层和深层的性能
- 图重连:考虑对异配图进行预处理
5. 相关主题
- 图神经网络:GNN的基础概念
- 图卷积网络详解:GCN的详细分析
- ResNet残差网络:残差连接的CNN起源
- 神经网络的表达能力:神经网络表达能力的理论分析
参考
Footnotes
-
Li et al., “Deeper Insights into Graph Convolutional Networks for Semi-Supervised Learning”, AAAI 2018 ↩ ↩2
-
Wu et al., “Simplifying Graph Neural Networks”, ICML 2019 ↩
-
Zhao & Akoglu, “PairNorm: Tackling Over-smoothing in GNNs”, ICLR 2020 ↩ ↩2
-
Topping et al., “Understanding over-squashing and over-curvature on graphs”, ICLR 2022 ↩ ↩2 ↩3
-
Xu et al., “Representation Learning on Graphs with Jumping Knowledge Networks”, ICML 2018 ↩
-
Chen et al., “Measuring and Relieving the Over-smoothing Problem for Graph Neural Networks”, ICLR 2020 ↩