1. 引言

消息传递(Message Passing)是图神经网络(GNN)的核心计算范式,将深度学习的表示学习能力与图结构数据的拓扑信息有机结合。自Gilmer等人于2017年提出”消息传递神经网络”(Message Passing Neural Networks, MPNN)框架以来1,消息传递已成为GNN领域最重要的理论基石之一。

1.1 消息传递的背景

传统的深度学习模型(如CNN、RNN)主要处理规则结构的数据(图像、序列),而图数据具有非欧几里得特性:

  • 节点无序性:节点的邻居没有固定顺序
  • 结构多样性:图的结构可以任意变化
  • 关系复杂性:边可以带权、重边、方向等属性

消息传递机制通过局部信息交换的方式,让每个节点能够感知其邻域的结构和特征信息,从而实现对图数据的深度表示学习。

1.2 与其他GNN方法的关系

┌─────────────────────────────────────────────────────────────────────────┐
│                      图神经网络方法分类                                    │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  ┌─────────────────────────┐    ┌─────────────────────────┐           │
│  │       谱域方法           │    │       空域方法           │           │
│  │  (Spectral Methods)      │    │  (Spatial Methods)       │           │
│  ├─────────────────────────┤    ├─────────────────────────┤           │
│  │  • GCN (谱卷积近似)      │    │  • Message Passing (核心)│           │
│  │  • ChebNet (切比雪夫多项式)│    │  • GraphSAGE            │           │
│  │  • GIN (理论分析)        │    │  • GAT (注意力机制)      │           │
│  └─────────────────────────┘    └─────────────────────────┘           │
│                                                                         │
│                        Message Passing 是空域方法的核心抽象                  │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

2. 消息传递框架形式化

2.1 MPNN通用框架

Gilmer等人提出的消息传递神经网络框架是GNN领域最具影响力的形式化定义之一1。整个框架可以分解为三个核心操作:消息生成消息聚合节点更新

2.1.1 单层消息传递的定义

为一个图,其中 是节点集合, 是边集合。对于节点 ,记其邻居节点集合为

消息传递层(Message Passing Layer)定义为:

其中:

符号含义
层节点 的隐藏状态
的特征向量
消息函数,生成从邻居到中心节点的消息
聚合操作(可以是求和、均值、最大值等)
更新函数,结合自身状态和聚合消息更新节点状态

2.1.2 消息传递的迭代过程

完整的 层消息传递网络可以形式化为:

2.2 消息传递的计算图

┌─────────────────────────────────────────────────────────────────────────┐
│                    单层消息传递计算图                                      │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  Layer l                          Layer l+1                             │
│                                                                         │
│      h_u^(l) ──┐                        ┌──→ h_v^(l+1)                   │
│                │                        │                               │
│      h_w^(l) ──┼──→ Message ──→ AGG ───┤                               │
│                │    Function     │      │                               │
│      h_x^(l) ──┘           ↓       │      │                               │
│                         Aggregate      │                               │
│                           │           │                               │
│                      ┌────┴────┐       │                               │
│                      │  a_v^(l) │ ─────┘                               │
│                      └─────────┘                                       │
│                           │                                            │
│                           ↓                                            │
│                      ┌─────────┐                                        │
│                      │ Update  │                                        │
│                      └─────────┘                                        │
│                                                                         │
│  消息函数: m_{u→v} = Message(h_u, h_v, e_uv)                            │
│  聚合操作: a_v = ⊕_{u∈N(v)} m_{u→v}                                     │
│  更新函数: h_v' = Update(h_v, a_v)                                      │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

2.3 与图卷积网络的关系

消息传递框架是图卷积网络(GCN)的更一般化表示。以GCN为例:

GCN的传播规则

这可以分解为:

组件GCN通用消息传递
消息函数
聚合操作
更新函数

3. 消息函数设计

消息函数(Message Function)是消息传递的核心组件,负责将邻居节点的信息转换为可以传递给目标节点的消息形式。消息函数的设计直接影响GNN的表达能力和信息流动方式。

3.1 线性消息函数

最简单的消息函数是线性变换,直接对邻居特征进行线性投影:

其中 是可学习的权重矩阵。

代码实现

import torch
import torch.nn as nn
 
class LinearMessageFunction(nn.Module):
    """线性消息函数"""
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.linear = nn.Linear(in_channels, out_channels)
    
    def forward(self, x_j, x_i=None, edge_attr=None):
        """
        Args:
            x_j: 源节点特征 (num_edges, in_channels)
            x_i: 目标节点特征 (num_edges, in_channels) - 可选
            edge_attr: 边特征 (num_edges, edge_channels) - 可选
        Returns:
            消息 (num_edges, out_channels)
        """
        return self.linear(x_j)

3.2 带边信息的线性消息

当边带有额外属性(如关系类型、权重)时,可以将其融入消息函数:

