GNN过压缩瓶颈问题
1. 问题引入
图神经网络(GNN)通过迭代的消息传递机制聚合邻域信息,这一设计在处理局部结构时表现优异。然而,当任务需要长距离节点间交互时(如判断两个相距甚远的节点是否满足某种关系),标准MPNN面临一个根本性限制——过压缩(Over-squashing)。1
核心问题:MPNN将来自遥远节点的信息”压缩”到固定维度的节点表示中,导致远距离信息丢失,模型难以学习依赖长程依赖的任务。
2. Over-Squashing的形式化
2.1 Jacobian分析与信息瓶颈
考虑MPNN的节点更新过程。设节点 的第 层表示为 ,定义从输入节点 到节点 的信息贡献由Jacobian矩阵衡量:
其谱范数 衡量了输入信号 对最终输出 的影响程度。
Over-squashing发生在Jacobian范数过小时,表明远距离节点的信息在传播过程中被”压缩”或”丢失”。
2.2 消息传递方程
考虑标准的消息传递范式:
其中 是聚合函数(如mean、sum)。
聚合瓶颈:当节点度很大时,大量邻居的信息被压缩到一个固定维度的向量中。
2.3 分离秩(Separation Rank)
文献2提出用分离秩量化信息压缩程度。对于图 和特征 ,考虑张量形式:
MPNN的输出可视为对 的线性投影,其分离秩上界与图的结构相关。
3. 图拓扑与Over-Squashing
3.1 曲率与有效电阻
文献1证明,over-squashing的严重程度与图的Ricci曲率和有效电阻密切相关。
图Ricci曲率(Ollivier-Ricci曲率):
其中 是 的邻域上的均匀分布, 是Wasserstein距离。
关键发现:负曲率边(表示瓶颈结构)会显著加剧over-squashing。
3.2 有效电阻分析
节点 和 之间的有效电阻 可通过拉普拉斯伪逆计算:
定理3:对于 层的MPNN,从节点 到 的Jacobian上界为:
其中 是与网络宽度和激活函数相关的常数。
3.3 图结构对信息传播的影响
不同的图结构导致不同程度的信息瓶颈:
| 图结构 | 特征 | Over-Squashing程度 |
|---|---|---|
| 完全图 | 无瓶颈 | 最小 |
| 树结构 | 路径长 | 中等 |
| 网格图 | 路径多 | 较小 |
| 小世界网络 | 局部稠密 | 严重(跨社区) |
| 星形图 | 中心节点度大 | 严重 |
4. 深度与宽度的影响
4.1 深度不能缓解问题
关键结论:增加网络深度不能有效缓解over-squashing,因为信息在每一层都会经历压缩。
# 深度与信息传递的关系
# 假设每层传播效率为 η < 1
# L层后的有效信息: η^L
def information_propagation(distance, eta, depth):
"""
模拟信息在MPNN中的传播
distance: 节点间距离
eta: 每跳传播效率 (< 1)
depth: 网络深度
"""
# 实际上需要depth >= distance才能传递信息
if depth < distance:
# 信息被过度压缩
return eta ** distance
else:
# 有足够深度传播
return eta ** distance
# 问题:增加depth并不能增加eta
# 而是增加跳数,每次都要经过瓶颈4.2 宽度可以缓解
定理3:增大网络宽度(隐藏维度)可以提高信息容量,从而缓解over-squashing。
具体而言,Jacobian范数与隐藏维度 的关系为:
实践意义:对于需要长程依赖的任务,选择更宽的GNN层比更深的层更有效。
4.3 理论 impossibility 结果
对于某些图结构,即使深度和宽度都很大,MPNN也可能无法完成任务。
Commute Time条件3:设 是节点 间的commute time(随机游走的往返期望步数)。若任务需要强混合但 ,则任何 层MPNN都无法准确学习该任务。
5. 过平滑与过压缩的对比
5.1 过平滑(Over-Smoothing)
定义:随着层数增加,节点表示趋于相同。
数学表示:
其中 是与输入特征无关的常数。
原因:重复的邻域聚合导致信息稀释。
5.2 过压缩(Over-Squashing)
定义:来自远距离节点的信息被压缩到固定维度的向量中导致丢失。
数学表示:
原因:信息在传播过程中被”挤压”到有限容量中。
5.3 对比总结
| 特性 | 过平滑 | 过压缩 |
|---|---|---|
| 受影响范围 | 所有节点 | 远距离节点对 |
| 增加深度的效果 | 恶化 | 无改善 |
| 增加宽度的效果 | 无直接改善 | 可缓解 |
| 解决方案 | 残差连接、归一化 | 图重布线 |
6. 解决方案
6.1 图重布线(Graph Rewiring)
通过修改图结构消除瓶颈,主要方法包括:
6.1.1 基于曲率的边添加
添加负曲率边(跨社区边)减少有效距离:
def curvature_based_rewiring(G, k):
"""
基于Ricci曲率的图重布线
策略:识别负曲率边,添加绕过这些边的捷径
"""
# 计算所有边的Ollivier-Ricci曲率
curvatures = compute_ricci_curvature(G)
# 识别瓶颈边(负曲率)
bottleneck_edges = [e for e, c in curvatures.items() if c < -epsilon]
# 添加捷径
for u, v in bottleneck_edges:
# 添加u-v的若干跳邻居作为中继
shortcuts = get_common_neighbors(G, u, v, k)
G.add_edges_from([(u, s) for s in shortcuts])
return G6.1.2 DIGL(Diffusion Graph Rewiring)
利用图的扩散矩阵添加”虚拟边”:
其中 是随机游走转移矩阵。
6.1.3 Topping等人的方法
通过负曲率边检测和添加反向边:
其中 是边曲率矩阵。
6.2 架构改进
6.2.1 注意力机制
使用注意力加权缓解信息瓶颈:
class AttentionRewiring(nn.Module):
def __init__(self, dim, num_heads=4):
super().__init__()
self.attention = nn.MultiheadAttention(dim, num_heads)
def forward(self, x, edge_index, long_range_edge_index):
"""
x: 节点特征
edge_index: 局部边
long_range_edge_index: 额外添加的长程边
"""
# 局部消息传递
local_out = local_message_passing(x, edge_index)
# 长程注意力
long_range_out = self.attention(
x, x, x,
attn_mask=create_long_range_mask(x, long_range_edge_index)
)[0]
return local_out + long_range_out6.2.2 读出层改进
使用更强大的读出函数:
def expressive_readout(node_features, edge_features=None):
"""
比SUM/MEAN更强大的读出函数
"""
# 1-WL等价的读出
# 2-WL需要考虑节点对
if edge_features is not None:
# 包含边特征的更丰富表示
node_pair_features = outer_product(node_features)
return sum(node_features) + sum(edge_features)
return soft_sort(node_features, k=10) # Top-k排序6.3 可学习边
SAN(Sample Aggregation Network)
通过可学习的采样策略选择信息传递路径:
class LearnableSampler(nn.Module):
def __init__(self, dim, k):
super().__init__()
self.k = k
self.score_fn = nn.Sequential(
nn.Linear(dim * 2, dim),
nn.ReLU(),
nn.Linear(dim, 1)
)
def forward(self, x, edge_index):
"""
学习选择最重要的k条边
"""
src, dst = edge_index
scores = self.score_fn(torch.cat([x[src], x[dst]], dim=-1))
# 为每个节点选择top-k邻居
sampled_edges = []
for v in range(x.size(0)):
neighbors = torch.where(src == v)[0]
top_k = torch.topk(scores[neighbors], min(self.k, len(neighbors))).indices
sampled_edges.extend([(v, dst[neighbors[i]]) for i in top_k])
return torch.tensor(sampled_edges).t()6.4 动态图
在动态图或时序图上,信息可以直接通过时间维度传递:
class TemporalAggregator(nn.Module):
def forward(self, x_history):
"""
x_history: (T, n, d) 时间序列特征
"""
# 跨时间步注意力
temporal_attn = self.temporal_attention(x_history)
# 结合空间和时间信息
return spatial_gnn(x_history) + temporal_attn7. Over-Squashing的量化指标
7.1 Bottleneck Ratio
定义图的瓶颈比:
7.2 Jacobian Sum
7.3 压缩比
8. 实践指南
8.1 何时关注Over-Squashing
需要关注的场景:
- 节点分类依赖远距离节点特征
- 图级别任务需要全局信息
- 图存在明显社区结构
- 链路预测涉及跨社区节点
影响较小的场景:
- 局部节点特征足够完成任务
- 图结构相对均匀
- 任务主要依赖直接邻居
8.2 诊断方法
import torch
from scipy.sparse.csgraph import shortest_path
def diagnose_over_squashing(model, graph, x):
"""
诊断模型是否受到over-squashing影响
"""
# 1. 检查节点度数分布
degrees = graph.sum(dim=1).numpy()
max_degree = degrees.max()
degree_variance = degrees.var()
# 2. 计算图的谱间隙
L = compute_laplacian(graph)
eigenvalues = torch.linalg.eigvalsh(L).numpy()
spectral_gap = eigenvalues[1] - eigenvalues[0]
# 3. 计算平均最短路径
dist_matrix = shortest_path(graph.numpy())
avg_path = dist_matrix[dist_matrix < np.inf].mean()
# 4. 计算有效电阻
eff_resistances = compute_effective_resistance(graph)
return {
'max_degree': max_degree,
'degree_variance': degree_variance,
'spectral_gap': spectral_gap,
'avg_path_length': avg_path,
'avg_eff_resistance': eff_resistances.mean()
}8.3 缓解策略选择
| 场景 | 推荐策略 |
|---|---|
| 社区结构明显 | 图重布线(添加跨社区边) |
| 资源充足 | 增大模型宽度 |
| 需要保持效率 | 稀疏注意力 |
| 动态图 | 时序注意力 |
9. 小结
Over-squashing是MPNN的一个根本性限制,它源于两个因素:
- 结构因素:图拓扑中的瓶颈(负曲率边、高有效电阻)
- 容量因素:固定维度的节点表示限制了信息容量
关键要点:
| 观点 | 结论 |
|---|---|
| 增加深度 | ❌ 不能解决over-squashing |
| 增加宽度 | ✅ 可以缓解(提高容量) |
| 图重布线 | ✅ 从根本上解决问题 |
| 注意力机制 | ✅ 动态选择信息传递路径 |
理解over-squashing对于设计更强大的图神经网络架构至关重要,特别是在需要长程依赖的应用中。
参考文献
Footnotes
-
Topping, J., Di Giovanni, F., et al. (2022). “Understanding over-squashing and over-smoothing on graphs.” ICLR 2022. arXiv:2110.00597 ↩ ↩2
-
Razin, N., et al. (2022). “Rethinking the expressive power of GNNs via graph homomorphism counting.” NeurIPS 2022. ↩
-
Di Giovanni, A., Konstantin Rusch, T., et al. (2023). “How does over-squashing affect the power of GNNs?” NeurIPS 2023. arXiv:2306.03589 ↩ ↩2 ↩3