概述

深层图神经网络(Deep Graph Neural Networks)的训练面临两个核心挑战:过平滑(Over-smoothing)和过压缩(Over-squashing)。1

这两个问题严重限制了GNN向深层发展的能力:

问题本质表现
过平滑多次平滑导致节点表示趋于相同节点特征丧失区分性
过压缩多跳信息被压缩到固定维度远距离依赖丢失

理解这两个问题的数学本质,对于设计更深、更强大的GNN架构至关重要。


1. 过平滑问题

1.1 定义与直观理解

过平滑是指随着GNN层数增加,节点的表示向量逐渐趋同,最终所有节点的嵌入变得几乎相同。

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
 
def measure_smoothing(embeddings):
    """测量嵌入的平滑程度(越低越平滑)"""
    # 计算所有节点对之间嵌入的距离方差
    pairwise_dist = torch.pdist(embeddings, p=2)
    return pairwise_dist.var().item()
 
# 测试不同层数的平滑程度
def test_smoothing_levels():
    """
    假设我们有一个简单的图结构
    观察随着层数增加,节点嵌入的区分性如何变化
    """
    # 底层真实嵌入应该有较大差异
    # 经过多层GNN后,差异逐渐消失
    pass

1.2 数学描述

拉普拉斯平滑的视角

对于两层GCN,假设没有激活函数和激活函数:

其中 是对称归一化邻接矩阵。

特征值分解分析

的特征分解为:

其中 是特征向量矩阵, 是特征值矩阵。

关键性质

  • 归一化邻接矩阵 的特征值满足:
  • 最小特征值 (对应全1向量)
  • 特征值越接近1,信号衰减越少

K层传播的数学表达

经过 层图卷积后:

其中 是全1向量, 是稳态分布(与图的度分布相关)。

推论:所有节点的表示趋于相同!(除非有残差连接)

1.3 过平滑的数学界限

有效平滑距离

定义平滑速度为节点表示收敛到相同值的速率:2

其中 是对应于特征值 的特征向量分量。

收敛边界

对于任意节点

其中 是第二大的特征值)。

越接近1,收敛越慢(不易过平滑)
越接近0,收敛越快(易过平滑)

PairNorm的数学动机

PairNorm的核心思想是保持节点嵌入的总平方范数恒定:3

其中 是归一化算子。

1.4 过平滑的诊断指标

class SmoothingMetrics:
    """计算过平滑程度的多种指标"""
    
    @staticmethod
    def pairwise_distance_var(embeddings):
        """节点嵌入对的距离方差"""
        pairwise_dist = torch.cdist(embeddings, embeddings)
        # 排除对角线
        mask = ~torch.eye(embeddings.shape[0], dtype=torch.bool)
        return pairwise_dist[mask].var()
    
    @staticmethod
    def condition_number(laplacian):
        """图的拉普拉斯矩阵条件数(与过平滑相关)"""
        eigenvalues = torch.linalg.eigvalsh(laplacian)
        # 排除零特征值
        nonzero_eigenvalues = eigenvalues[1:]
        return nonzero_eigenvalues.max() / nonzero_eigenvalues.min()
    
    @staticmethod
    def spectral_gap(normalized_adj):
        """谱间隙 1 - λ2(越大越不易过平滑)"""
        eigenvalues = torch.linalg.eigvalsh(normalized_adj)
        # λ1 = 1, λ2 是第二大的
        return 1 - eigenvalues[-2].item()

2. 过压缩问题

2.1 定义与直观理解

过压缩是指GNN的消息传递机制将来自多个邻居/多条路径的信息”压缩”到固定维度的向量中,导致远距离信息丢失。4

层数1: 节点只能接收1跳邻居的信息
层数2: 节点可以接收2跳邻居的信息(但被压缩)
层数3: 节点可以接收3跳邻居的信息(被进一步压缩)
...
层数K: 节点可以接收K跳邻居的信息(严重压缩)