或使用更复杂的交互:

其中 表示向量拼接, 是多层感知机。

代码实现

class EdgeAwareMessageFunction(nn.Module):
    """带边信息的消息函数"""
    
    def __init__(self, node_channels, edge_channels, out_channels):
        super().__init__()
        self.node_proj = nn.Linear(node_channels, out_channels)
        self.edge_proj = nn.Linear(edge_channels, out_channels)
    
    def forward(self, x_j, x_i=None, edge_attr=None):
        if edge_attr is not None:
            msg = self.node_proj(x_j) + self.edge_proj(edge_attr)
        else:
            msg = self.node_proj(x_j)
        return msg
 
 
class ConcatenationMessageFunction(nn.Module):
    """拼接型消息函数"""
    
    def __init__(self, node_channels, edge_channels, hidden_channels):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(node_channels + edge_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels)
        )
    
    def forward(self, x_j, x_i=None, edge_attr=None):
        if edge_attr is not None:
            x = torch.cat([x_j, edge_attr], dim=-1)
        else:
            x = x_j
        return self.mlp(x)

3.3 基于注意力机制的消息

注意力机制允许模型为不同邻居分配不同的重要性,是现代GNN架构的核心组件。

3.3.1 基础注意力消息

其中注意力系数:

3.3.2 多头注意力

为了增强模型的稳定性和表达能力,通常使用多头注意力机制:

代码实现

class AttentionMessageFunction(nn.Module):
    """基于注意力机制的消息函数"""
    
    def __init__(self, in_channels, out_channels, heads=4, concat=True):
        super().__init__()
        self.heads = heads
        self.concat = concat
        self.out_channels = out_channels
        
        self.head_dim = out_channels // heads if concat else out_channels
        self.linear = nn.Linear(in_channels, self.head_dim * heads)
        self.att = nn.Parameter(torch.Tensor(1, heads, 2 * self.head_dim))
        
        nn.init.xavier_uniform_(self.att)
    
    def forward(self, x_j, x_i, edge_attr=None):
        # x_j: (num_edges, in_channels) - 源节点
        # x_i: (num_edges, in_channels) - 目标节点
        
        h_j = self.linear(x_j)  # (num_edges, heads * head_dim)
        h_i = self.linear(x_i)  # (num_edges, heads * head_dim)
        
        batch_size = h_j.size(0)
        h_j = h_j.view(batch_size, self.heads, self.head_dim)
        h_i = h_i.view(batch_size, self.heads, self.head_dim)
        
        # 拼接源节点和目标节点
        combined = torch.cat([h_j, h_i], dim=-1)  # (num_edges, heads, 2*head_dim)
        
        # 计算注意力分数
        att_score = (combined * self.att).sum(dim=-1)  # (num_edges, heads)
        att_score = F.leaky_relu(att_score, 0.2)
        
        # 跨边归一化(需要后续 softmax)
        return att_score, h_j
 
 
class MultiHeadAttentionMessage(nn.Module):
    """多头注意力消息函数(完整实现)"""
    
    def __init__(self, in_channels, out_channels, heads=4, edge_dim=None):
        super().__init__()
        self.heads = heads
        self.head_dim = out_channels // heads
        assert out_channels % heads == 0
        
        self.linear_v = nn.Linear(in_channels, out_channels)
        self.linear_qk = nn.Linear(in_channels, out_channels)
        
        if edge_dim is not None:
            self.edge_proj = nn.Linear(edge_dim, out_channels)
        
        self.att = nn.Parameter(torch.Tensor(1, heads, 2 * self.head_dim))
        nn.init.xavier_uniform_(self.att)
    
    def forward(self, x_j, x_i, edge_attr=None):
        # V: 消息内容
        v = self.linear_v(x_j)
        # Q/K: 用于计算注意力
        q = self.linear_qk(x_i)
        k = self.linear_qk(x_j)
        
        if edge_attr is not None and hasattr(self, 'edge_proj'):
            k = k + self.edge_proj(edge_attr)
        
        # Reshape for multi-head attention
        batch_size = x_j.size(0)
        v = v.view(batch_size, self.heads, self.head_dim)
        q = q.view(batch_size, self.heads, self.head_dim)
        k = k.view(batch_size, self.heads, self.head_dim)
        
        # 计算注意力分数
        qk = torch.cat([q, k], dim=-1)
        att = (qk * self.att).sum(dim=-1)
        att = F.leaky_relu(att, 0.2)
        
        return v, att

3.4 多关系消息函数

在异构图(知识图谱、多关系图)中,不同类型的边需要不同的消息函数:

其中 是关系类型 对应的权重矩阵。

代码实现

