1. 研究背景与问题定义
1.1 Over-smoothing问题回顾
图神经网络(GNN)在处理图结构数据时表现出色,但随着网络层数增加,面临严重的过平滑问题(Over-smoothing)1:
- 所有节点的表示趋向于收敛到相同的值
- 节点间的可区分性消失
- 深层网络的性能急剧下降
过平滑问题可以用数学语言形式化描述。设第 层节点表示为 ,则过平滑表现为2:
其中 是全1向量, 是与图结构相关的某个固定向量。节点的个性化信息在多次消息传递后完全丧失。
1.2 现有解决方案的局限
针对over-smoothing问题,已有多种解决方案:
| 方法类别 | 代表工作 | 局限性 |
|---|---|---|
| 残差连接 | JKNet, DenseGCN | 延迟而非解决平滑 |
| 归一化技巧 | BatchNorm, LayerNorm | 计算开销增加 |
| 注意力机制 | GAT, GATv2 | 复杂度 |
| 跳跃知识 | JKNet | 仍受限于单层表达能力 |
这些方法要么引入额外计算开销,要么无法从根本上解决平滑问题。
1.3 Mamba的启发
Mamba(选择性状态空间模型)在序列建模中展现出卓越的选择性能力3:
- 输入依赖的选择性:根据输入动态决定信息传递
- 线性复杂度: 而非
- 长距离依赖:通过状态空间建模捕捉远距离依赖
将Mamba的选择性机制引入GNN,理论上可以实现:
- 自适应选择重要邻居信息
- 动态控制信息传播范围
- 从根本上缓解over-smoothing
2. 论文核心贡献
2.1 主要创新点
IJCAI 2025论文《Mamba-Based Graph Convolutional Networks: Tackling Over-smoothing with Selective State Space》提出了Mamba-GCN架构2,核心创新包括:
- 选择性消息传递机制:根据节点特征动态选择聚合哪些邻居的信息
- 选择性状态空间模块:使用SSM建模节点状态的演化
- 端到端可训练架构:整个pipeline可微分,支持反向传播
2.2 与现有MbaGCN的区别
wiki中已有的Mamba与GNN融合文档介绍了MbaGCN架构。IJCAI 2025论文的Mamba-GCN有以下区别:
| 特性 | MbaGCN | Mamba-GCN (IJCAI 2025) |
|---|---|---|
| 选择性机制 | 门控信号 | 硬选择 + 软选择 |
| SSM集成方式 | 层间融合 | 层内深度融合 |
| 过平滑解决 | 缓解 | 从根本上消除 |
| 实验规模 | 中等 | 大规模实验验证 |
3. 理论框架
3.1 问题形式化
给定图 和节点特征 ,传统GCN的传播规则为:
其中 是带自环的邻接矩阵, 是度矩阵。
问题:无论输入特征如何,所有节点最终都会收敛到相似的表示。
3.2 选择性消息传递
Mamba-GCN的核心是选择性消息传递机制:
其中选择分数 由以下方式计算:
这里 是选择性门控,由Mamba的选择机制生成:
3.3 选择性状态空间模块
节点状态通过选择性SSM进行演化:
具体的选择性SSM计算流程:
步骤1:参数生成(输入依赖)
步骤2:离散化
步骤3:选择性扫描
3.4 解决Over-smoothing的理论分析
定理1(选择性防止平滑):对于选择性SSM,如果存在节点 使得 的谱半径 ,则该节点的表示不会趋于平凡值。
证明思路:设初始表示为 ,经过 步传播后:
当 时,第一项指数衰减到0,节点保留了初始信息的”记忆”。
推论:由于 是输入依赖的,网络可以学习为不同节点生成不同的转移矩阵,从而保留节点的个性化信息。
4. 架构详解
4.1 整体结构
┌──────────────────────────────────────────────────────────────────────┐
│ Mamba-GCN 整体架构 │
├──────────────────────────────────────────────────────────────────────┤
│ │
│ 输入: 节点特征 X, 边索引 E │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ 选择性消息传递层 (Selective Message Passing) │ │
│ │ │ │
│ │ Query生成: Q = X W_Q │ │
│ │ Key生成: K = X W_K │ │
│ │ Value生成: V = X W_V │ │
│ │ 选择性门控: G = σ(W_g [X; XV] ) │ │
│ │ 加权聚合: M = Softmax(KQ^T/√d) ⊙ G ⊙ V │ │
│ │ │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ 选择性SSM层 (Selective SSM Layer) │ │
│ │ │ │
│ │ 输入依赖参数: A(x), B(x), C(x), Δ(x) │ │
│ │ 离散化: Ā, B̄ │ │
│ │ 状态更新: h' = Āh + B̄m │ │
│ │ 输出投影: y = Ch' │ │
│ │ │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────────────────────────────────────────────────┐ │
│ │ 残差连接 + LayerNorm │ │
│ │ H_out = LayerNorm(H + SSM(H)) │ │
│ └──────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ 输出: 节点嵌入 H_out │
│ │
└──────────────────────────────────────────────────────────────────────┘
4.2 层间连接
Mamba-GCN采用层间选择性连接:
其中 是第 层的可学习跳跃系数:
这种设计允许网络自适应地控制信息流,浅层侧重于局部信息聚合,深层侧重于全局状态更新。
4.3 与现有工作的对比
| 特性 | GCN | GAT | MbaGCN | Mamba-GCN |
|---|---|---|---|---|
| 邻居选择 | 固定 | 注意力 | 门控 | 硬选择+软选择 |
| 状态建模 | 无 | 无 | SSM | 选择性SSM |
| 平滑问题 | 严重 | 中等 | 缓解 | 根本解决 |
| 复杂度 | ||||
| 可解释性 | 低 | 中 | 高 | 高 |
5. 实验结果
5.1 节点分类
在标准数据集上的节点分类结果:
| 数据集 | Cora | CiteSeer | PubMed | ogbn-arxiv |
|---|---|---|---|---|
| GCN | 81.5% | 70.3% | 79.0% | 71.9% |
| GAT | 83.0% | 72.5% | 79.0% | 72.1% |
| GraphSAGE | 82.3% | 72.1% | 78.5% | 72.3% |
| JKNet | 82.5% | 72.8% | 79.2% | 72.5% |
| MbaGCN | 85.2% | 74.1% | 81.3% | 73.8% |
| Mamba-GCN | 87.1% | 75.6% | 82.4% | 74.9% |
5.2 深度实验
验证深层网络(2-32层)的性能:
准确率 (%)
│
90 ├─────────────────────────────────────────────•••• Mamba-GCN (32层: 85.2%)
│ ┄┄•••••┄┄• •• • MbaGCN (32层: 78.3%)
85 ├ ┄┄• •••• •
│ ┄┄• •••• JKNet (32层: 71.5%)
│ ┄┄• •••
80 ├ ┄┄• ••• GCN (32层: 58.2%)
│ ┄┄•••
│ ┄┄•••
75 ├•••••
└────────────────────────────────────────────────────────► 层数
2 4 8 16 32
关键观察:
- 传统GCN在16层后性能急剧下降(over-smoothing主导)
- Mamba-GCN在32层仍保持高性能,证明了选择性机制的有效性
- MbaGCN有改进但仍受限于平滑问题
5.3 消融实验
| 组件 | Cora | CiteSeer | PubMed |
|---|---|---|---|
| 基线GCN | 81.5% | 70.3% | 79.0% |
| + 选择性消息传递 | 83.8% | 72.4% | 80.2% |
| + 选择性SSM | 85.2% | 73.8% | 81.1% |
| + 残差连接 | 86.1% | 74.7% | 81.9% |
| 完整模型 | 87.1% | 75.6% | 82.4% |
5.4 表示可视化
使用t-SNE对Cora数据集的节点嵌入进行可视化:
- GCN(深层):所有类别节点混杂在一起,over-smoothing严重
- Mamba-GCN(深层):不同类别清晰分离,保持良好的可区分性
6. 代码实现
6.1 选择性消息传递
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SelectiveMessagePassing(nn.Module):
"""
选择性消息传递层
根据输入动态选择要聚合的邻居信息
"""
def __init__(self, in_dim, out_dim):
super().__init__()
self.query_proj = nn.Linear(in_dim, out_dim)
self.key_proj = nn.Linear(in_dim, out_dim)
self.value_proj = nn.Linear(in_dim, out_dim)
# 选择性门控网络
self.gate_net = nn.Sequential(
nn.Linear(in_dim * 2, out_dim),
nn.Sigmoid()
)
# 输出投影
self.out_proj = nn.Linear(out_dim, out_dim)
def forward(self, x, edge_index):
"""
Args:
x: 节点特征 [N, d_in]
edge_index: 边索引 [2, E]
Returns:
选择性聚合后的消息 [N, d_out]
"""
N = x.shape[0]
# 生成Q, K, V
Q = self.query_proj(x)
K = self.key_proj(x)
V = self.value_proj(x)
# 计算注意力分数
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Q.shape[-1])
attn_weights = F.softmax(attn_scores, dim=-1)
# 生成选择性门控
src, dst = edge_index
x_src = x[src]
x_dst = x[dst]
gate_input = torch.cat([x_dst, x_src], dim=-1)
gate_values = self.gate_net(gate_input) # [E, d]
# 应用门控到注意力权重
# 只保留与目标节点"相关"的邻居信息
gated_weights = attn_weights[dst, src] # [E]
gated_weights = gated_weights.unsqueeze(-1) * gate_values # [E, d]
# 归一化
deg = torch.zeros(N, device=x.device)
deg.scatter_add_(0, dst, torch.ones(dst.shape[0], device=x.device))
deg = deg.unsqueeze(-1) + 1e-8 # 避免除零
# 加权聚合
messages = V[src] * gated_weights
out = torch.zeros(N, V.shape[-1], device=x.device)
out.scatter_add_(0, dst.unsqueeze(-1).expand_as(messages), messages)
out = out / deg
return self.out_proj(out)6.2 选择性SSM
class SelectiveSSM(nn.Module):
"""
选择性状态空间模型
核心创新:根据输入动态生成SSM参数
"""
def __init__(self, d_model, d_state=16):
super().__init__()
self.d_model = d_model
self.d_state = d_state
# 参数投影网络
self.x_proj = nn.Sequential(
nn.Linear(d_model, d_model * 2 + d_state * 2 + 1),
nn.GLU()
)
# dt投影
self.dt_proj = nn.Sequential(
nn.Linear(1, d_model),
nn.Softplus()
)
# A矩阵初始化(负值保证稳定性)
self.A = nn.Parameter(torch.randn(d_model, d_state))
nn.init.normal_(self.A, mean=0, std=-1) # 负值
# 输出投影
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x, h_prev=None):
"""
Args:
x: 输入特征 [N, d_model]
h_prev: 上一时刻状态 [N, d_state]
Returns:
新状态 [N, d_model]
"""
if h_prev is None:
h_prev = torch.zeros(x.shape[0], self.d_state, device=x.device)
# 生成输入依赖参数
x_dbl = self.x_proj(x) # [N, d_model + d_state*2 + 1]
# 分割参数
dt, B, C = x_dbl.split(
[1, self.d_state, self.d_state], dim=-1
)
# dt变换
dt = self.dt_proj(dt) # [N, d_model]
# 离散化A矩阵
# A_discrete = exp(dt * A)
A = torch.exp(torch.einsum('nd,ds->ns', dt, -torch.abs(self.A)))
# 离散化B矩阵
# B_discrete = (exp(dt * A) - I) * inv(dt * A) * B
B = (torch.exp(torch.einsum('nd,ds->ns', dt, -torch.abs(self.A))) - 1) * B
B = B / (-torch.abs(self.A).unsqueeze(0) + 1e-8) * torch.sigmoid(dt) * C # 简化形式
# 选择性扫描(简化的并行扫描)
# 实际实现需要更复杂的并行扫描算法
h = torch.einsum('ns,nm,nd->md', A, B, x) + torch.einsum('nm,md->nd', h_prev, A)
# 输出投影
y = torch.einsum('md,dm->mm', h, C)
return self.out_proj(y), h6.3 完整Mamba-GCN层
class MambaGCNLayer(nn.Module):
"""
Mamba-GCN单层
整合选择性消息传递和选择性SSM
"""
def __init__(self, in_dim, out_dim, d_state=16):
super().__init__()
self.msg_pass = SelectiveMessagePassing(in_dim, out_dim)
self.ssm = SelectiveSSM(out_dim, d_state)
self.norm1 = nn.LayerNorm(out_dim)
self.norm2 = nn.LayerNorm(out_dim)
self.dropout = nn.Dropout(0.1)
# 跳跃系数
self.alpha_net = nn.Sequential(
nn.Linear(out_dim, 1),
nn.Sigmoid()
)
def forward(self, x, edge_index, h_state=None):
# 选择性消息传递
m = self.msg_pass(x, edge_index)
# 残差连接1
h = self.norm1(x + m)
h = self.dropout(h)
# 选择性SSM
h_new, h_state = self.ssm(h, h_state)
# 跳跃连接
alpha = self.alpha_net(h)
h = alpha * h_new + (1 - alpha) * h
# 残差连接2
out = self.norm2(h + x)
return out, h_state
class MambaGCN(nn.Module):
"""
完整的Mamba-GCN模型
"""
def __init__(self, in_dim, hidden_dim, out_dim, num_layers, d_state=16):
super().__init__()
self.embedding = nn.Linear(in_dim, hidden_dim)
self.layers = nn.ModuleList([
MambaGCNLayer(hidden_dim, hidden_dim, d_state)
for _ in range(num_layers - 1)
])
self.classifier = nn.Linear(hidden_dim, out_dim)
def forward(self, x, edge_index):
h = self.embedding(x)
h_states = []
for layer in self.layers:
h, h_state = layer(h, edge_index)
h_states.append(h_state)
return self.classifier(h)7. 总结与展望
7.1 主要贡献
- 理论贡献:证明了选择性机制可以从根本上解决over-smoothing问题
- 方法创新:提出选择性消息传递和选择性SSM的深度融合架构
- 实验验证:在大规模数据集上验证了方法的有效性
- 实践价值:保持复杂度的同时实现高性能
7.2 局限性
- SSM实现复杂度:并行选择性扫描的高效实现具有挑战性
- 超参数敏感:d_state等超参数需要调优
- 异构图支持:目前主要针对同构图设计
7.3 未来方向
- 更高效的选择性扫描实现
- 异构图和动态图的扩展
- 与其他GNN变体的结合
- 在知识图谱、分子图等领域的应用
参考文献
相关资源
- 原论文: https://arxiv.org/abs/2501.15461
- IJCAI 2025: https://www.ijcai.org/proceedings/2025/595
- 相关模型: Mamba与GNN融合
Footnotes
-
Li et al. (2018): “Deeper insights into graph neural networks for node classification”, arXiv ↩
-
He et al. (2025): “Mamba-Based Graph Convolutional Networks: Tackling Over-smoothing with Selective State Space”, IJCAI 2025 ↩ ↩2
-
Gu & Dao (2023): “Mamba: Linear-Time Sequence Modeling with Selective State Spaces”, arXiv ↩