2.2 数学分析:Jacobian视角

消息传递的回流

考虑一个简化的线性消息传递:

Jacobian矩阵的分析

节点 关于输入 的 Jacobian:

关键问题:当 很大时,这个乘积的范数会如何变化?

如果权重的谱范数 ,梯度会指数衰减(与RNN的梯度消失类似)。

2.3 信息瓶颈理论

通道容量视角

将消息传递视为一个信息通道:

输入: 来自多个邻居的信息 {h_u : u ∈ N(v)}
      ↓
编码: 通过线性变换 + 聚合 → h_v
      ↓
输出: 固定维度的节点嵌入

核心约束:固定维度的嵌入无法无损失地表示来自任意数量邻居的信息。

互信息分析

表示输入与输出之间的互信息:

其中 是嵌入维度, 是通道容量。

2.4 曲率基测量方法

Topping et al. (2022) 提出使用图Ricci曲率来量化过压缩程度。4

曲率定义

对于无权无向图,节点 之间的Ricci曲率定义为:

其中:

  • :最短路径上的中间节点数
  • :边的权重乘积

负曲率与过压缩

  • 正曲率边:信息流动顺畅
  • 负曲率边:信息流动受阻,易发生过压缩
import numpy as np
 
def compute_ricci_curvature(G, edge):
    """
    简化版的图Ricci曲率计算
    
    参数:
        G: NetworkX图
        edge: 边 (u, v)
    
    返回:
        curvature: 曲率值(负值表示负曲率边)
    """
    u, v = edge
    
    # 获取邻居
    neighbors_u = set(G.neighbors(u)) - {v}
    neighbors_v = set(G.neighbors(v)) - {u}
    
    # 计算最短路径上的中间节点
    common_neighbors = neighbors_u & neighbors_v
    u_to_v_via_common = len(common_neighbors)
    
    # 通过其他节点的路径
    paths_through_other = 0
    for w in (neighbors_u - common_neighbors):
        for x in (neighbors_v - common_neighbors):
            if G.has_edge(w, x):
                paths_through_other += 1
    
    # Ricci曲率公式
    curvature = 1 - (u_to_v_via_common + paths_through_other) / (len(neighbors_u) + len(neighbors_v))
    
    return curvature
 
def identify_bottleneck_edges(G, threshold=-0.1):
    """
    识别可能导致过压缩的瓶颈边
    
    参数:
        G: NetworkX图
        threshold: 曲率阈值(低于此值认为是瓶颈)
    
    返回:
        bottleneck_edges: 瓶颈边列表
    """
    bottleneck_edges = []
    
    for edge in G.edges():
        curvature = compute_ricci_curvature(G, edge)
        if curvature < threshold:
            bottleneck_edges.append((edge, curvature))
    
    return bottleneck_edges

3. 解决方案

3.1 架构层面

3.1.1 残差连接

最直接的方法是在每层GNN后添加残差连接:1

效果:保持原始信号通路,延缓过平滑

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
 
class ResGNN(nn.Module):
    """带残差连接的GNN"""
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
        super().__init__()
        self.num_layers = num_layers
        
        # 输入投影层
        self.input_proj = nn.Linear(in_channels, hidden_channels)
        
        # GNN层列表
        self.convs = nn.ModuleList()
        for _ in range(num_layers):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        
        # 输出层
        self.output_proj = nn.Linear(hidden_channels, out_channels)
    
    def forward(self, x, edge_index):
        # 输入投影
        h = self.input_proj(x)
        
        # 带残差连接的GNN层
        for i, conv in enumerate(self.convs):
            h_new = conv(h, edge_index)
            h_new = F.relu(h_new)
            
            # 残差连接:每两层加一次残差
            if i % 2 == 1:
                h = h + h_new  # 残差连接
            else:
                h = h_new
        
        return self.output_proj(h)

3.1.2 JK-Net(Jump Knowledge Network)