class MultiRelationMessageFunction(nn.Module):
    """多关系消息函数(用于异构图)"""
    
    def __init__(self, in_channels, out_channels, num_relations):
        super().__init__()
        self.W_r = nn.Parameter(torch.Tensor(num_relations, in_channels, out_channels))
        nn.init.xavier_uniform_(self.W_r)
    
    def forward(self, x_j, edge_type=None):
        """
        Args:
            x_j: 源节点特征 (num_edges, in_channels)
            edge_type: 边类型 (num_edges,)
        """
        if edge_type is None:
            # 如果没有边类型,使用共享权重
            return torch.matmul(x_j.unsqueeze(1), self.W_r[0]).squeeze(1)
        
        # 为每条边选择对应的关系矩阵
        W = self.W_r[edge_type]  # (num_edges, in_channels, out_channels)
        return torch.bmm(x_j.unsqueeze(1), W).squeeze(1)  # (num_edges, out_channels)

4. 聚合函数设计

聚合函数(Aggregation Function)负责将来自多个邻居的消息合并为单一表示,是消息传递框架中最关键的组件之一。不同的聚合方式决定了GNN捕获不同图结构信息的能力。

4.1 基础聚合操作

4.1.1 求和聚合(Sum Pooling)

特点

  • 保留所有邻居的信息
  • 对邻居数量不敏感
  • 计算效率高
def sum_aggregate(messages):
    """求和聚合"""
    return torch.sum(messages, dim=0)

4.1.2 均值聚合(Mean Pooling)

特点

  • 对邻居数量敏感
  • 有利于学习邻居特征的分布信息
  • 常与归一化结合使用
def mean_aggregate(messages):
    """均值聚合"""
    return torch.mean(messages, dim=0)

4.1.3 最大池化聚合(Max Pooling)

特点

  • 捕获最显著的特征
  • 对异常值鲁棒
  • 不保留完整的邻居分布信息
def max_aggregate(messages):
    """最大池化聚合"""
    return torch.max(messages, dim=0)[0]

4.2 注意力聚合

注意力聚合通过学习权重来加权组合邻居消息:

其中 是可学习的注意力系数。

代码实现

class AttentionAggregation(nn.Module):
    """注意力聚合函数"""
    
    def __init__(self, message_dim, head_dim):
        super().__init__()
        self.alpha = nn.Sequential(
            nn.Linear(2 * message_dim, head_dim),
            nn.LeakyReLU(),
            nn.Linear(head_dim, 1)
        )
    
    def forward(self, messages, targets):
        """
        Args:
            messages: 邻居消息 (num_neighbors, message_dim)
            targets: 目标节点特征 (message_dim)
        Returns:
            聚合结果 (message_dim,)
        """
        num_neighbors = messages.size(0)
        # 扩展目标节点以匹配邻居数量
        targets_expanded = targets.unsqueeze(0).expand(num_neighbors, -1)
        
        # 计算注意力系数
        combined = torch.cat([messages, targets_expanded], dim=-1)
        alpha = self.alpha(combined)  # (num_neighbors, 1)
        alpha = F.softmax(alpha, dim=0)
        
        # 加权求和
        return (alpha * messages).sum(dim=0)

4.3 Set Pooling(集合池化)

由于图中的邻居集合是无序的,理想的聚合函数应该具有置换不变性(Permutation Invariant)。Set Pooling通过神经网络实现这一点:

4.3.1 Deep Sets

代码实现

class DeepSetsAggregation(nn.Module):
    """Deep Sets 聚合 - 置换不变聚合"""
    
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(in_channels, hidden_channels),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU()
        )
    
    def forward(self, messages):
        """
        Args:
            messages: 邻居消息 (num_neighbors, in_channels)
        Returns:
            聚合结果 (hidden_channels,)
        """
        # Encode each message
        encoded = self.encoder(messages)  # (num_neighbors, hidden_channels)
        
        # Sum pooling (置换不变操作)
        summed = torch.sum(encoded, dim=0)  # (hidden_channels,)
        
        # Decode
        return self.decoder(summed)

4.3.2 Set Transformer

使用注意力机制实现集合聚合,具有更强的表达能力:

代码实现

