GNN过压缩瓶颈问题

1. 问题引入

图神经网络(GNN)通过迭代的消息传递机制聚合邻域信息,这一设计在处理局部结构时表现优异。然而,当任务需要长距离节点间交互时(如判断两个相距甚远的节点是否满足某种关系),标准MPNN面临一个根本性限制——过压缩(Over-squashing)。1

核心问题:MPNN将来自遥远节点的信息”压缩”到固定维度的节点表示中,导致远距离信息丢失,模型难以学习依赖长程依赖的任务。


2. Over-Squashing的形式化

2.1 Jacobian分析与信息瓶颈

考虑MPNN的节点更新过程。设节点 的第 层表示为 ,定义从输入节点 到节点 信息贡献由Jacobian矩阵衡量:

其谱范数 衡量了输入信号 对最终输出 的影响程度。

Over-squashing发生在Jacobian范数过小时,表明远距离节点的信息在传播过程中被”压缩”或”丢失”。

2.2 消息传递方程

考虑标准的消息传递范式:

其中 是聚合函数(如mean、sum)。

聚合瓶颈:当节点度很大时,大量邻居的信息被压缩到一个固定维度的向量中。

2.3 分离秩(Separation Rank)

文献2提出用分离秩量化信息压缩程度。对于图 和特征 ,考虑张量形式:

MPNN的输出可视为对 的线性投影,其分离秩上界与图的结构相关。


3. 图拓扑与Over-Squashing

3.1 曲率与有效电阻

文献1证明,over-squashing的严重程度与图的Ricci曲率有效电阻密切相关。

图Ricci曲率(Ollivier-Ricci曲率):

其中 的邻域上的均匀分布, 是Wasserstein距离。

关键发现:负曲率边(表示瓶颈结构)会显著加剧over-squashing。

3.2 有效电阻分析

节点 之间的有效电阻 可通过拉普拉斯伪逆计算:

定理3:对于 层的MPNN,从节点 的Jacobian上界为:

其中 是与网络宽度和激活函数相关的常数。

3.3 图结构对信息传播的影响

不同的图结构导致不同程度的信息瓶颈:

图结构特征Over-Squashing程度
完全图无瓶颈最小
树结构路径长中等
网格图路径多较小
小世界网络局部稠密严重(跨社区)
星形图中心节点度大严重

4. 深度与宽度的影响

4.1 深度不能缓解问题

关键结论:增加网络深度不能有效缓解over-squashing,因为信息在每一层都会经历压缩。

# 深度与信息传递的关系
# 假设每层传播效率为 η < 1
# L层后的有效信息: η^L
 
def information_propagation(distance, eta, depth):
    """
    模拟信息在MPNN中的传播
    
    distance: 节点间距离
    eta: 每跳传播效率 (< 1)
    depth: 网络深度
    """
    # 实际上需要depth >= distance才能传递信息
    if depth < distance:
        # 信息被过度压缩
        return eta ** distance
    else:
        # 有足够深度传播
        return eta ** distance
 
# 问题:增加depth并不能增加eta
# 而是增加跳数,每次都要经过瓶颈

4.2 宽度可以缓解

定理3:增大网络宽度(隐藏维度)可以提高信息容量,从而缓解over-squashing。

具体而言,Jacobian范数与隐藏维度 的关系为:

实践意义:对于需要长程依赖的任务,选择更宽的GNN层比更深的层更有效。

4.3 理论 impossibility 结果

对于某些图结构,即使深度和宽度都很大,MPNN也可能无法完成任务。

Commute Time条件3:设 是节点 间的commute time(随机游走的往返期望步数)。若任务需要强混合但 ,则任何 层MPNN都无法准确学习该任务。


5. 过平滑与过压缩的对比

5.1 过平滑(Over-Smoothing)

定义:随着层数增加,节点表示趋于相同。

数学表示

其中 是与输入特征无关的常数。

原因:重复的邻域聚合导致信息稀释。

5.2 过压缩(Over-Squashing)

定义:来自远距离节点的信息被压缩到固定维度的向量中导致丢失。

数学表示

原因:信息在传播过程中被”挤压”到有限容量中。

5.3 对比总结

特性过平滑过压缩
受影响范围所有节点远距离节点对
增加深度的效果恶化无改善
增加宽度的效果无直接改善可缓解
解决方案残差连接、归一化图重布线

6. 解决方案

6.1 图重布线(Graph Rewiring)

通过修改图结构消除瓶颈,主要方法包括:

6.1.1 基于曲率的边添加

添加负曲率边(跨社区边)减少有效距离:

def curvature_based_rewiring(G, k):
    """
    基于Ricci曲率的图重布线
    
    策略:识别负曲率边,添加绕过这些边的捷径
    """
    # 计算所有边的Ollivier-Ricci曲率
    curvatures = compute_ricci_curvature(G)
    
    # 识别瓶颈边(负曲率)
    bottleneck_edges = [e for e, c in curvatures.items() if c < -epsilon]
    
    # 添加捷径
    for u, v in bottleneck_edges:
        # 添加u-v的若干跳邻居作为中继
        shortcuts = get_common_neighbors(G, u, v, k)
        G.add_edges_from([(u, s) for s in shortcuts])
    
    return G

6.1.2 DIGL(Diffusion Graph Rewiring)

利用图的扩散矩阵添加”虚拟边”:

其中 是随机游走转移矩阵。

6.1.3 Topping等人的方法

通过负曲率边检测和添加反向边:

