概述

图神经网络(GNN)的表达能力(Expressivity)理论是理解图神经网络能力边界的核心课题。12

表达能力决定了GNN能够区分什么样的图结构、能够学习什么样的函数。理解这一理论对于设计更强大的GNN架构、选择合适的模型具有重要意义。

核心问题

给定一个GNN架构,它能够区分哪些非同构图?它能够检测或计数哪些子结构?

本文档系统性地介绍GNN表达能力理论的核心内容:

主题核心内容
1-WL测试消息传递GNN的表达能力上界
超越1-WLK-WL层级、K跳消息传递、子图GNN
同态表达性ICLR 2024提出的定量框架
实践意义表达能力与真实任务的关系

1-WL测试与GNN的关系

Weisfeiler-Lehman颜色细化算法

1-WL(又称颜色细化/Color Refinement)是一种经典的图同构启发式算法。3

算法流程

给定两个图 ,1-WL测试迭代地比较它们的结构:

初始化:为所有节点赋予初始颜色(通常是度或特征)

迭代更新:在每次迭代中,节点的颜色根据其邻居颜色集合更新:

颜色更新规则

其中 将邻居颜色聚合为多重集(忽略顺序), 将聚合结果哈希为新颜色。

PyTorch实现

import torch
 
def wl_color_refinement(edge_index, initial_colors, num_iterations=10):
    """
    1-WL颜色细化算法的简化实现
    
    参数:
        edge_index: 边索引 (2, num_edges)
        initial_colors: 初始颜色 (num_nodes,)
        num_iterations: 迭代次数
    
    返回:
        colors: 最终颜色
    """
    num_nodes = initial_colors.shape[0]
    colors = initial_colors.clone()
    
    for _ in range(num_iterations):
        # 收集邻居颜色
        neighbor_colors = colors[edge_index[1]]  # 源节点颜色
        
        # 对每个节点聚合邻居颜色
        new_colors = torch.zeros(num_nodes, dtype=torch.long)
        
        # 遍历所有节点
        for v in range(num_nodes):
            # 找到节点v的所有邻居
            neighbors = edge_index[1][edge_index[0] == v]
            if len(neighbors) > 0:
                # 聚合:排序后拼接
                neighbor_set = colors[neighbors]
                key = torch.cat([colors[v].unsqueeze(0), torch.sort(neighbor_set)[0]])
            else:
                key = colors[v].unsqueeze(0)
            
            # 简化的哈希(实际中用更强的哈希)
            new_colors[v] = (colors[v] * 31 + neighbor_set.sum()) % 1000
        
        colors = new_colors
        
        # 检查是否稳定
        if len(torch.unique(colors)) == num_nodes:
            break
    
    return colors

GNN与1-WL的形式化对应

消息传递神经网络(MPNN)

现代GNN大多采用消息传递范式:

关键定理(Xu et al., 2019 / Morris et al., 2019)

定理:如果两层GNN的聚合函数是单射的(injective),则其表达能力与1-WL测试等价。12

设:

  • :聚合函数(单射)
  • :图级 readout 函数
  • :多层感知机

GIN(Graph Isomorphism Network) 的聚合规则:

证明概要

  1. 聚合步骤 ↔ 颜色细化:GNN的聚合函数类似于1-WL的颜色聚合
  2. 单射性保证:当聚合函数是单射的,它能够区分不同的多重集
  3. 不动点:当颜色不再变化时,GNN达到与1-WL相同的稳定状态
import torch
import torch.nn as nn
import torch.nn.functional as F
 