Xu et al. (2018) 提出的JK-Net通过跳转连接聚合多层次表示:5

聚合方式

  • Concat:拼接所有层
  • Max-pooling:逐维度取最大值
  • LSTM:用注意力加权
class JKNet(nn.Module):
    """Jump Knowledge Network"""
    def __init__(self, in_channels, hidden_channels, out_channels, 
                 num_layers, aggr='concat'):
        super().__init__()
        self.aggr = aggr
        
        self.conv1 = GCNConv(in_channels, hidden_channels)
        
        self.convs = nn.ModuleList()
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        
        if aggr == 'concat':
            self.final_proj = nn.Linear(hidden_channels * num_layers, out_channels)
        else:
            self.final_proj = nn.Linear(hidden_channels, out_channels)
    
    def forward(self, x, edge_index):
        # 存储所有层的表示
        layer_outputs = []
        
        # 第一层
        x = F.relu(self.conv1(x, edge_index))
        layer_outputs.append(x)
        
        # 中间层
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
            layer_outputs.append(x)
        
        # 聚合
        if self.aggr == 'concat':
            h = torch.cat(layer_outputs, dim=1)
        elif self.aggr == 'max':
            h = torch.stack(layer_outputs, dim=0).max(dim=0)[0]
        elif self.aggr == 'lstm':
            # LSTM注意力聚合
            h = self.lstm_aggregate(layer_outputs)
        else:
            h = layer_outputs[-1]
        
        return self.final_proj(h)
    
    def lstm_aggregate(self, layer_outputs):
        """LSTM风格的聚合"""
        # 简化版本:使用最后一层
        return layer_outputs[-1]

3.1.3 深层监督(Deep Supervision)

在中间层添加辅助损失函数:6

class DeeplySupervisedGNN(nn.Module):
    """带深层监督的GNN"""
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
        super().__init__()
        self.num_layers = num_layers
        
        self.convs = nn.ModuleList()
        self.classifiers = nn.ModuleList()
        
        # 第一个卷积层
        self.convs.append(GCNConv(in_channels, hidden_channels))
        self.classifiers.append(nn.Linear(hidden_channels, out_channels))
        
        # 中间层
        for _ in range(1, num_layers):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.classifiers.append(nn.Linear(hidden_channels, out_channels))
    
    def forward(self, x, edge_index, train_mask=None):
        all_logits = []
        
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=0.5, training=self.training)
            
            logits = self.classifiers[i](x)
            all_logits.append(logits)
        
        # 训练时:使用加权平均的损失
        # 推理时:使用最后一层的输出
        return all_logits
    
    def compute_loss(self, all_logits, labels, train_mask):
        """计算深层监督的总损失"""
        total_loss = 0
        
        for i, logits in enumerate(all_logits):
            # 越深层权重越小(课程学习思想)
            weight = 1.0 / (i + 1)
            loss = F.cross_entropy(logits[train_mask], labels[train_mask])
            total_loss += weight * loss
        
        return total_loss

3.2 归一化层面

3.2.1 PairNorm

Zhao et al. (2020) 提出的PairNorm通过保持节点嵌入的总相似度来缓解过平滑:3

核心思想

  • 对于每个 batch,计算所有节点对的距离平方和
  • 归一化这个和为常数
class PairNorm(nn.Module):
    """PairNorm: 保持节点嵌入对的总体差异"""
    def __init__(self, scale=1.0):
        super().__init__()
        self.scale = scale
    
    def forward(self, x):
        # x: (N, D) 其中 N 是节点数
        N = x.shape[0]
        
        # 计算所有节点对之间的差异
        # x.unsqueeze(1): (N, 1, D)
        # x.unsqueeze(0): (1, N, D)
        diff = x.unsqueeze(1) - x.unsqueeze(0)  # (N, N, D)
        
        # 计算距离平方和
        sq_dist = torch.sum(diff ** 2, dim=2)  # (N, N)
        
        # 对非对角线元素求和
        mask = ~torch.eye(N, dtype=torch.bool, device=x.device)
        total_sq_dist = sq_dist[mask].sum()
        
        # PairNorm归一化
        if total_sq_dist > 0:
            norm_factor = torch.sqrt(N * (N - 1) / total_sq_dist)
            x = x * norm_factor * self.scale
        
        return x
 