其中 是边曲率矩阵。

6.2 架构改进

6.2.1 注意力机制

使用注意力加权缓解信息瓶颈:

class AttentionRewiring(nn.Module):
    def __init__(self, dim, num_heads=4):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads)
        
    def forward(self, x, edge_index, long_range_edge_index):
        """
        x: 节点特征
        edge_index: 局部边
        long_range_edge_index: 额外添加的长程边
        """
        # 局部消息传递
        local_out = local_message_passing(x, edge_index)
        
        # 长程注意力
        long_range_out = self.attention(
            x, x, x, 
            attn_mask=create_long_range_mask(x, long_range_edge_index)
        )[0]
        
        return local_out + long_range_out

6.2.2 读出层改进

使用更强大的读出函数:

def expressive_readout(node_features, edge_features=None):
    """
    比SUM/MEAN更强大的读出函数
    """
    # 1-WL等价的读出
    # 2-WL需要考虑节点对
    if edge_features is not None:
        # 包含边特征的更丰富表示
        node_pair_features = outer_product(node_features)
        return sum(node_features) + sum(edge_features)
    return soft_sort(node_features, k=10)  # Top-k排序

6.3 可学习边

SAN(Sample Aggregation Network)

通过可学习的采样策略选择信息传递路径:

class LearnableSampler(nn.Module):
    def __init__(self, dim, k):
        super().__init__()
        self.k = k
        self.score_fn = nn.Sequential(
            nn.Linear(dim * 2, dim),
            nn.ReLU(),
            nn.Linear(dim, 1)
        )
        
    def forward(self, x, edge_index):
        """
        学习选择最重要的k条边
        """
        src, dst = edge_index
        scores = self.score_fn(torch.cat([x[src], x[dst]], dim=-1))
        
        # 为每个节点选择top-k邻居
        sampled_edges = []
        for v in range(x.size(0)):
            neighbors = torch.where(src == v)[0]
            top_k = torch.topk(scores[neighbors], min(self.k, len(neighbors))).indices
            sampled_edges.extend([(v, dst[neighbors[i]]) for i in top_k])
        
        return torch.tensor(sampled_edges).t()

6.4 动态图

在动态图或时序图上,信息可以直接通过时间维度传递:

class TemporalAggregator(nn.Module):
    def forward(self, x_history):
        """
        x_history: (T, n, d) 时间序列特征
        """
        # 跨时间步注意力
        temporal_attn = self.temporal_attention(x_history)
        
        # 结合空间和时间信息
        return spatial_gnn(x_history) + temporal_attn

7. Over-Squashing的量化指标

7.1 Bottleneck Ratio

定义图的瓶颈比:

7.2 Jacobian Sum

7.3 压缩比


8. 实践指南

8.1 何时关注Over-Squashing

需要关注的场景

  • 节点分类依赖远距离节点特征
  • 图级别任务需要全局信息
  • 图存在明显社区结构
  • 链路预测涉及跨社区节点

影响较小的场景

  • 局部节点特征足够完成任务
  • 图结构相对均匀
  • 任务主要依赖直接邻居

8.2 诊断方法

import torch
from scipy.sparse.csgraph import shortest_path
 
def diagnose_over_squashing(model, graph, x):
    """
    诊断模型是否受到over-squashing影响
    """
    # 1. 检查节点度数分布
    degrees = graph.sum(dim=1).numpy()
    max_degree = degrees.max()
    degree_variance = degrees.var()
    
    # 2. 计算图的谱间隙
    L = compute_laplacian(graph)
    eigenvalues = torch.linalg.eigvalsh(L).numpy()
    spectral_gap = eigenvalues[1] - eigenvalues[0]
    
    # 3. 计算平均最短路径
    dist_matrix = shortest_path(graph.numpy())
    avg_path = dist_matrix[dist_matrix < np.inf].mean()
    
    # 4. 计算有效电阻
    eff_resistances = compute_effective_resistance(graph)
    
    return {
        'max_degree': max_degree,
        'degree_variance': degree_variance,
        'spectral_gap': spectral_gap,
        'avg_path_length': avg_path,
        'avg_eff_resistance': eff_resistances.mean()
    }

8.3 缓解策略选择

场景推荐策略
社区结构明显图重布线(添加跨社区边)
资源充足增大模型宽度
需要保持效率稀疏注意力
动态图时序注意力

9. 小结

Over-squashing是MPNN的一个根本性限制,它源于两个因素:

  1. 结构因素:图拓扑中的瓶颈(负曲率边、高有效电阻)
  2. 容量因素:固定维度的节点表示限制了信息容量

关键要点

观点结论
增加深度❌ 不能解决over-squashing
增加宽度✅ 可以缓解(提高容量)
图重布线✅ 从根本上解决问题
注意力机制✅ 动态选择信息传递路径

理解over-squashing对于设计更强大的图神经网络架构至关重要,特别是在需要长程依赖的应用中。


参考文献

Footnotes

  1. Topping, J., Di Giovanni, F., et al. (2022). “Understanding over-squashing and over-smoothing on graphs.” ICLR 2022. arXiv:2110.00597 2

  2. Razin, N., et al. (2022). “Rethinking the expressive power of GNNs via graph homomorphism counting.” NeurIPS 2022.

  3. Di Giovanni, A., Konstantin Rusch, T., et al. (2023). “How does over-squashing affect the power of GNNs?” NeurIPS 2023. arXiv:2306.03589 2 3