class GIN(nn.Module):
    """Graph Isomorphism Network - 最强表达能力的消息传递GNN"""
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3):
        super().__init__()
        self.num_layers = num_layers
        
        # 每层的MLP
        self.mlps = nn.ModuleList()
        for i in range(num_layers):
            if i == 0:
                self.mlps.append(nn.Sequential(
                    nn.Linear(in_channels, hidden_channels),
                    nn.ReLU(),
                    nn.Linear(hidden_channels, hidden_channels)
                ))
            else:
                self.mlps.append(nn.Sequential(
                    nn.Linear(hidden_channels, hidden_channels),
                    nn.ReLU(),
                    nn.Linear(hidden_channels, hidden_channels)
                ))
        
        # 图级READOUT
        self.classifier = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels)
        )
        
        # 可学习参数 eps
        self.eps = nn.Parameter(torch.randn(num_layers))
    
    def aggregate(self, x, edge_index):
        """邻接表聚合"""
        num_nodes = x.shape[0]
        out = torch.zeros_like(x)
        
        for v in range(num_nodes):
            neighbors = edge_index[1][edge_index[0] == v]
            if len(neighbors) > 0:
                out[v] = x[neighbors].sum(dim=0)
        
        return out
    
    def forward(self, x, edge_index):
        h = x
        
        for i in range(self.num_layers):
            # 消息传递
            neighbor_agg = self.aggregate(h, edge_index)
            
            # GIN更新规则:(1 + eps) * h_v + neighbor_sum
            h = self.mlps[i]((1 + self.eps[i]) * h + neighbor_agg)
        
        # 图级READOUT(求和)
        graph_repr = torch.zeros_like(h[0]).unsqueeze(0).repeat(h.shape[0], 1)
        for v in range(h.shape[0]):
            neighbors = edge_index[1][edge_index[0] == v]
            if len(neighbors) > 0:
                graph_repr[v] = h[neighbors].sum(dim=0)
            else:
                graph_repr[v] = h[v]
        
        # 使用第一个节点的表示作为图表示(实际应使用更好的READOUT)
        return self.classifier(h[0])

1-WL无法区分的图结构

1-WL测试存在已知的盲点,以下图结构是其无法区分的典型例子:

1. 同构但结构不同的图对

图A (C6: 6节点环)           图B (两个三角形+3条独立边)
     ○──○──○                     △     △     │
     │  ↑│  ↑                     \   /      |
     ○──○──○                       \ /       |
                                   ─┼─       │
                                     │       │

这两个图对于1-WL来说是”颜色相同”的。

2. 正则图(Regular Graphs)

1-WL无法区分度数相同的正则图:

图类型描述1-WL能否区分
完全图 所有节点度数为
所有节点度数为2
强正则图参数相同的SRG

3. 循环计数

1-WL可以检测长度 ≤ 6 的环,但无法区分长度 ≥ 8 的环

4. 子图计数

子结构1-WL能否计数
三角形
四边形
五边形
六边形
八边形

5. WL盲点的形式化例子

定理(Cai et al., 1992):存在无限多对 节点图,使得1-WL无法区分它们,且这些图在图同构意义上不同构。


K-WL层级与表达能力边界

k维Weisfeiler-Lehman算法

1-WL只考虑单个节点的颜色,而k-WL考虑k元组的颜色4

2-WL(Dyck测试)

2-WL不是1-WL的简单扩展,它作用于有序节点对

2-WL比1-WL严格更强,但仍无法区分所有图。

3-WL(Brueckner-Servatius)

3-WL作用于三元组,表达能力进一步提升。

K-WL层级关系

算法表达能力计算复杂度
1-WL
2-WL
k-WL递增
图同构最高未知(疑似准多项式)

不同GNN架构的WL层级

架构表达能力上界备注
GCN1-WL聚合函数非单射
GAT1-WL注意力权重不增加表达能力
GraphSAGE1-WL聚合函数非单射
GIN1-WL(最紧)单射聚合可达1-WL上界
k-GNNk-WL需要k阶邻接矩阵
2-WL GNN2-WL基于节点对的GNN
子图GNN3-WL详见下文

超越1-WL的方法

K跳消息传递

基本思想

传统1-WL MPNN只聚合1跳邻居的信息。K跳消息传递同时聚合K跳邻居的信息:

其中 表示距离节点 跳的邻居集合。

表达能力分析

定理(Feng et al., 2022):K跳消息传递的表达能力严格强于1-WL,但仍被3-WL上界限制。5

PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import k_hop_subgraph
 