class GCNWithPairNorm(nn.Module):
    """使用PairNorm的GCN"""
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
        super().__init__()
        
        self.input_proj = nn.Linear(in_channels, hidden_channels)
        
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        
        for _ in range(num_layers):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.norms.append(PairNorm())
        
        self.classifier = nn.Linear(hidden_channels, out_channels)
    
    def forward(self, x, edge_index):
        x = self.input_proj(x)
        
        for conv, norm in zip(self.convs, self.norms):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = norm(x)  # PairNorm
        
        return self.classifier(x)

3.2.2 SUGCON(Self-Attention Graph Normalization)

SUGCON使用自注意力机制进行归一化:

class SUGCON(nn.Module):
    """SUGCON归一化"""
    def __init__(self, in_features):
        super().__init__()
        self.in_features = in_features
        self.scale = in_features ** 0.5
    
    def forward(self, x):
        # 自注意力权重
        attn = torch.matmul(x, x.transpose(-2, -1)) / self.scale
        attn = F.softmax(attn, dim=-1)
        
        # 加权聚合
        x_normalized = torch.matmul(attn, x)
        
        return x_normalized

3.3 图结构层面

3.3.1 图重连(Graph Rewiring)

Topping et al. (2022) 提出通过曲率重连来消除过压缩:4

算法步骤

  1. 计算图中所有边的Ricci曲率
  2. 识别负曲率边(瓶颈边)
  3. 通过边重连或添加新边来”填补”负曲率区域
def curvature_based_rewiring(G, target_curvature=0.0, max_iterations=100):
    """
    基于曲率的图重连算法
    
    参数:
        G: NetworkX图
        target_curvature: 目标曲率(通常为0或正值)
        max_iterations: 最大迭代次数
    
    返回:
        G_rewired: 重连后的图
    """
    G_rewired = G.copy()
    
    for iteration in range(max_iterations):
        bottleneck_edges = identify_bottleneck_edges(G_rewired, threshold=target_curvature)
        
        if not bottleneck_edges:
            break
        
        # 对每个瓶颈边,尝试重连
        for (u, v), curvature in bottleneck_edges:
            # 策略1:添加u和v之间的边
            # 策略2:将u或v连接到其他高曲率节点
            # 策略3:重新路由
            
            # 简化:添加一个新节点作为桥梁
            new_node = G_rewired.number_of_nodes()
            G_rewired.add_node(new_node)
            G_rewired.add_edge(u, new_node)
            G_rewired.add_edge(new_node, v)
            
            # 移除原来的边(可选)
            G_rewired.remove_edge(u, v)
    
    return G_rewired

3.3.2 DiffWire

另一种方法是使用随机重连来增加图的展开量:

def random_edge_flip(G, rewiring_ratio=0.1):
    """
    随机边翻转:增加图的连通性
    
    参数:
        G: NetworkX图
        rewiring_ratio: 重连比例
    """
    import random
    
    G_rewired = G.copy()
    num_edges = G_rewired.number_of_edges()
    num_rewire = int(num_edges * rewiring_ratio)
    
    nodes = list(G_rewired.nodes())
    
    for _ in range(num_rewire):
        # 随机选择一条边
        u, v = random.choice(list(G_rewired.edges()))
        
        # 随机选择另一个节点
        w = random.choice(nodes)
        
        # 确保不形成自环和重边
        if w != u and w != v and not G_rewired.has_edge(u, w):
            G_rewired.remove_edge(u, v)
            G_rewired.add_edge(u, w)
    
    return G_rewired

4. 实用指南