class SetTransformerAggregation(nn.Module):
    """Set Transformer 聚合"""
    
    def __init__(self, dim, num_heads=4, num_seeds=1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.feed_forward = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.ReLU(),
            nn.Linear(dim * 4, dim)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.num_seeds = num_seeds
        
        # 可学习的种子向量
        if num_seeds > 0:
            self.seeds = nn.Parameter(torch.randn(num_seeds, dim))
    
    def forward(self, messages, target=None):
        """
        Args:
            messages: 邻居消息 (num_neighbors, dim)
            target: 目标节点特征 (dim,) - 可选
        Returns:
            聚合结果 (dim,) 或 (num_seeds, dim)
        """
        num_neighbors = messages.size(0)
        
        if self.num_seeds > 0:
            # 使用可学习种子作为Query
            queries = self.seeds.unsqueeze(0)  # (1, num_seeds, dim)
            keys = messages.unsqueeze(0)  # (1, num_neighbors, dim)
            values = messages.unsqueeze(0)  # (1, num_neighbors, dim)
        else:
            # 使用目标节点作为Query
            queries = target.unsqueeze(0).unsqueeze(0)  # (1, 1, dim)
            keys = messages.unsqueeze(0)  # (1, num_neighbors, dim)
            values = messages.unsqueeze(0)  # (1, num_neighbors, dim)
        
        # Self-attention
        attn_out, _ = self.self_attn(queries, keys, values)
        x = queries + attn_out
        x = self.norm1(x)
        
        # Feed-forward
        ff_out = self.feed_forward(x)
        x = x + ff_out
        x = self.norm2(x)
        
        return x.squeeze(0)  # (num_seeds, dim) 或 (dim,)

4.4 Sort Pooling

Sort Pooling通过将邻居特征排序来实现置换不变性:

  1. 将邻居特征排序
  2. 选择前 个(或全部)邻居
  3. 应用卷积或全连接层

代码实现

class SortPooling(nn.Module):
    """Sort Pooling 聚合"""
    
    def __init__(self, k):
        super().__init__()
        self.k = k
    
    def forward(self, messages, dim=0):
        """
        Args:
            messages: 邻居消息 (num_neighbors, feature_dim)
            dim: 排序维度
        Returns:
            排序后并截取的结果 (k, feature_dim) 或 (num_neighbors, feature_dim)
        """
        # 排序(按所有特征的字典序)
        sorted_messages, _ = torch.sort(messages, dim=dim)
        
        # 截取前k个
        if self.k < messages.size(dim):
            if dim == 0:
                return sorted_messages[:self.k]
            else:
                return sorted_messages[:, :self.k]
        return sorted_messages

4.5 归一化考量

在实际应用中,聚合操作通常需要归一化以保证数值稳定性和训练效率:

归一化类型公式适用场景
节点度归一化防止度数大的节点主导
添加自环保留节点自身信息
L2归一化特征长度归一化
Batch归一化训练稳定性

5. 更新函数设计

更新函数(Update Function)将聚合后的邻居信息与节点当前状态结合,生成新的节点表示。

5.1 基础更新函数

5.1.1 线性更新

其中 表示向量拼接, 是非线性激活函数。

class LinearUpdate(nn.Module):
    """线性更新函数"""
    
    def __init__(self, hidden_channels):
        super().__init__()
        self.linear = nn.Linear(2 * hidden_channels, hidden_channels)
        self.activation = nn.ReLU()
    
    def forward(self, h, aggregated):
        combined = torch.cat([h, aggregated], dim=-1)
        return self.activation(self.linear(combined))

5.1.2 GRU更新

使用门控循环单元(Gated Recurrent Unit)进行更新:

class GRUUpdate(nn.Module):
    """GRU更新函数"""
    
    def __init__(self, hidden_channels):
        super().__init__()
        self.gru = nn.GRUCell(hidden_channels, hidden_channels)
    
    def forward(self, h, aggregated):
        """
        Args:
            h: 当前隐藏状态 (batch, hidden_channels)
            aggregated: 聚合消息 (batch, hidden_channels)
        """
        return self.gru(aggregated, h)

5.2 残差连接与层归一化

为了支持更深的网络架构,常使用残差连接和层归一化:

class ResidualUpdate(nn.Module):
    """带残差连接的更新函数"""
    
    def __init__(self, hidden_channels, dropout=0.1):
        super().__init__()
        self.update = LinearUpdate(hidden_channels)
        self.norm = nn.LayerNorm(hidden_channels)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, h, aggregated):
        h_new = self.update(h, aggregated)
        h_new = self.dropout(h_new)
        return self.norm(h + h_new)

5.3 动态更新

根据任务需求,更新函数可以采用不同的策略:

更新策略公式应用场景
保留自身标准GNN
替换自身GraphSAGE
门控更新序列建模
循环更新记忆增强

6. 表达能力分析

6.1 WL图同构测试

Weisfeiler-Lehman(WL)测试是分析GNN表达能力的重要工具。Xu等人证明2GNN的表达能力上界是1-WL测试(颜色细化算法)

6.1.1 WL测试回顾

1-WL测试(也称为颜色细化算法)的工作流程:

初始化:所有节点着相同颜色
迭代:
    1. 每个节点收集邻居颜色集合
    2. 根据自身颜色和邻居颜色集合分配新颜色
    3. 如果颜色分布稳定,则停止

6.1.2 GNN与WL测试的关系

关键定理:如果GNN中的聚合函数是单射的(injective),则该GNN的表达能力与1-WL测试等价。

这意味着:

  • 具有单射聚合的GNN可以区分任何1-WL测试能区分的图
  • 任何无法被1-WL测试区分的图对,也无法被此类GNN区分