class KHopMPNN(nn.Module):
    """K跳消息传递神经网络"""
    def __init__(self, in_channels, hidden_channels, out_channels, K=2):
        super().__init__()
        self.K = K
        
        # 每跳的投影层
        self.hop_proj = nn.ModuleList([
            nn.Linear(in_channels, hidden_channels)
            for _ in range(K)
        ])
        
        # 更新层
        self.update = nn.Sequential(
            nn.Linear(hidden_channels * K, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels)
        )
        
        self.classifier = nn.Linear(hidden_channels, out_channels)
    
    def forward(self, x, edge_index, batch=None):
        num_nodes = x.shape[0]
        hop_features = []
        
        # 收集每跳的信息
        for k in range(1, self.K + 1):
            if k == 1:
                # 1跳邻居
                neighbors = edge_index[1]
            else:
                # 计算k跳邻居
                # 简化的实现:多次扩展
                neighbors = self._get_k_hop_neighbors(x, edge_index, k)
            
            # 聚合k跳邻居的特征
            agg_feat = torch.zeros_like(x)
            for v in range(num_nodes):
                mask = edge_index[0] == v
                nbrs = neighbors[mask]
                if len(nbrs) > 0:
                    agg_feat[v] = x[nbrs].mean(dim=0)
            
            hop_features.append(self.hop_proj[k-1](agg_feat))
        
        # 拼接所有跳的信息
        combined = torch.cat(hop_features, dim=1)
        h = self.update(combined)
        
        # 图级READOUT
        graph_h = torch.zeros(h.shape[1]).to(x.device)
        for v in range(num_nodes):
            graph_h += h[v]
        graph_h /= num_nodes
        
        return self.classifier(graph_h.unsqueeze(0))
    
    def _get_k_hop_neighbors(self, x, edge_index, k):
        """获取k跳邻居的简化实现"""
        # 实际应使用更高效的实现
        current = edge_index[1].clone()
        seen = set(current.tolist())
        
        for _ in range(k - 1):
            new_neighbors = []
            for src in current.unique():
                mask = edge_index[0] == src
                nbrs = edge_index[1][mask]
                for n in nbrs:
                    if n.item() not in seen:
                        new_neighbors.append(n.item())
                        seen.add(n.item())
            current = torch.tensor(new_neighbors + current.tolist()).to(x.device)
        
        return current

子图GNN

基本思想

子图GNN将原始图的子结构作为输入,而不是整个图的结构。6

代表性架构

架构子图定义表达能力
SUN以节点为根的子图2-WL
DS以边为根的子图2-WL
ESAN节点对的连通子图3-WL
ID-GNN节点诱导子图依赖ID

子图GNN的消息传递

其中 是以节点 为中心的子图。

PyTorch实现(简化版)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import subgraph
 
class SubgraphGNN(nn.Module):
    """基于子图的GNN(简化版)"""
    def __init__(self, in_channels, hidden_channels, out_channels, 
                 subgraph_size=2):
        super().__init__()
        self.subgraph_size = subgraph_size
        
        # 子图内部的GNN
        self.subgraph_gnn = GIN(in_channels, hidden_channels, hidden_channels)
        
        # 图级聚合
        self.node_gnn = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels)
        )
        
        self.classifier = nn.Linear(hidden_channels, out_channels)
    
    def extract_subgraph(self, x, edge_index, center_node, size):
        """提取以center_node为中心的子图"""
        # 简单的BFS提取
        nodes = {center_node.item()}
        frontier = {center_node.item()}
        
        for _ in range(size - 1):
            new_frontier = set()
            for u in frontier:
                # 找到u的所有邻居
                mask = (edge_index[0] == u) | (edge_index[1] == u)
                neighbors = edge_index[0][mask].tolist() + edge_index[1][mask].tolist()
                new_frontier.update(neighbors)
            frontier = new_frontier - nodes
            nodes.update(frontier)
        
        return torch.tensor(list(nodes))
    
    def forward(self, x, edge_index):
        num_nodes = x.shape[0]
        subgraph_reprs = []
        
        # 为每个节点提取子图
        for v in range(num_nodes):
            subgraph_nodes = self.extract_subgraph(x, edge_index, 
                                                   torch.tensor([v]), 
                                                   self.subgraph_size)
            
            # 提取子图的边和节点特征
            sub_edge_index, _ = subgraph(subgraph_nodes, edge_index)
            sub_x = x[subgraph_nodes]
            
            # 在子图上运行GIN
            sub_repr = self.subgraph_gnn(sub_x, sub_edge_index)
            subgraph_reprs.append(sub_repr.mean(dim=0))
        
        # 聚合所有子图表示
        subgraph_reprs = torch.stack(subgraph_reprs)
        graph_repr = self.node_gnn(subgraph_reprs).mean(dim=0)
        
        return self.classifier(graph_repr.unsqueeze(0))