4.1 何时使用更深的GNN

场景推荐层数原因
节点特征强,图结构弱2-3层浅层即可捕获特征信息
图结构复杂,需要多跳信息4-8层需要足够深度捕获远距离依赖
异构图/异配图2-4层深层更易受异配性影响
大规模稀疏图2-3层深层导致过度平滑
小规模同配图6-10层(配合残差)可以尝试更深

4.2 推荐层数经验值

Cora/Citeseer/Pubmed 数据集:
├── 2层: 通常最优(特征+一阶邻居)
├── 3-4层: 需要残差连接
└── 8层以上: 通常效果下降

分子图(PQM9, ZINC):
├── 4-6层: 常见配置
└── 取决于分子大小

社交网络(Reddit, Flickr):
├── 2-3层: 邻居爆炸,需要采样
└── 深层需要GraphSAINT等采样方法

4.3 诊断工具

def diagnose_gnn_depth_issues(model, data, device='cpu'):
    """
    诊断GNN的深度问题
    
    返回:
        诊断报告
    """
    model.eval()
    model = model.to(device)
    data = data.to(device)
    
    # 获取嵌入
    with torch.no_grad():
        embeddings = model(data.x, data.edge_index)
    
    report = {
        'smoothing': {},
        'capacity': {},
    }
    
    # 1. 过平滑检测
    pairwise_var = SmoothingMetrics.pairwise_distance_var(embeddings)
    report['smoothing']['pairwise_distance_var'] = pairwise_var
    
    # 2. 计算条件数(需要拉普拉斯矩阵)
    edge_index = data.edge_index.cpu()
    num_nodes = data.num_nodes
    adj = torch.zeros(num_nodes, num_nodes)
    adj[edge_index[0], edge_index[1]] = 1
    adj = adj + adj.T  # 对称化
    
    # 度矩阵
    D = torch.diag(adj.sum(dim=1))
    # 拉普拉斯矩阵
    L = D - adj
    
    cond_num = SmoothingMetrics.condition_number(L)
    report['smoothing']['laplacian_condition_number'] = cond_num
    
    # 3. 嵌入的秩(反映嵌入空间的维度)
    try:
        embed_rank = torch.matrix_rank(embeddings).item()
        report['capacity']['embedding_rank'] = embed_rank
        report['capacity']['max_possible_rank'] = min(embeddings.shape)
        report['capacity']['rank_ratio'] = embed_rank / min(embeddings.shape)
    except:
        pass
    
    return report

4.4 超参数建议

参数推荐值说明
隐藏维度64-256足够大以捕获信息,但不要过大
层数2-6通常2-4层足够
dropout0.5-0.7深层需要更强的正则化
学习率0.01标准设置
权重衰减5e-4防止过拟合
邻居采样数10-25大图深层时减少方差

4.5 最佳实践检查清单

  • 层数控制:从2-3层开始,根据需要逐步增加
  • 残差连接:层数>3时必须使用
  • 批归一化/LayerNorm:深层GNN的标准配置
  • 邻居采样:大图时控制感受野
  • 监控指标:训练时记录平滑度指标
  • 对比实验:对比浅层和深层的性能
  • 图重连:考虑对异配图进行预处理

5. 相关主题


参考


Footnotes

  1. Li et al., “Deeper Insights into Graph Convolutional Networks for Semi-Supervised Learning”, AAAI 2018 2

  2. Wu et al., “Simplifying Graph Neural Networks”, ICML 2019

  3. Zhao & Akoglu, “PairNorm: Tackling Over-smoothing in GNNs”, ICLR 2020 2

  4. Topping et al., “Understanding over-squashing and over-curvature on graphs”, ICLR 2022 2 3

  5. Xu et al., “Representation Learning on Graphs with Jumping Knowledge Networks”, ICML 2018

  6. Chen et al., “Measuring and Relieving the Over-smoothing Problem for Graph Neural Networks”, ICLR 2020