1. 引言
消息传递(Message Passing)是图神经网络(GNN)的核心计算范式,将深度学习的表示学习能力与图结构数据的拓扑信息有机结合。自Gilmer等人于2017年提出”消息传递神经网络”(Message Passing Neural Networks, MPNN)框架以来1,消息传递已成为GNN领域最重要的理论基石之一。
1.1 消息传递的背景
传统的深度学习模型(如CNN、RNN)主要处理规则结构的数据(图像、序列),而图数据具有非欧几里得特性:
- 节点无序性:节点的邻居没有固定顺序
- 结构多样性:图的结构可以任意变化
- 关系复杂性:边可以带权、重边、方向等属性
消息传递机制通过局部信息交换的方式,让每个节点能够感知其邻域的结构和特征信息,从而实现对图数据的深度表示学习。
1.2 与其他GNN方法的关系
┌─────────────────────────────────────────────────────────────────────────┐
│ 图神经网络方法分类 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────┐ ┌─────────────────────────┐ │
│ │ 谱域方法 │ │ 空域方法 │ │
│ │ (Spectral Methods) │ │ (Spatial Methods) │ │
│ ├─────────────────────────┤ ├─────────────────────────┤ │
│ │ • GCN (谱卷积近似) │ │ • Message Passing (核心)│ │
│ │ • ChebNet (切比雪夫多项式)│ │ • GraphSAGE │ │
│ │ • GIN (理论分析) │ │ • GAT (注意力机制) │ │
│ └─────────────────────────┘ └─────────────────────────┘ │
│ │
│ Message Passing 是空域方法的核心抽象 │
│ │
└─────────────────────────────────────────────────────────────────────────┘2. 消息传递框架形式化
2.1 MPNN通用框架
Gilmer等人提出的消息传递神经网络框架是GNN领域最具影响力的形式化定义之一1。整个框架可以分解为三个核心操作:消息生成、消息聚合和节点更新。
2.1.1 单层消息传递的定义
设 为一个图,其中 是节点集合, 是边集合。对于节点 ,记其邻居节点集合为 。
消息传递层(Message Passing Layer)定义为:
其中:
| 符号 | 含义 |
|---|---|
| 第 层节点 的隐藏状态 | |
| 边 的特征向量 | |
| 消息函数,生成从邻居到中心节点的消息 | |
| 聚合操作(可以是求和、均值、最大值等) | |
| 更新函数,结合自身状态和聚合消息更新节点状态 |
2.1.2 消息传递的迭代过程
完整的 层消息传递网络可以形式化为:
2.2 消息传递的计算图
┌─────────────────────────────────────────────────────────────────────────┐
│ 单层消息传递计算图 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ Layer l Layer l+1 │
│ │
│ h_u^(l) ──┐ ┌──→ h_v^(l+1) │
│ │ │ │
│ h_w^(l) ──┼──→ Message ──→ AGG ───┤ │
│ │ Function │ │ │
│ h_x^(l) ──┘ ↓ │ │ │
│ Aggregate │ │
│ │ │ │
│ ┌────┴────┐ │ │
│ │ a_v^(l) │ ─────┘ │
│ └─────────┘ │
│ │ │
│ ↓ │
│ ┌─────────┐ │
│ │ Update │ │
│ └─────────┘ │
│ │
│ 消息函数: m_{u→v} = Message(h_u, h_v, e_uv) │
│ 聚合操作: a_v = ⊕_{u∈N(v)} m_{u→v} │
│ 更新函数: h_v' = Update(h_v, a_v) │
│ │
└─────────────────────────────────────────────────────────────────────────┘2.3 与图卷积网络的关系
消息传递框架是图卷积网络(GCN)的更一般化表示。以GCN为例:
GCN的传播规则:
这可以分解为:
| 组件 | GCN | 通用消息传递 |
|---|---|---|
| 消息函数 | ||
| 聚合操作 | ||
| 更新函数 |
3. 消息函数设计
消息函数(Message Function)是消息传递的核心组件,负责将邻居节点的信息转换为可以传递给目标节点的消息形式。消息函数的设计直接影响GNN的表达能力和信息流动方式。
3.1 线性消息函数
最简单的消息函数是线性变换,直接对邻居特征进行线性投影:
其中 是可学习的权重矩阵。
代码实现:
import torch
import torch.nn as nn
class LinearMessageFunction(nn.Module):
"""线性消息函数"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.linear = nn.Linear(in_channels, out_channels)
def forward(self, x_j, x_i=None, edge_attr=None):
"""
Args:
x_j: 源节点特征 (num_edges, in_channels)
x_i: 目标节点特征 (num_edges, in_channels) - 可选
edge_attr: 边特征 (num_edges, edge_channels) - 可选
Returns:
消息 (num_edges, out_channels)
"""
return self.linear(x_j)3.2 带边信息的线性消息
当边带有额外属性(如关系类型、权重)时,可以将其融入消息函数:
或使用更复杂的交互:
其中 表示向量拼接, 是多层感知机。
代码实现:
class EdgeAwareMessageFunction(nn.Module):
"""带边信息的消息函数"""
def __init__(self, node_channels, edge_channels, out_channels):
super().__init__()
self.node_proj = nn.Linear(node_channels, out_channels)
self.edge_proj = nn.Linear(edge_channels, out_channels)
def forward(self, x_j, x_i=None, edge_attr=None):
if edge_attr is not None:
msg = self.node_proj(x_j) + self.edge_proj(edge_attr)
else:
msg = self.node_proj(x_j)
return msg
class ConcatenationMessageFunction(nn.Module):
"""拼接型消息函数"""
def __init__(self, node_channels, edge_channels, hidden_channels):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(node_channels + edge_channels, hidden_channels),
nn.ReLU(),
nn.Linear(hidden_channels, hidden_channels)
)
def forward(self, x_j, x_i=None, edge_attr=None):
if edge_attr is not None:
x = torch.cat([x_j, edge_attr], dim=-1)
else:
x = x_j
return self.mlp(x)3.3 基于注意力机制的消息
注意力机制允许模型为不同邻居分配不同的重要性,是现代GNN架构的核心组件。
3.3.1 基础注意力消息
其中注意力系数:
3.3.2 多头注意力
为了增强模型的稳定性和表达能力,通常使用多头注意力机制:
代码实现:
class AttentionMessageFunction(nn.Module):
"""基于注意力机制的消息函数"""
def __init__(self, in_channels, out_channels, heads=4, concat=True):
super().__init__()
self.heads = heads
self.concat = concat
self.out_channels = out_channels
self.head_dim = out_channels // heads if concat else out_channels
self.linear = nn.Linear(in_channels, self.head_dim * heads)
self.att = nn.Parameter(torch.Tensor(1, heads, 2 * self.head_dim))
nn.init.xavier_uniform_(self.att)
def forward(self, x_j, x_i, edge_attr=None):
# x_j: (num_edges, in_channels) - 源节点
# x_i: (num_edges, in_channels) - 目标节点
h_j = self.linear(x_j) # (num_edges, heads * head_dim)
h_i = self.linear(x_i) # (num_edges, heads * head_dim)
batch_size = h_j.size(0)
h_j = h_j.view(batch_size, self.heads, self.head_dim)
h_i = h_i.view(batch_size, self.heads, self.head_dim)
# 拼接源节点和目标节点
combined = torch.cat([h_j, h_i], dim=-1) # (num_edges, heads, 2*head_dim)
# 计算注意力分数
att_score = (combined * self.att).sum(dim=-1) # (num_edges, heads)
att_score = F.leaky_relu(att_score, 0.2)
# 跨边归一化(需要后续 softmax)
return att_score, h_j
class MultiHeadAttentionMessage(nn.Module):
"""多头注意力消息函数(完整实现)"""
def __init__(self, in_channels, out_channels, heads=4, edge_dim=None):
super().__init__()
self.heads = heads
self.head_dim = out_channels // heads
assert out_channels % heads == 0
self.linear_v = nn.Linear(in_channels, out_channels)
self.linear_qk = nn.Linear(in_channels, out_channels)
if edge_dim is not None:
self.edge_proj = nn.Linear(edge_dim, out_channels)
self.att = nn.Parameter(torch.Tensor(1, heads, 2 * self.head_dim))
nn.init.xavier_uniform_(self.att)
def forward(self, x_j, x_i, edge_attr=None):
# V: 消息内容
v = self.linear_v(x_j)
# Q/K: 用于计算注意力
q = self.linear_qk(x_i)
k = self.linear_qk(x_j)
if edge_attr is not None and hasattr(self, 'edge_proj'):
k = k + self.edge_proj(edge_attr)
# Reshape for multi-head attention
batch_size = x_j.size(0)
v = v.view(batch_size, self.heads, self.head_dim)
q = q.view(batch_size, self.heads, self.head_dim)
k = k.view(batch_size, self.heads, self.head_dim)
# 计算注意力分数
qk = torch.cat([q, k], dim=-1)
att = (qk * self.att).sum(dim=-1)
att = F.leaky_relu(att, 0.2)
return v, att3.4 多关系消息函数
在异构图(知识图谱、多关系图)中,不同类型的边需要不同的消息函数:
其中 是关系类型 对应的权重矩阵。
代码实现:
class MultiRelationMessageFunction(nn.Module):
"""多关系消息函数(用于异构图)"""
def __init__(self, in_channels, out_channels, num_relations):
super().__init__()
self.W_r = nn.Parameter(torch.Tensor(num_relations, in_channels, out_channels))
nn.init.xavier_uniform_(self.W_r)
def forward(self, x_j, edge_type=None):
"""
Args:
x_j: 源节点特征 (num_edges, in_channels)
edge_type: 边类型 (num_edges,)
"""
if edge_type is None:
# 如果没有边类型,使用共享权重
return torch.matmul(x_j.unsqueeze(1), self.W_r[0]).squeeze(1)
# 为每条边选择对应的关系矩阵
W = self.W_r[edge_type] # (num_edges, in_channels, out_channels)
return torch.bmm(x_j.unsqueeze(1), W).squeeze(1) # (num_edges, out_channels)4. 聚合函数设计
聚合函数(Aggregation Function)负责将来自多个邻居的消息合并为单一表示,是消息传递框架中最关键的组件之一。不同的聚合方式决定了GNN捕获不同图结构信息的能力。
4.1 基础聚合操作
4.1.1 求和聚合(Sum Pooling)
特点:
- 保留所有邻居的信息
- 对邻居数量不敏感
- 计算效率高
def sum_aggregate(messages):
"""求和聚合"""
return torch.sum(messages, dim=0)4.1.2 均值聚合(Mean Pooling)
特点:
- 对邻居数量敏感
- 有利于学习邻居特征的分布信息
- 常与归一化结合使用
def mean_aggregate(messages):
"""均值聚合"""
return torch.mean(messages, dim=0)4.1.3 最大池化聚合(Max Pooling)
特点:
- 捕获最显著的特征
- 对异常值鲁棒
- 不保留完整的邻居分布信息
def max_aggregate(messages):
"""最大池化聚合"""
return torch.max(messages, dim=0)[0]4.2 注意力聚合
注意力聚合通过学习权重来加权组合邻居消息:
其中 是可学习的注意力系数。
代码实现:
class AttentionAggregation(nn.Module):
"""注意力聚合函数"""
def __init__(self, message_dim, head_dim):
super().__init__()
self.alpha = nn.Sequential(
nn.Linear(2 * message_dim, head_dim),
nn.LeakyReLU(),
nn.Linear(head_dim, 1)
)
def forward(self, messages, targets):
"""
Args:
messages: 邻居消息 (num_neighbors, message_dim)
targets: 目标节点特征 (message_dim)
Returns:
聚合结果 (message_dim,)
"""
num_neighbors = messages.size(0)
# 扩展目标节点以匹配邻居数量
targets_expanded = targets.unsqueeze(0).expand(num_neighbors, -1)
# 计算注意力系数
combined = torch.cat([messages, targets_expanded], dim=-1)
alpha = self.alpha(combined) # (num_neighbors, 1)
alpha = F.softmax(alpha, dim=0)
# 加权求和
return (alpha * messages).sum(dim=0)4.3 Set Pooling(集合池化)
由于图中的邻居集合是无序的,理想的聚合函数应该具有置换不变性(Permutation Invariant)。Set Pooling通过神经网络实现这一点:
4.3.1 Deep Sets
代码实现:
class DeepSetsAggregation(nn.Module):
"""Deep Sets 聚合 - 置换不变聚合"""
def __init__(self, in_channels, hidden_channels):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(in_channels, hidden_channels),
nn.ReLU()
)
self.decoder = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels),
nn.ReLU()
)
def forward(self, messages):
"""
Args:
messages: 邻居消息 (num_neighbors, in_channels)
Returns:
聚合结果 (hidden_channels,)
"""
# Encode each message
encoded = self.encoder(messages) # (num_neighbors, hidden_channels)
# Sum pooling (置换不变操作)
summed = torch.sum(encoded, dim=0) # (hidden_channels,)
# Decode
return self.decoder(summed)4.3.2 Set Transformer
使用注意力机制实现集合聚合,具有更强的表达能力:
代码实现:
class SetTransformerAggregation(nn.Module):
"""Set Transformer 聚合"""
def __init__(self, dim, num_heads=4, num_seeds=1):
super().__init__()
self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.feed_forward = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.ReLU(),
nn.Linear(dim * 4, dim)
)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.num_seeds = num_seeds
# 可学习的种子向量
if num_seeds > 0:
self.seeds = nn.Parameter(torch.randn(num_seeds, dim))
def forward(self, messages, target=None):
"""
Args:
messages: 邻居消息 (num_neighbors, dim)
target: 目标节点特征 (dim,) - 可选
Returns:
聚合结果 (dim,) 或 (num_seeds, dim)
"""
num_neighbors = messages.size(0)
if self.num_seeds > 0:
# 使用可学习种子作为Query
queries = self.seeds.unsqueeze(0) # (1, num_seeds, dim)
keys = messages.unsqueeze(0) # (1, num_neighbors, dim)
values = messages.unsqueeze(0) # (1, num_neighbors, dim)
else:
# 使用目标节点作为Query
queries = target.unsqueeze(0).unsqueeze(0) # (1, 1, dim)
keys = messages.unsqueeze(0) # (1, num_neighbors, dim)
values = messages.unsqueeze(0) # (1, num_neighbors, dim)
# Self-attention
attn_out, _ = self.self_attn(queries, keys, values)
x = queries + attn_out
x = self.norm1(x)
# Feed-forward
ff_out = self.feed_forward(x)
x = x + ff_out
x = self.norm2(x)
return x.squeeze(0) # (num_seeds, dim) 或 (dim,)4.4 Sort Pooling
Sort Pooling通过将邻居特征排序来实现置换不变性:
- 将邻居特征排序
- 选择前 个(或全部)邻居
- 应用卷积或全连接层
代码实现:
class SortPooling(nn.Module):
"""Sort Pooling 聚合"""
def __init__(self, k):
super().__init__()
self.k = k
def forward(self, messages, dim=0):
"""
Args:
messages: 邻居消息 (num_neighbors, feature_dim)
dim: 排序维度
Returns:
排序后并截取的结果 (k, feature_dim) 或 (num_neighbors, feature_dim)
"""
# 排序(按所有特征的字典序)
sorted_messages, _ = torch.sort(messages, dim=dim)
# 截取前k个
if self.k < messages.size(dim):
if dim == 0:
return sorted_messages[:self.k]
else:
return sorted_messages[:, :self.k]
return sorted_messages4.5 归一化考量
在实际应用中,聚合操作通常需要归一化以保证数值稳定性和训练效率:
| 归一化类型 | 公式 | 适用场景 |
|---|---|---|
| 节点度归一化 | 防止度数大的节点主导 | |
| 添加自环 | 保留节点自身信息 | |
| L2归一化 | 特征长度归一化 | |
| Batch归一化 | 训练稳定性 |
5. 更新函数设计
更新函数(Update Function)将聚合后的邻居信息与节点当前状态结合,生成新的节点表示。
5.1 基础更新函数
5.1.1 线性更新
其中 表示向量拼接, 是非线性激活函数。
class LinearUpdate(nn.Module):
"""线性更新函数"""
def __init__(self, hidden_channels):
super().__init__()
self.linear = nn.Linear(2 * hidden_channels, hidden_channels)
self.activation = nn.ReLU()
def forward(self, h, aggregated):
combined = torch.cat([h, aggregated], dim=-1)
return self.activation(self.linear(combined))5.1.2 GRU更新
使用门控循环单元(Gated Recurrent Unit)进行更新:
class GRUUpdate(nn.Module):
"""GRU更新函数"""
def __init__(self, hidden_channels):
super().__init__()
self.gru = nn.GRUCell(hidden_channels, hidden_channels)
def forward(self, h, aggregated):
"""
Args:
h: 当前隐藏状态 (batch, hidden_channels)
aggregated: 聚合消息 (batch, hidden_channels)
"""
return self.gru(aggregated, h)5.2 残差连接与层归一化
为了支持更深的网络架构,常使用残差连接和层归一化:
class ResidualUpdate(nn.Module):
"""带残差连接的更新函数"""
def __init__(self, hidden_channels, dropout=0.1):
super().__init__()
self.update = LinearUpdate(hidden_channels)
self.norm = nn.LayerNorm(hidden_channels)
self.dropout = nn.Dropout(dropout)
def forward(self, h, aggregated):
h_new = self.update(h, aggregated)
h_new = self.dropout(h_new)
return self.norm(h + h_new)5.3 动态更新
根据任务需求,更新函数可以采用不同的策略:
| 更新策略 | 公式 | 应用场景 |
|---|---|---|
| 保留自身 | 标准GNN | |
| 替换自身 | GraphSAGE | |
| 门控更新 | 序列建模 | |
| 循环更新 | 记忆增强 |
6. 表达能力分析
6.1 WL图同构测试
Weisfeiler-Lehman(WL)测试是分析GNN表达能力的重要工具。Xu等人证明2:GNN的表达能力上界是1-WL测试(颜色细化算法)。
6.1.1 WL测试回顾
1-WL测试(也称为颜色细化算法)的工作流程:
初始化:所有节点着相同颜色
迭代:
1. 每个节点收集邻居颜色集合
2. 根据自身颜色和邻居颜色集合分配新颜色
3. 如果颜色分布稳定,则停止
6.1.2 GNN与WL测试的关系
关键定理:如果GNN中的聚合函数是单射的(injective),则该GNN的表达能力与1-WL测试等价。
这意味着:
- 具有单射聚合的GNN可以区分任何1-WL测试能区分的图
- 任何无法被1-WL测试区分的图对,也无法被此类GNN区分
6.1.3 什么可以学习?
满足单射性的GNN可以区分:
| 结构类型 | 示例 | 可区分性 |
|---|---|---|
| 不同度分布 | vs | ✓ |
| 不同子图结构 | 三角形 vs 路径 | ✓ |
| 节点角色 | 中心节点 vs 边缘节点 | ✓ |
6.1.4 什么不能学习?
1-WL测试无法区分的结构,GNN也无法区分:
┌─────────────────────────────────────────────────────────────────────────┐
│ 1-WL测试无法区分的典型例子 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 例1:同构子图 │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ ●──●──● │ │ ●──●──● │ 完全相同的多重集 │
│ │ │ │ │ │ = │ │ │ │ │ → WL测试无法区分 │
│ │ ●──●──● │ │ ●──●──● │ → GNN也无法区分 │
│ └──────────────┘ └──────────────┘ │
│ │
│ 例2:不同环结构 │
│ ┌──────────────┐ ┌──────────────┐ │
│ │ ●──● │ │ ●──● │ 邻居都是 2 度节点 │
│ │ /│\ │ │ │ / \ │ │ → WL测试无法区分 │
│ │ ● ● ● ● │ ≠ │ ● ● ● │ → GNN也无法区分 │
│ │ \│/ │ │ │ \ │/ │ │
│ │ ●──● │ │ ●──● │ │
│ └──────────────┘ └──────────────┘ │
│ 六边形 桶形结构 │
│ │
│ 例3:距离信息丢失 │
│ ●───────●───────● │
│ vs │
│ ●─────────●─────────● │
│ │
│ 距离不同但邻居多重集相同 │
│ │
└─────────────────────────────────────────────────────────────────────────┘6.2 提升表达能力的方法
6.2.1 GIN(Graph Isomorphism Network)
Xu等人证明2,Graph Isomorphism Network (GIN) 是表达能力最强的GNN之一:
其中 必须是单射的。
代码实现:
class GINConv(nn.Module):
"""GIN卷积层 - 理论上最具表达能力的GNN"""
def __init__(self, in_channels, out_channels, eps=0.0, train_eps=True):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_channels, out_channels),
nn.ReLU(),
nn.Linear(out_channels, out_channels)
)
if train_eps:
self.eps = nn.Parameter(torch.tensor(eps))
else:
self.eps = eps
def forward(self, x, edge_index):
"""
Args:
x: 节点特征 (num_nodes, in_channels)
edge_index: 边索引 (2, num_edges)
"""
row, col = edge_index
# 邻居求和
out = torch.zeros_like(x)
out.index_add_(0, col, x[row])
# 添加自环和可学习的epsilon
out = (1 + self.eps) * x + out
return self.mlp(out)6.2.2 距离感知GNN
为了捕获节点间的距离信息,可以使用:
class DistanceAwareMessage(nn.Module):
"""距离感知消息函数"""
def __init__(self, hidden_channels):
super().__init__()
self.message = nn.Linear(hidden_channels + hidden_channels, hidden_channels)
def forward(self, x_j, x_i, edge_index):
"""
x_j: 源节点特征
x_i: 目标节点特征
edge_index: 边索引(需要额外计算距离)
"""
# 简化示例:实际需要计算节点间的距离
# 这里假设 edge_attr 已包含距离信息
return self.message(torch.cat([x_j, x_i], dim=-1))7. 高级消息传递变体
7.1 Jump Knowledge Networks (JK-Net)
核心思想:在最后一层之前,将所有中间层的表示连接起来,充分利用不同感知域的信息。
其中 可以是:
- Concat:[所有层拼接]
- Max:[逐位置取最大值]
- LSTM:[通过LSTM选择最重要信息]
class JKNet(nn.Module):
"""Jump Knowledge Networks"""
def __init__(self, hidden_channels, num_layers, mode='max'):
super().__init__()
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
for i in range(num_layers):
self.convs.append(GINConv(hidden_channels, hidden_channels))
self.bns.append(nn.BatchNorm1d(hidden_channels))
self.mode = mode
if mode == 'lstm':
self.lstm = nn.LSTM(
hidden_channels, hidden_channels // 2,
bidirectional=True, batch_first=True
)
def forward(self, x, edge_index):
xs = []
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
x = conv(x, edge_index)
x = bn(x)
x = F.relu(x)
xs.append(x)
if self.mode == 'concat':
return torch.cat(xs, dim=-1)
elif self.mode == 'max':
return torch.stack(xs, dim=0).max(dim=0)[0]
elif self.mode == 'lstm':
x = torch.stack(xs, dim=1) # (N, num_layers, hidden)
out, _ = self.lstm(x)
return out[:, -1, :]
else:
return xs[-1]7.2 有向消息传递 (Directed Message Passing)
在分子图中,边的方向(如化学键方向)携带重要信息。有向消息传递显式建模方向性:
其中 表示边的方向类型。
class DirectedMessagePassing(nn.Module):
"""有向消息传递"""
def __init__(self, hidden_channels, num_directions=2):
super().__init__()
# 两种方向:正向和反向
self.W_forward = nn.Linear(hidden_channels, hidden_channels)
self.W_backward = nn.Linear(hidden_channels, hidden_channels)
def forward(self, x, edge_index):
row, col = edge_index
# 正向消息(row → col)
msg_forward = self.W_forward(x[row])
# 反向消息(col → row)
msg_backward = self.W_backward(x[col])
# 聚合
out = torch.zeros_like(x)
out.index_add_(0, col, msg_forward)
out.index_add_(0, row, msg_backward)
return out
class BondDirectionMessage(nn.Module):
"""化学键方向感知消息传递(用于分子图)"""
def __init__(self, hidden_channels, num_bond_types):
super().__init__()
self.W = nn.ModuleDict({
'single': nn.Linear(hidden_channels, hidden_channels),
'double': nn.Linear(hidden_channels, hidden_channels),
'triple': nn.Linear(hidden_channels, hidden_channels),
'aromatic': nn.Linear(hidden_channels, hidden_channels),
})
def forward(self, x, edge_index, bond_types):
"""
bond_types: List of bond types for each edge
"""
row, col = edge_index
messages = []
for i in range(len(row)):
src, dst = row[i].item(), col[i].item()
bond = bond_types[i]
msg = self.W[bond](x[src])
messages.append(msg)
return torch.stack(messages)7.3 边缘卷积 (Edge Convolution)
DYNAMIC EDGE CONVOLUTION(也称EdgeConv)由Wang等人提出7,是点云处理中常用的方法。
7.3.1 核心思想
EdgeConv不是聚合邻居的消息,而是为每个节点计算其与每个邻居的边特征:
其中 捕获了节点间的相对位置/特征差异。
7.3.2 k-NN vs 全图连接
EdgeConv可以使用 近邻(-NN)构建局部邻域,而非仅使用图结构中的边:
class EdgeConv(nn.Module):
"""边缘卷积 (EdgeConv)"""
def __init__(self, in_channels, out_channels, k=20):
super().__init__()
self.k = k
self.mlp = nn.Sequential(
nn.Linear(in_channels * 2, out_channels),
nn.ReLU(),
nn.BatchNorm1d(out_channels),
nn.Linear(out_channels, out_channels),
nn.ReLU(),
nn.BatchNorm1d(out_channels)
)
def knn(self, x):
"""计算k近邻"""
pairwise_dist = torch.cdist(x, x, p=2)
_, indices = pairwise_dist.topk(self.k, largest=False)
return indices
def forward(self, x, edge_index=None):
"""
Args:
x: 节点特征 (N, C)
edge_index: 可选的预定义边(如果为None则使用k-NN)
"""
N = x.size(0)
if edge_index is None:
# 使用k-NN构建邻域
indices = self.knn(x) # (N, k)
else:
indices = edge_index[1].view(N, -1)
# 获取邻居特征
neighbor_features = x[indices] # (N, k, C)
# 拼接中心节点特征(与每个邻居的特征差)
center_features = x.unsqueeze(1).expand(-1, self.k, -1) # (N, k, C)
diff = neighbor_features - center_features # (N, k, C)
# 边特征
edge_features = torch.cat([center_features, diff], dim=-1) # (N, k, 2C)
# 通过MLP处理每条边
edge_features = edge_features.view(-1, 2 * x.size(1)) # (N*k, 2C)
edge_features = self.mlp(edge_features) # (N*k, C)
edge_features = edge_features.view(N, self.k, -1) # (N, k, C)
# 聚合(最大池化)
aggregated = edge_features.max(dim=1)[0] # (N, C)
return aggregated + x # 残差连接8. PyTorch Geometric 实现
PyTorch Geometric(PyG)是目前最流行的GNN框架,提供了丰富的消息传递API。
8.1 MessagePassing 基类
PyG的 MessagePassing 类封装了消息传递的核心逻辑:
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
class MyMessagePassingLayer(MessagePassing):
"""自定义消息传递层"""
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # 默认聚合方式
self.lin = nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# 添加自环,使节点能够获取自身信息
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# 调用propagate开始消息传递
return self.propagate(edge_index, x=x)
def message(self, x_j, x_i):
"""
消息函数
x_j: 源节点特征
x_i: 目标节点特征
"""
return self.lin(x_j)
def update(self, updated):
"""
更新函数
updated: 聚合后的结果
"""
return updated8.2 完整 GCN 层实现
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
"""GCN卷积层的PyG实现"""
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add')
self.lin = nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# 添加自环
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# 计算度归一化系数
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# 线性变换
x = self.lin(x)
# 消息传递
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# 消息函数:节点特征 × 归一化系数
return norm.view(-1, 1) * x_j8.3 GAT 层实现
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
class GATConv(MessagePassing):
"""GAT卷积层的PyG实现"""
def __init__(self, in_channels, out_channels, heads=4, concat=True, dropout=0.0):
super().__init__(node_dim=0) # 指定节点维度
self.heads = heads
self.concat = concat
self.dropout = dropout
self.head_dim = out_channels // heads
assert out_channels % heads == 0
self.lin = nn.Linear(in_channels, heads * self.head_dim, bias=False)
self.att = nn.Parameter(torch.Tensor(1, heads, 2 * self.head_dim))
nn.init.xavier_uniform_(self.att)
if concat:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.bias = nn.Parameter(torch.Tensor(self.head_dim))
nn.init.zeros_(self.bias)
def forward(self, x, edge_index):
H = self.heads
C = self.head_dim
# 线性变换
x = self.lin(x).view(-1, H, C) # (N, H, C)
# 添加自环
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# 消息传递
out = self.propagate(edge_index, x=x)
# 处理输出维度
if self.concat:
out = out.view(-1, H * C)
out = out + self.bias
else:
out = out.mean(dim=1) + self.bias
return out
def message(self, x_i, x_j, index):
# x_i: 目标节点特征 (num_edges, H, C)
# x_j: 源节点特征 (num_edges, H, C)
# 计算注意力分数
alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1) # (num_edges, H)
alpha = F.leaky_relu(alpha, 0.2)
alpha = softmax(alpha, index) # 按目标节点归一化
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return x_j * alpha.view(-1, self.heads, 1)8.4 使用 Sequential 构建多层 GNN
class GraphNeuralNetwork(nn.Module):
"""多层GNN模型"""
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3):
super().__init__()
self.convs = nn.ModuleList()
# 第一层
self.convs.append(GATConv(in_channels, hidden_channels, heads=4))
# 中间层
for _ in range(num_layers - 2):
self.convs.append(GATConv(hidden_channels, hidden_channels, heads=4))
# 最后一层
self.convs.append(GATConv(hidden_channels, out_channels, heads=1, concat=False))
self.norms = nn.ModuleList([nn.BatchNorm1d(hidden_channels)
for _ in range(num_layers - 1)])
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs[:-1]):
x_new = conv(x, edge_index)
x_new = self.norms[i](x_new)
x_new = F.relu(x_new)
x_new = F.dropout(x_new, p=0.5, training=self.training)
x = x_new + x # 残差连接
# 最后一层不使用激活和dropout
x = self.convs[-1](x, edge_index)
return x9. 不同设计的比较
9.1 消息函数比较
| 类型 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
| 线性 | 计算高效、可解释 | 表达能力有限 | 大规模图、边信息简单的场景 |
| MLP | 表达能力强 | 计算开销大 | 异构图、复杂边关系 |
| 注意力 | 自适应权重、捕获重要性差异 | 参数量大 | 关系重要性差异明显 |
| 多关系 | 显式建模不同关系类型 | 关系类型需要预先定义 | 知识图谱、异构图 |
9.2 聚合函数比较
| 类型 | 置换不变性 | 表达能力 | 计算复杂度 |
|---|---|---|---|
| Sum | ✓ | 中等 | |
| Mean | ✓ | 中等 | |
| Max | ✓ | 较弱 | |
| Attention | ✓ | 强 | |
| DeepSets | ✓ | 强 | |
| SetTrans | ✓ | 最强 |
9.3 不同 GNN 架构对比
| 架构 | 消息函数 | 聚合函数 | 更新函数 | 表达能力 |
|---|---|---|---|---|
| GCN | 度归一化和 | 弱于WL | ||
| GraphSAGE | Max/Mean/LSTM | 弱于WL | ||
| GAT | 加权和 | 弱于WL | ||
| GIN | Sum(单射) | 等价WL | ||
| GIN+ | Sum | 等价WL |
9.4 计算效率对比
┌─────────────────────────────────────────────────────────────────────────┐
│ 各GNN架构计算效率对比 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ O(1) ───────────────────────────────────────────────────────→ O(N²) │
│ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ GCN │ │ Graph- │ │ GAT │ │ GIN │ │ Edge- │ │
│ │ │ │ SAGE │ │ │ │ │ │ Conv │ │
│ │ O(E) │ │ O(E) │ │ O(E·K) │ │ O(E) │ │ O(N²) │ │
│ │ │ │ │ │ │ │ │ │ (k-NN) │ │
│ └─────────┘ └─────────┘ └─────────┘ └─────────┘ └─────────┘ │
│ │
│ E: 边数, N: 节点数, K: 注意力头数 │
│ │
└─────────────────────────────────────────────────────────────────────────┘10. 实践指南
10.1 何时使用何种设计
┌─────────────────────────────────────────────────────────────────────────┐
│ 消息传递设计决策树 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 开始设计 │
│ │ │
│ ↓ │
│ ┌─────────────────┐ │
│ │ 边有权重吗? │ │
│ └────────┬────────┘ │
│ 是 ↓ ↓ 否 │
│ 使用权重消息 使用基础消息 │
│ │ │
│ ↓ │
│ ┌─────────────────┐ │
│ │ 关系类型多样吗? │ │
│ └────────┬────────┘ │
│ 是 ↓ ↓ 否 │
│ 多关系消息函数 ↓ │
│ │ │
│ ↓ │
│ ┌─────────────────┐ │
│ │ 邻居重要性不同? │ │
│ └────────┬────────┘ │
│ 是 ↓ ↓ 否 │
│ 注意力聚合 使用求和/均值 │
│ │
└─────────────────────────────────────────────────────────────────────────┘10.2 常见问题与解决方案
| 问题 | 症状 | 解决方案 |
|---|---|---|
| 过平滑 | 所有节点表示趋于相同 | 使用残差连接、增加层数限制、使用JK-Net |
| 过拟合 | 训练损失低但测试性能差 | 使用dropout、减少层数、数据增强 |
| 数值不稳定 | NaN或Inf | 使用层归一化、降低学习率、梯度裁剪 |
| 异构图处理 | 不同节点/边类型混淆 | 使用异构图专用消息传递 |
10.3 最佳实践清单
-
数据预处理
- 标准化节点特征
- 添加自环(除非明确不需要)
- 考虑是否使用节点度数作为额外特征
-
架构设计
- 2-3层通常足够(除非需要很大感知域)
- 使用残差连接支持更深的网络
- 考虑使用JK-Net或注意力机制
-
训练技巧
- 使用早停(early stopping)
- 使用学习率衰减
- 应用合适的dropout率(0.1-0.5)
11. 总结
消息传递机制是图神经网络的核心计算范式,通过消息生成、消息聚合和节点更新三个步骤实现图结构数据的表示学习。
11.1 核心要点
| 组件 | 作用 | 设计选择 |
|---|---|---|
| 消息函数 | 转换邻居信息 | 线性/MLP/注意力/多关系 |
| 聚合函数 | 合并邻居消息 | Sum/Mean/Max/Attention |
| 更新函数 | 融合自身与聚合 | 线性/GRU/残差连接 |
11.2 表达能力
- 标准GNN的表达能力受限于1-WL测试
- 单射聚合函数是达到WL表达能力的必要条件
- GIN是理论上最具表达能力的架构之一
- 超越WL需要额外的结构编码(如距离、子图信息)
11.3 未来方向
- 更强大的聚合器:超越WL测试的表达能力
- 动态邻域:基于学习的邻域选择
- 异步消息传递:处理时序图和事件流
- 稀疏化:大规模图的高效计算
参考资料
Footnotes
-
Gilmer J, Schoenholz S S, Riley P F, et al. Neural Message Passing for Quantum Chemistry[J]. International Conference on Machine Learning (ICML), 2017. ↩ ↩2
-
Xu K, Hu W, Leskovec J, et al. How Powerful are Graph Neural Networks?[J]. International Conference on Learning Representations (ICLR), 2019. ↩ ↩2
-
You J, Ying R, Ren X, et al. Identity-Aware Graph Neural Networks in Knowledge Graph Completion[C]. AAAI Conference on Artificial Intelligence, 2021. ↩
-
Li J, Ma R, Guo Q, et al. Distance Encoding—Enable Informative Shortest Path Counting and Propagation for GNNs[J]. ACM SIGKDD Conference on Knowledge Discovery and Data Mining (KDD), 2020. ↩
-
Frasca F, Bevilacqua B, Bronstein M M, et al. Understanding Unbalanced Semantic Classes in Graph Neural Networks’ Predictions[J]. arXiv preprint, 2022. ↩
-
Grohe M. The Logic of Graph Neural Networks[C]. ACM SIGLOG News, 2021. ↩
-
Wang Y, Sun Y, Liu Z, et al. Dynamic Graph CNN for Learning on Point Clouds[J]. ACM Transactions on Graphics (TOG), 2019. ↩