路径聚合

PATH-WL算法

路径聚合是一类基于路径信息增强表达能力的算法。7

核心思想:不仅考虑节点颜色,还考虑路径上节点的颜色序列。

表达能力

方法与WL层级的关系
1-WL1-WL
最短路径-WL与2-WL不可比较
所有路径-WL强于2-WL

环路计数增强

Count-GNN

通过显式计数特定子结构(如环、路径)来增强表达能力:

class CountGNN(nn.Module):
    """带子结构计数的GNN"""
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.gnn = GIN(in_channels, hidden_channels, hidden_channels)
        
        # 额外的环路计数特征
        self.count_proj = nn.Sequential(
            nn.Linear(10, hidden_channels),  # 假设计数10种不同的环
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels)
        )
        
        # 融合GNN特征和计数特征
        self.fusion = nn.Linear(hidden_channels * 2, hidden_channels)
        self.classifier = nn.Linear(hidden_channels, out_channels)
    
    def count_cycles(self, edge_index, num_nodes):
        """简化版环计数"""
        # 实际应使用更高效的算法
        cycle_counts = torch.zeros(10)
        # ... 环路计数逻辑 ...
        return cycle_counts
    
    def forward(self, x, edge_index):
        # GNN特征
        gnn_repr = self.gnn(x, edge_index)
        
        # 环路计数特征
        cycle_counts = self.count_cycles(edge_index, x.shape[0])
        count_repr = self.count_proj(cycle_counts)
        
        # 融合
        combined = torch.cat([gnn_repr, count_repr], dim=-1)
        fused = self.fusion(combined)
        
        return self.classifier(fused)

同态表达性框架

ICLR 2024 论文核心

论文:Beyond Weisfeiler-Lehman: A Quantitative Framework for GNN Expressiveness8

核心贡献:提出同态表达性(Homomorphism Expressivity)作为GNN表达能力的定量度量。

图同态基础

同态的定义

是两个图。同态 满足:

即保持边的映射关系。

同态计数的意义

的同态数:

定理(Cai-Fürer-Immerman):图的同态计数与WL层级密切相关:

同态类型WL可检测性
(三角形)1-WL 可检测
(四边形)1-WL 无法检测
(k-完全图)k-1-WL 可检测

同态表达性的定义

定量度量

对于GNN模型 ,其在图 上对模式 同态表达性定义为:

核心定理

定理(Zhang et al., 2024):GNN的同态计数能力具有层次结构

对于任意图 和模式

不同GNN架构的同态表达能力

架构同态表达能力精确计数能力
GIN可精确计数树状结构
GCN/GAT仅计数无圈结构
1-WL等价于GIN与GIN相同上界
k-GNN更高可计数k阶结构

GIN的同态计数能力

定理:GIN(带适当的READOUT)可以精确计数树结构的同态:

其中 是所有树结构的集合, 是可学习的权重。

同态表达性的优势

传统WL层级同态表达性
定性(能/不能区分)定量(计数多少)
粗糙的层级划分连续的值
无法衡量子结构检测能力可衡量任意子结构的检测能力
表达能力与实际任务脱节与实际任务性能相关

统一框架

同态表达性框架统一了多个研究方向:

                    同态表达性框架
                          │
        ┌─────────────────┼─────────────────┐
        │                 │                 │
   子图计数方法       k-WL层级方法      环路计数方法
        │                 │                 │
   (Bouritsas et al.)  (k-GNN)        (Count-GNN)
        │                 │                 │
        └─────────────────┼─────────────────┘
                          │
                    统一的表达能力度量