6.1.3 什么可以学习?

满足单射性的GNN可以区分:

结构类型示例可区分性
不同度分布 vs
不同子图结构三角形 vs 路径
节点角色中心节点 vs 边缘节点

6.1.4 什么不能学习?

1-WL测试无法区分的结构,GNN也无法区分:

┌─────────────────────────────────────────────────────────────────────────┐
│                    1-WL测试无法区分的典型例子                              │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  例1:同构子图                                                         │
│  ┌──────────────┐     ┌──────────────┐                                 │
│  │   ●──●──●    │     │   ●──●──●    │  完全相同的多重集                 │
│  │   │  │  │    │  =  │   │  │  │    │  → WL测试无法区分                  │
│  │   ●──●──●    │     │   ●──●──●    │  → GNN也无法区分                   │
│  └──────────────┘     └──────────────┘                                 │
│                                                                         │
│  例2:不同环结构                                                        │
│  ┌──────────────┐     ┌──────────────┐                                 │
│  │    ●──●      │     │    ●──●      │  邻居都是 2 度节点                 │
│  │   /│\ │      │     │   / \ │      │  → WL测试无法区分                  │
│  │  ● ● ● ●     │  ≠  │  ●   ● ●     │  → GNN也无法区分                   │
│  │   \│/ │      │     │   \ │/       │                                 │
│  │    ●──●      │     │    ●──●      │                                 │
│  └──────────────┘     └──────────────┘                                 │
│       六边形               桶形结构                                       │
│                                                                         │
│  例3:距离信息丢失                                                      │
│  ●───────●───────●                                                      │
│  vs                                                                        │
│  ●─────────●─────────●                                                  │
│                                                                         │
│  距离不同但邻居多重集相同                                                 │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

6.2 提升表达能力的方法

6.2.1 GIN(Graph Isomorphism Network)

Xu等人证明2Graph Isomorphism Network (GIN) 是表达能力最强的GNN之一:

其中 必须是单射的。

代码实现

class GINConv(nn.Module):
    """GIN卷积层 - 理论上最具表达能力的GNN"""
    
    def __init__(self, in_channels, out_channels, eps=0.0, train_eps=True):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
        
        if train_eps:
            self.eps = nn.Parameter(torch.tensor(eps))
        else:
            self.eps = eps
    
    def forward(self, x, edge_index):
        """
        Args:
            x: 节点特征 (num_nodes, in_channels)
            edge_index: 边索引 (2, num_edges)
        """
        row, col = edge_index
        
        # 邻居求和
        out = torch.zeros_like(x)
        out.index_add_(0, col, x[row])
        
        # 添加自环和可学习的epsilon
        out = (1 + self.eps) * x + out
        
        return self.mlp(out)

6.2.2 距离感知GNN

为了捕获节点间的距离信息,可以使用:

方法核心思想代表工作
ID-GNN循环消息传递中保留节点身份3
Distance Encoding将距离编码为特征4
Subgraph GNN在子图级别进行消息传递5
-WL测试使用更强大的WL变体6
class DistanceAwareMessage(nn.Module):
    """距离感知消息函数"""
    
    def __init__(self, hidden_channels):
        super().__init__()
        self.message = nn.Linear(hidden_channels + hidden_channels, hidden_channels)
    
    def forward(self, x_j, x_i, edge_index):
        """
        x_j: 源节点特征
        x_i: 目标节点特征
        edge_index: 边索引(需要额外计算距离)
        """
        # 简化示例:实际需要计算节点间的距离
        # 这里假设 edge_attr 已包含距离信息
        return self.message(torch.cat([x_j, x_i], dim=-1))

7. 高级消息传递变体

7.1 Jump Knowledge Networks (JK-Net)

核心思想:在最后一层之前,将所有中间层的表示连接起来,充分利用不同感知域的信息。

其中 可以是:

  • Concat:[所有层拼接]
  • Max:[逐位置取最大值]
  • LSTM:[通过LSTM选择最重要信息]
class JKNet(nn.Module):
    """Jump Knowledge Networks"""
    
    def __init__(self, hidden_channels, num_layers, mode='max'):
        super().__init__()
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        
        for i in range(num_layers):
            self.convs.append(GINConv(hidden_channels, hidden_channels))
            self.bns.append(nn.BatchNorm1d(hidden_channels))
        
        self.mode = mode
        
        if mode == 'lstm':
            self.lstm = nn.LSTM(
                hidden_channels, hidden_channels // 2,
                bidirectional=True, batch_first=True
            )
    
    def forward(self, x, edge_index):
        xs = []
        
        for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            xs.append(x)
        
        if self.mode == 'concat':
            return torch.cat(xs, dim=-1)
        elif self.mode == 'max':
            return torch.stack(xs, dim=0).max(dim=0)[0]
        elif self.mode == 'lstm':
            x = torch.stack(xs, dim=1)  # (N, num_layers, hidden)
            out, _ = self.lstm(x)
            return out[:, -1, :]
        else:
            return xs[-1]

