Message Passing Neural Networks (MPNN)

1. 引言

Message Passing Neural Networks (MPNN) 是由 Gilmer 等人在 2017 年提出的统一框架,用于描述图神经网络的消息传递机制。1 该框架将图神经网络的运算形式化为”消息传递”过程,与概率图模型中的信念传播(Belief Propagation)有着深刻的联系。

MPNN 框架的核心思想是:通过迭代地聚合邻居节点信息来更新节点表示。这种思想既体现了深度学习的层次表示学习能力,又融合了概率图模型的图结构建模优势。

2. 形式化定义

2.1 基本框架

MPNN 包含两个核心阶段:消息传递阶段和** readout 阶段**。

消息传递阶段

对于第 层,MPNN 的消息传递遵循以下步骤:

消息计算

其中:

  • :节点 发送给节点 的消息
  • :节点 在第 层的隐藏状态
  • :边 的特征或属性

消息聚合

节点更新

其中 是节点更新函数, 是消息函数, 是节点 的邻居集合。

readout 阶段

其中 是 readout 函数,用于从所有节点表示中生成图级输出。

2.2 消息函数的设计

消息函数 可以有不同的设计方式:

加性消息(Additive)

乘性消息(Multiplicative)

MLP 消息

3. 与概率图模型消息传递的联系

MPNN 与概率图模型中的**信念传播(Belief Propagation)**有着形式上的对应关系。

3.1 信念传播回顾

在因子图上,信念传播的消息更新规则为:

3.2 对应关系

MPNN 组件信念传播组件数学联系
消息 消息 均表示邻居对当前节点的贡献
聚合 求和 均对邻居消息进行汇总
更新 归一化均包含归一化/归约操作
迭代层数消息传递步数均通过迭代传播信息

3.3 关键区别

尽管形式相似,MPNN 与信念传播存在本质区别:

  1. 消息参数化:MPNN 的消息函数是可学习的,而 BP 的消息函数由概率模型固定
  2. 归一化处理:BP 需要保证消息归一化以避免数值溢出,MPNN 可选择是否归一化
  3. 目标函数:BP 旨在推断后验分布,MPNN 旨在学习下游任务预测

4. MPNN 变体分类

4.1 按消息类型分类

Message-Passing Neural Networks (MPNN)1

  • 使用 MLP 作为消息函数
  • 支持边特征
  • 适用于分子图等结构化数据

Graph Networks (GN)2

  • 支持节点、边、全局属性更新
  • 更通用的消息传递框架
  • 可模拟多种图网络架构

4.2 按聚合方式分类

加性聚合

优点:简单、梯度流动好
缺点:丢失邻居节点数量信息

注意力聚合

优点:自适应权重分配
缺点:计算复杂度较高

Set2Set 聚合

优点:处理置换不变性
缺点:额外的时间步

4.3 按更新方式分类

GRU 更新

LSTM 更新

其中 是重置门

5. 代码实现

5.1 PyTorch Geometric 实现

import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops
 
class MPNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, edge_dim):
        super(MPNNLayer, self).__init__(aggr='add')  # 聚合方式
        self.edge_lin = nn.Linear(edge_dim, in_channels)
        self.mlp = nn.Sequential(
            nn.Linear(2 * in_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )
        self.update_mlp = nn.Sequential(
            nn.Linear(in_channels + out_channels, out_channels),
            nn.ReLU()
        )
    
    def forward(self, x, edge_index, edge_attr):
        # 添加自环
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        
        # 消息传递
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)
    
    def message(self, x_j, edge_attr):
        # 消息计算: M(h_v, h_u, e_vu)
        if edge_attr is not None:
            edge_emb = self.edge_lin(edge_attr)
            return self.mlp(torch.cat([x_j + edge_emb], dim=-1))
        return self.mlp(x_j)
    
    def update(self, aggr_out, x):
        # 节点更新: U(h_u, m_u)
        return self.update_mlp(torch.cat([x, aggr_out], dim=-1))

5.2 带注意力机制的 MPNN

class AttentionMPNN(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(AttentionMPNN, self).__init__(aggr='add')
        self.query_lin = nn.Linear(in_channels, out_channels)
        self.key_lin = nn.Linear(in_channels, out_channels)
        self.value_lin = nn.Linear(in_channels, out_channels)
        self.mlp = nn.Sequential(
            nn.Linear(2 * out_channels, out_channels),
            nn.ReLU()
        )
    
    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)
    
    def message(self, x_i, x_j):
        # 计算注意力权重
        q = self.query_lin(x_i)
        k = self.key_lin(x_j)
        v = self.value_lin(x_j)
        
        alpha = (q * k).sum(dim=-1) / np.sqrt(k.size(-1))
        alpha = torch.softmax(alpha, dim=0)
        
        return alpha.unsqueeze(-1) * v
    
    def update(self, aggr_out, x):
        return self.mlp(torch.cat([x, aggr_out], dim=-1))

6. 理论性质

6.1 表达能力

MPNN 的表达能力受限于 1-WL 同构测试(即图同构网络 GNN 无法区分非同构图)。[3]

定理(Weisfeiler-Lehman 界限)
具有足够大的嵌入维度和适当的聚合函数,MPNN 的表达能力不超过 1-WL 同构测试。

6.2 置换不变性

MPNN 的输出对于节点顺序的置换是不变的:

其中 是节点置换操作, 表示应用置换到图上。

6.3 过平滑问题

深层 MPNN 面临**过平滑(Over-smoothing)**问题,节点表示趋于相同。

缓解方法

  • Jump Knowledge Network3
  • 残差连接
  • 层归一化
  • 稀疏连接模式

7. 应用场景

7.1 分子图预测

MPNN 在分子性质预测中有广泛应用:

  • 药物发现(溶解度、毒性预测)
  • 材料科学(带隙、稳定性预测)
  • 分子生成

7.2 知识图谱

  • 链接预测
  • 实体关系推理
  • 知识图谱补全

7.3 交通预测

  • 路网流量预测
  • 出行时间估计
  • 异常检测

8. 与其他图神经网络的联系

graph LR
    A[MPNN框架] --> B[GNN基础]
    A --> C[Graph Networks]
    B --> D[GCN]
    B --> E[GAT]
    B --> F[GIN]
    C --> G[Message Passing Neural Networks]
    G --> H[Interaction Networks]

9. 参考资料

Footnotes

  1. Gilmer et al. (2017). “Neural Message Passing for Quantum Chemistry.” ICML 2017. 2

  2. Battaglia et al. (2018). “Relational Inductive Biases, Deep Learning, and Graph Networks.” arXiv:1806.01261.

  3. Xu et al. (2018). “Representation Learning on Graphs with Jumping Knowledge Networks.” ICML 2018.