PyTorch实现示例

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class HomomorphismAwareGNN(nn.Module):
    """基于同态表达性的GNN(概念实现)"""
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        
        # 基础GNN
        self.gnn = GIN(in_channels, hidden_channels, hidden_channels)
        
        # 同态计数的可学习权重
        # 假设我们关注以下模式
        self.pattern_weights = nn.Parameter(torch.ones(10))  # 10种模式
        
        # 融合层
        self.fusion = nn.Sequential(
            nn.Linear(hidden_channels * 2, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels)
        )
    
    def count_homomorphisms(self, edge_index, num_nodes, pattern_id):
        """
        简化的同态计数
        实际应用中应使用更高效的算法
        """
        # 这是一个概念示例
        # 实际需要根据pattern_id计算具体的同态数
        if pattern_id == 0:  # 三角形
            return self._count_triangles(edge_index, num_nodes)
        elif pattern_id == 1:  # 四边形
            return self._count_cycles(edge_index, num_nodes, 4)
        # ... 其他模式
        return torch.zeros(num_nodes)
    
    def _count_triangles(self, edge_index, num_nodes):
        """计算每个节点参与多少个三角形"""
        triangles = torch.zeros(num_nodes)
        for i in range(edge_index.shape[1]):
            u, v = edge_index[0, i], edge_index[1, i]
            # 找u和v的公共邻居
            u_neighbors = edge_index[1][edge_index[0] == u]
            v_neighbors = edge_index[1][edge_index[0] == v]
            common = set(u_neighbors.tolist()) & set(v_neighbors.tolist())
            for w in common:
                triangles[u] += 1
                triangles[v] += 1
                triangles[w] += 1
        return triangles
    
    def _count_cycles(self, edge_index, num_nodes, length):
        """简化的环计数"""
        # 实际应用需要更复杂的实现
        return torch.zeros(num_nodes)
    
    def forward(self, x, edge_index):
        # 基础GNN表示
        gnn_repr = self.gnn(x, edge_index)
        
        # 计算各模式的同态数
        hom_counts = []
        for p in range(10):
            count = self.count_homomorphisms(edge_index, x.shape[0], p)
            hom_counts.append(count * self.pattern_weights[p])
        
        # 拼接同态计数
        hom_repr = torch.stack(hom_counts).t()  # (num_nodes, num_patterns)
        hom_repr = F.relu(self.pattern_weights[:1].expand_as(hom_repr))  # 投影
        
        # 融合GNN表示和同态表示
        combined = torch.cat([gnn_repr, hom_repr], dim=-1)
        output = self.fusion(combined)
        
        return output

表达能力与真实任务的关系

何时需要超越1-WL?

不需要超越1-WL的场景

任务类型示例原因
节点分类Cora引用网络仅需局部结构
链接预测推荐系统邻居信息足够
大多数真实应用分子性质预测1-WL足够区分关键结构

需要超越1-WL的场景

任务类型示例需要检测的结构
环路检测分子中的苯环 检测
子图匹配化学基团检测特定子图
图同构相关精确图匹配全局结构
计数任务三角形计数精确计数

表达能力与泛化的权衡

理论发现

核心问题:更强的表达能力是否意味着更好的泛化?

研究结论:表达能力与泛化之间存在复杂的权衡关系9

表达能力过强的风险

  1. 过拟合:模型可能记忆训练数据而非学习泛化模式
  2. 优化困难:高表达能力模型可能更难训练
  3. 计算成本:更复杂的架构通常需要更多计算

表达能力不足的风险

  1. 欠拟合:模型无法捕捉数据中的重要模式
  2. 偏差:即使数据无限,模型也存在系统性错误

经验观察

表达能力训练数据充足训练数据有限
可能过拟合推荐使用
可能欠拟合可能刚好

过平滑与过压缩

过平滑(Over-Smoothing)

随着GNN层数增加,所有节点的表示趋于相同:

原因:多次邻居聚合导致信息损失。

解决

  • 残差连接(ResNet)
  • 跳跃知识网络(JK-Net)
  • 适当限制层数

过压缩(Over-Squashing)