7.2 有向消息传递 (Directed Message Passing)

在分子图中,边的方向(如化学键方向)携带重要信息。有向消息传递显式建模方向性:

其中 表示边的方向类型。

class DirectedMessagePassing(nn.Module):
    """有向消息传递"""
    
    def __init__(self, hidden_channels, num_directions=2):
        super().__init__()
        # 两种方向:正向和反向
        self.W_forward = nn.Linear(hidden_channels, hidden_channels)
        self.W_backward = nn.Linear(hidden_channels, hidden_channels)
    
    def forward(self, x, edge_index):
        row, col = edge_index
        
        # 正向消息(row → col)
        msg_forward = self.W_forward(x[row])
        # 反向消息(col → row)
        msg_backward = self.W_backward(x[col])
        
        # 聚合
        out = torch.zeros_like(x)
        out.index_add_(0, col, msg_forward)
        out.index_add_(0, row, msg_backward)
        
        return out
 
 
class BondDirectionMessage(nn.Module):
    """化学键方向感知消息传递(用于分子图)"""
    
    def __init__(self, hidden_channels, num_bond_types):
        super().__init__()
        self.W = nn.ModuleDict({
            'single': nn.Linear(hidden_channels, hidden_channels),
            'double': nn.Linear(hidden_channels, hidden_channels),
            'triple': nn.Linear(hidden_channels, hidden_channels),
            'aromatic': nn.Linear(hidden_channels, hidden_channels),
        })
    
    def forward(self, x, edge_index, bond_types):
        """
        bond_types: List of bond types for each edge
        """
        row, col = edge_index
        messages = []
        
        for i in range(len(row)):
            src, dst = row[i].item(), col[i].item()
            bond = bond_types[i]
            
            msg = self.W[bond](x[src])
            messages.append(msg)
        
        return torch.stack(messages)

7.3 边缘卷积 (Edge Convolution)

DYNAMIC EDGE CONVOLUTION(也称EdgeConv)由Wang等人提出7,是点云处理中常用的方法。

7.3.1 核心思想

EdgeConv不是聚合邻居的消息,而是为每个节点计算其与每个邻居的边特征

其中 捕获了节点间的相对位置/特征差异。

7.3.2 k-NN vs 全图连接

EdgeConv可以使用 近邻(-NN)构建局部邻域,而非仅使用图结构中的边:

class EdgeConv(nn.Module):
    """边缘卷积 (EdgeConv)"""
    
    def __init__(self, in_channels, out_channels, k=20):
        super().__init__()
        self.k = k
        self.mlp = nn.Sequential(
            nn.Linear(in_channels * 2, out_channels),
            nn.ReLU(),
            nn.BatchNorm1d(out_channels),
            nn.Linear(out_channels, out_channels),
            nn.ReLU(),
            nn.BatchNorm1d(out_channels)
        )
    
    def knn(self, x):
        """计算k近邻"""
        pairwise_dist = torch.cdist(x, x, p=2)
        _, indices = pairwise_dist.topk(self.k, largest=False)
        return indices
    
    def forward(self, x, edge_index=None):
        """
        Args:
            x: 节点特征 (N, C)
            edge_index: 可选的预定义边(如果为None则使用k-NN)
        """
        N = x.size(0)
        
        if edge_index is None:
            # 使用k-NN构建邻域
            indices = self.knn(x)  # (N, k)
        else:
            indices = edge_index[1].view(N, -1)
        
        # 获取邻居特征
        neighbor_features = x[indices]  # (N, k, C)
        
        # 拼接中心节点特征(与每个邻居的特征差)
        center_features = x.unsqueeze(1).expand(-1, self.k, -1)  # (N, k, C)
        diff = neighbor_features - center_features  # (N, k, C)
        
        # 边特征
        edge_features = torch.cat([center_features, diff], dim=-1)  # (N, k, 2C)
        
        # 通过MLP处理每条边
        edge_features = edge_features.view(-1, 2 * x.size(1))  # (N*k, 2C)
        edge_features = self.mlp(edge_features)  # (N*k, C)
        edge_features = edge_features.view(N, self.k, -1)  # (N, k, C)
        
        # 聚合(最大池化)
        aggregated = edge_features.max(dim=1)[0]  # (N, C)
        
        return aggregated + x  # 残差连接

8. PyTorch Geometric 实现

PyTorch Geometric(PyG)是目前最流行的GNN框架,提供了丰富的消息传递API。

8.1 MessagePassing 基类

PyG的 MessagePassing 类封装了消息传递的核心逻辑:

from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
 
