Message Passing Neural Networks (MPNN)
1. 引言
Message Passing Neural Networks (MPNN) 是由 Gilmer 等人在 2017 年提出的统一框架,用于描述图神经网络的消息传递机制。1 该框架将图神经网络的运算形式化为”消息传递”过程,与概率图模型中的信念传播(Belief Propagation)有着深刻的联系。
MPNN 框架的核心思想是:通过迭代地聚合邻居节点信息来更新节点表示。这种思想既体现了深度学习的层次表示学习能力,又融合了概率图模型的图结构建模优势。
2. 形式化定义
2.1 基本框架
MPNN 包含两个核心阶段:消息传递阶段和** readout 阶段**。
消息传递阶段
对于第 层,MPNN 的消息传递遵循以下步骤:
消息计算:
其中:
- :节点 发送给节点 的消息
- :节点 在第 层的隐藏状态
- :边 的特征或属性
消息聚合:
节点更新:
其中 是节点更新函数, 是消息函数, 是节点 的邻居集合。
readout 阶段
其中 是 readout 函数,用于从所有节点表示中生成图级输出。
2.2 消息函数的设计
消息函数 可以有不同的设计方式:
加性消息(Additive):
乘性消息(Multiplicative):
MLP 消息:
3. 与概率图模型消息传递的联系
MPNN 与概率图模型中的**信念传播(Belief Propagation)**有着形式上的对应关系。
3.1 信念传播回顾
在因子图上,信念传播的消息更新规则为:
3.2 对应关系
| MPNN 组件 | 信念传播组件 | 数学联系 |
|---|---|---|
| 消息 | 消息 | 均表示邻居对当前节点的贡献 |
| 聚合 | 求和 | 均对邻居消息进行汇总 |
| 更新 | 归一化 | 均包含归一化/归约操作 |
| 迭代层数 | 消息传递步数 | 均通过迭代传播信息 |
3.3 关键区别
尽管形式相似,MPNN 与信念传播存在本质区别:
- 消息参数化:MPNN 的消息函数是可学习的,而 BP 的消息函数由概率模型固定
- 归一化处理:BP 需要保证消息归一化以避免数值溢出,MPNN 可选择是否归一化
- 目标函数:BP 旨在推断后验分布,MPNN 旨在学习下游任务预测
4. MPNN 变体分类
4.1 按消息类型分类
Message-Passing Neural Networks (MPNN)1:
- 使用 MLP 作为消息函数
- 支持边特征
- 适用于分子图等结构化数据
Graph Networks (GN)2:
- 支持节点、边、全局属性更新
- 更通用的消息传递框架
- 可模拟多种图网络架构
4.2 按聚合方式分类
加性聚合:
优点:简单、梯度流动好
缺点:丢失邻居节点数量信息
注意力聚合:
优点:自适应权重分配
缺点:计算复杂度较高
Set2Set 聚合:
优点:处理置换不变性
缺点:额外的时间步
4.3 按更新方式分类
GRU 更新:
LSTM 更新:
其中 是重置门
5. 代码实现
5.1 PyTorch Geometric 实现
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
class MPNNLayer(MessagePassing):
def __init__(self, in_channels, out_channels, edge_dim):
super(MPNNLayer, self).__init__(aggr='add') # 聚合方式
self.edge_lin = nn.Linear(edge_dim, in_channels)
self.mlp = nn.Sequential(
nn.Linear(2 * in_channels, out_channels),
nn.ReLU(),
nn.Linear(out_channels, out_channels)
)
self.update_mlp = nn.Sequential(
nn.Linear(in_channels + out_channels, out_channels),
nn.ReLU()
)
def forward(self, x, edge_index, edge_attr):
# 添加自环
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# 消息传递
return self.propagate(edge_index, x=x, edge_attr=edge_attr)
def message(self, x_j, edge_attr):
# 消息计算: M(h_v, h_u, e_vu)
if edge_attr is not None:
edge_emb = self.edge_lin(edge_attr)
return self.mlp(torch.cat([x_j + edge_emb], dim=-1))
return self.mlp(x_j)
def update(self, aggr_out, x):
# 节点更新: U(h_u, m_u)
return self.update_mlp(torch.cat([x, aggr_out], dim=-1))5.2 带注意力机制的 MPNN
class AttentionMPNN(MessagePassing):
def __init__(self, in_channels, out_channels):
super(AttentionMPNN, self).__init__(aggr='add')
self.query_lin = nn.Linear(in_channels, out_channels)
self.key_lin = nn.Linear(in_channels, out_channels)
self.value_lin = nn.Linear(in_channels, out_channels)
self.mlp = nn.Sequential(
nn.Linear(2 * out_channels, out_channels),
nn.ReLU()
)
def forward(self, x, edge_index):
return self.propagate(edge_index, x=x)
def message(self, x_i, x_j):
# 计算注意力权重
q = self.query_lin(x_i)
k = self.key_lin(x_j)
v = self.value_lin(x_j)
alpha = (q * k).sum(dim=-1) / np.sqrt(k.size(-1))
alpha = torch.softmax(alpha, dim=0)
return alpha.unsqueeze(-1) * v
def update(self, aggr_out, x):
return self.mlp(torch.cat([x, aggr_out], dim=-1))6. 理论性质
6.1 表达能力
MPNN 的表达能力受限于 1-WL 同构测试(即图同构网络 GNN 无法区分非同构图)。[3]
定理(Weisfeiler-Lehman 界限):
具有足够大的嵌入维度和适当的聚合函数,MPNN 的表达能力不超过 1-WL 同构测试。
6.2 置换不变性
MPNN 的输出对于节点顺序的置换是不变的:
其中 是节点置换操作, 表示应用置换到图上。
6.3 过平滑问题
深层 MPNN 面临**过平滑(Over-smoothing)**问题,节点表示趋于相同。
缓解方法:
- Jump Knowledge Network3
- 残差连接
- 层归一化
- 稀疏连接模式
7. 应用场景
7.1 分子图预测
MPNN 在分子性质预测中有广泛应用:
- 药物发现(溶解度、毒性预测)
- 材料科学(带隙、稳定性预测)
- 分子生成
7.2 知识图谱
- 链接预测
- 实体关系推理
- 知识图谱补全
7.3 交通预测
- 路网流量预测
- 出行时间估计
- 异常检测
8. 与其他图神经网络的联系
graph LR A[MPNN框架] --> B[GNN基础] A --> C[Graph Networks] B --> D[GCN] B --> E[GAT] B --> F[GIN] C --> G[Message Passing Neural Networks] G --> H[Interaction Networks]