信息从多个源头压缩到固定大小的向量中:

其中 是邻居数, 越大,信息损失越多。10

解决

  • 图Transformer(无固定邻居限制)
  • 远程建模
  • 邻居采样
# 过平滑的可视化示例
import matplotlib.pyplot as plt
 
def measure_smoothing(model, data):
    """测量节点表示的平滑程度"""
    model.eval()
    with torch.no_grad():
        h = model(data.x, data.edge_index)
    
    # 计算节点表示的距离
    pairwise_dist = torch.pdist(h)
    return pairwise_dist.mean().item()
 
# 层数增加时,平滑度应该下降
for num_layers in [1, 2, 3, 4, 5]:
    model = GIN(in_channels, hidden, out_channels, num_layers=num_layers)
    smoothing = measure_smoothing(model, data)
    print(f"层数 {num_layers}: 平滑度 = {smoothing:.4f}")

实践建议

如何选择合适的表达能力

考虑因素建议
任务类型结构敏感任务需要更高表达能力
数据规模大数据可用高表达能力模型
计算资源复杂模型需要更多资源
可解释性简单模型更易解释

评估指标

def evaluate_expressivity(model, dataset):
    """
    评估模型的实际表达能力
    """
    results = {}
    
    # 1. 结构计数任务
    results['cycle_count'] = evaluate_cycle_counting(model, dataset)
    
    # 2. 子图匹配任务
    results['subgraph_match'] = evaluate_subgraph_matching(model, dataset)
    
    # 3. 图同构区分
    results['graph_distinguish'] = evaluate_graph_distinction(model, dataset)
    
    # 4. 泛化性能
    results['generalization'] = evaluate_generalization(model, dataset)
    
    return results

架构选择指南

┌─────────────────────────────────────────────────────────┐
│                    GNN架构选择流程                       │
├─────────────────────────────────────────────────────────┤
│                                                         │
│  1. 评估任务需求                                         │
│     │                                                  │
│     ├── 局部结构信息 → 标准MPNN (GCN/GAT/GraphSAGE)    │
│     │                                                  │
│     ├── 需要子图信息 → 子图GNN                           │
│     │                                                  │
│     └── 需要全局信息 → 深层MPNN + 残差连接               │
│                                                         │
│  2. 考虑数据规模                                         │
│     │                                                  │
│     ├── 小规模数据 → 可用高表达能力模型                  │
│     │                                                  │
│     └── 大规模数据 → 考虑效率与表达的平衡                │
│                                                         │
│  3. 评估计算资源                                         │
│     │                                                  │
│     ├── 资源充足 → 可尝试Transformer                    │
│     │                                                  │
│     └── 资源有限 → 标准MPNN + 优化                      │
│                                                         │
└─────────────────────────────────────────────────────────┘

相关主题

主题描述
图神经网络GNN的基本概念和消息传递范式
图卷积网络GCN的谱域和空域方法
神经网络表达能力通用逼近定理与VC维度

参考

Footnotes

  1. Xu et al., “How Powerful Are Graph Neural Networks?”, ICLR 2019 2

  2. Morris et al., “Weisfeiler and Leman Go Neural: Higher-Order Graph Neural Networks”, AAAI 2019 2

  3. Weisfeiler & Lehman, “A Reduction of a Graph to a Canonical Form”, 1968

  4. Grohe, “The Weisfeiler-Leman Dimension of Graphs below Gensemer”, ICALP 2017

  5. Feng et al., “How Powerful are K-hop Message Passing Graph Neural Networks”, NeurIPS 2022

  6. Frasca et al., “Sign: Scalable Inception Graph Neural Networks”, GRL 2022

  7. Graziani et al., “The Expressive Power of Path-Based Graph Neural Networks”, ICML 2024

  8. Zhang et al., “Beyond Weisfeiler-Lehman: A Quantitative Framework for GNN Expressiveness”, ICLR 2024

  9. Wang et al., “An Empirical Study of Realized GNN Expressiveness”, ICML 2024

  10. Di Giovanni et al., “How Does Over-Squashing Affect the Power of GNNs?”, TMLR 2024