class MyMessagePassingLayer(MessagePassing):
    """自定义消息传递层"""
    
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  # 默认聚合方式
        self.lin = nn.Linear(in_channels, out_channels)
    
    def forward(self, x, edge_index):
        # 添加自环,使节点能够获取自身信息
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        # 调用propagate开始消息传递
        return self.propagate(edge_index, x=x)
    
    def message(self, x_j, x_i):
        """
        消息函数
        x_j: 源节点特征
        x_i: 目标节点特征
        """
        return self.lin(x_j)
    
    def update(self, updated):
        """
        更新函数
        updated: 聚合后的结果
        """
        return updated

8.2 完整 GCN 层实现

from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
 
class GCNConv(MessagePassing):
    """GCN卷积层的PyG实现"""
    
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')
        self.lin = nn.Linear(in_channels, out_channels)
    
    def forward(self, x, edge_index):
        # 添加自环
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        # 计算度归一化系数
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        
        # 线性变换
        x = self.lin(x)
        
        # 消息传递
        return self.propagate(edge_index, x=x, norm=norm)
    
    def message(self, x_j, norm):
        # 消息函数:节点特征 × 归一化系数
        return norm.view(-1, 1) * x_j

8.3 GAT 层实现

from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
 
class GATConv(MessagePassing):
    """GAT卷积层的PyG实现"""
    
    def __init__(self, in_channels, out_channels, heads=4, concat=True, dropout=0.0):
        super().__init__(node_dim=0)  # 指定节点维度
        self.heads = heads
        self.concat = concat
        self.dropout = dropout
        self.head_dim = out_channels // heads
        
        assert out_channels % heads == 0
        
        self.lin = nn.Linear(in_channels, heads * self.head_dim, bias=False)
        self.att = nn.Parameter(torch.Tensor(1, heads, 2 * self.head_dim))
        nn.init.xavier_uniform_(self.att)
        
        if concat:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.bias = nn.Parameter(torch.Tensor(self.head_dim))
        
        nn.init.zeros_(self.bias)
    
    def forward(self, x, edge_index):
        H = self.heads
        C = self.head_dim
        
        # 线性变换
        x = self.lin(x).view(-1, H, C)  # (N, H, C)
        
        # 添加自环
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        # 消息传递
        out = self.propagate(edge_index, x=x)
        
        # 处理输出维度
        if self.concat:
            out = out.view(-1, H * C)
            out = out + self.bias
        else:
            out = out.mean(dim=1) + self.bias
        
        return out
    
    def message(self, x_i, x_j, index):
        # x_i: 目标节点特征 (num_edges, H, C)
        # x_j: 源节点特征 (num_edges, H, C)
        
        # 计算注意力分数
        alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)  # (num_edges, H)
        alpha = F.leaky_relu(alpha, 0.2)
        alpha = softmax(alpha, index)  # 按目标节点归一化
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        
        return x_j * alpha.view(-1, self.heads, 1)

8.4 使用 Sequential 构建多层 GNN

class GraphNeuralNetwork(nn.Module):
    """多层GNN模型"""
    
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3):
        super().__init__()
        self.convs = nn.ModuleList()
        
        # 第一层
        self.convs.append(GATConv(in_channels, hidden_channels, heads=4))
        
        # 中间层
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(hidden_channels, hidden_channels, heads=4))
        
        # 最后一层
        self.convs.append(GATConv(hidden_channels, out_channels, heads=1, concat=False))
        
        self.norms = nn.ModuleList([nn.BatchNorm1d(hidden_channels) 
                                     for _ in range(num_layers - 1)])
    
    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x_new = conv(x, edge_index)
            x_new = self.norms[i](x_new)
            x_new = F.relu(x_new)
            x_new = F.dropout(x_new, p=0.5, training=self.training)
            x = x_new + x  # 残差连接
        
        # 最后一层不使用激活和dropout
        x = self.convs[-1](x, edge_index)
        return x

9. 不同设计的比较

9.1 消息函数比较

类型优点缺点适用场景
线性计算高效、可解释表达能力有限大规模图、边信息简单的场景
MLP表达能力强计算开销大异构图、复杂边关系
注意力自适应权重、捕获重要性差异参数量大关系重要性差异明显
多关系显式建模不同关系类型关系类型需要预先定义知识图谱、异构图

9.2 聚合函数比较

类型置换不变性表达能力计算复杂度
Sum中等
Mean中等
Max较弱
Attention
DeepSets
SetTrans最强

9.3 不同 GNN 架构对比

架构消息函数聚合函数更新函数表达能力
GCN度归一化和弱于WL
GraphSAGEMax/Mean/LSTM弱于WL
GAT加权和弱于WL
GINSum(单射)等价WL
GIN+Sum等价WL

9.4 计算效率对比

┌─────────────────────────────────────────────────────────────────────────┐
│                      各GNN架构计算效率对比                                 │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│  O(1)  ───────────────────────────────────────────────────────→ O(N²)  │
│                                                                         │
│  ┌─────────┐   ┌─────────┐   ┌─────────┐   ┌─────────┐   ┌─────────┐  │
│  │   GCN   │   │ Graph-  │   │   GAT   │   │  GIN    │   │  Edge-  │  │
│  │         │   │  SAGE   │   │         │   │         │   │  Conv   │  │
│  │  O(E)   │   │  O(E)   │   │  O(E·K) │   │  O(E)   │   │ O(N²)   │  │
│  │         │   │         │   │         │   │         │   │ (k-NN)  │  │
│  └─────────┘   └─────────┘   └─────────┘   └─────────┘   └─────────┘  │
│                                                                         │
│  E: 边数, N: 节点数, K: 注意力头数                                      │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

10. 实践指南

10.1 何时使用何种设计

┌─────────────────────────────────────────────────────────────────────────┐
│                        消息传递设计决策树                                 │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                         │
│                         开始设计                                         │
│                              │                                           │
│                              ↓                                           │
│                    ┌─────────────────┐                                   │
│                    │   边有权重吗?    │                                   │
│                    └────────┬────────┘                                   │
│                      是 ↓        ↓ 否                                    │
│                    使用权重消息    使用基础消息                            │
│                              │                                           │
│                              ↓                                           │
│                    ┌─────────────────┐                                   │
│                    │ 关系类型多样吗? │                                   │
│                    └────────┬────────┘                                   │
│                      是 ↓        ↓ 否                                    │
│               多关系消息函数     ↓                                       │
│                              │                                           │
│                              ↓                                           │
│                    ┌─────────────────┐                                   │
│                    │ 邻居重要性不同? │                                   │
│                    └────────┬────────┘                                   │
│                      是 ↓        ↓ 否                                    │
│                   注意力聚合     使用求和/均值                             │
│                                                                         │
└─────────────────────────────────────────────────────────────────────────┘

10.2 常见问题与解决方案

问题症状解决方案
过平滑所有节点表示趋于相同使用残差连接、增加层数限制、使用JK-Net
过拟合训练损失低但测试性能差使用dropout、减少层数、数据增强
数值不稳定NaN或Inf使用层归一化、降低学习率、梯度裁剪
异构图处理不同节点/边类型混淆使用异构图专用消息传递

10.3 最佳实践清单

  1. 数据预处理

    • 标准化节点特征
    • 添加自环(除非明确不需要)
    • 考虑是否使用节点度数作为额外特征
  2. 架构设计

    • 2-3层通常足够(除非需要很大感知域)
    • 使用残差连接支持更深的网络
    • 考虑使用JK-Net或注意力机制
  3. 训练技巧

    • 使用早停(early stopping)
    • 使用学习率衰减
    • 应用合适的dropout率(0.1-0.5)

11. 总结

消息传递机制是图神经网络的核心计算范式,通过消息生成消息聚合节点更新三个步骤实现图结构数据的表示学习。

11.1 核心要点

组件作用设计选择
消息函数转换邻居信息线性/MLP/注意力/多关系
聚合函数合并邻居消息Sum/Mean/Max/Attention
更新函数融合自身与聚合线性/GRU/残差连接

11.2 表达能力

  • 标准GNN的表达能力受限于1-WL测试
  • 单射聚合函数是达到WL表达能力的必要条件
  • GIN是理论上最具表达能力的架构之一
  • 超越WL需要额外的结构编码(如距离、子图信息)

11.3 未来方向

  • 更强大的聚合器:超越WL测试的表达能力
  • 动态邻域:基于学习的邻域选择
  • 异步消息传递:处理时序图和事件流
  • 稀疏化:大规模图的高效计算

参考资料

Footnotes

  1. Gilmer J, Schoenholz S S, Riley P F, et al. Neural Message Passing for Quantum Chemistry[J]. International Conference on Machine Learning (ICML), 2017. 2

  2. Xu K, Hu W, Leskovec J, et al. How Powerful are Graph Neural Networks?[J]. International Conference on Learning Representations (ICLR), 2019. 2

  3. You J, Ying R, Ren X, et al. Identity-Aware Graph Neural Networks in Knowledge Graph Completion[C]. AAAI Conference on Artificial Intelligence, 2021.

  4. Li J, Ma R, Guo Q, et al. Distance Encoding—Enable Informative Shortest Path Counting and Propagation for GNNs[J]. ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD), 2020.

  5. Frasca F, Bevilacqua B, Bronstein M M, et al. Understanding Unbalanced Semantic Classes in Graph Neural Networks’ Predictions[J]. arXiv preprint, 2022.

  6. Grohe M. The Logic of Graph Neural Networks[C]. ACM SIGLOG News, 2021.

  7. Wang Y, Sun Y, Liu Z, et al. Dynamic Graph CNN for Learning on Point Clouds[J]. ACM Transactions on Graphics (TOG), 2019.