1. 研究背景与问题定义

1.1 Over-smoothing问题回顾

图神经网络(GNN)在处理图结构数据时表现出色,但随着网络层数增加,面临严重的过平滑问题(Over-smoothing)1

  • 所有节点的表示趋向于收敛到相同的值
  • 节点间的可区分性消失
  • 深层网络的性能急剧下降

过平滑问题可以用数学语言形式化描述。设第 层节点表示为 ,则过平滑表现为2

其中 是全1向量, 是与图结构相关的某个固定向量。节点的个性化信息在多次消息传递后完全丧失。

1.2 现有解决方案的局限

针对over-smoothing问题,已有多种解决方案:

方法类别代表工作局限性
残差连接JKNet, DenseGCN延迟而非解决平滑
归一化技巧BatchNorm, LayerNorm计算开销增加
注意力机制GAT, GATv2 复杂度
跳跃知识JKNet仍受限于单层表达能力

这些方法要么引入额外计算开销,要么无法从根本上解决平滑问题。

1.3 Mamba的启发

Mamba(选择性状态空间模型)在序列建模中展现出卓越的选择性能力3

  • 输入依赖的选择性:根据输入动态决定信息传递
  • 线性复杂度 而非
  • 长距离依赖:通过状态空间建模捕捉远距离依赖

将Mamba的选择性机制引入GNN,理论上可以实现:

  1. 自适应选择重要邻居信息
  2. 动态控制信息传播范围
  3. 从根本上缓解over-smoothing

2. 论文核心贡献

2.1 主要创新点

IJCAI 2025论文《Mamba-Based Graph Convolutional Networks: Tackling Over-smoothing with Selective State Space》提出了Mamba-GCN架构2,核心创新包括:

  1. 选择性消息传递机制:根据节点特征动态选择聚合哪些邻居的信息
  2. 选择性状态空间模块:使用SSM建模节点状态的演化
  3. 端到端可训练架构:整个pipeline可微分,支持反向传播

2.2 与现有MbaGCN的区别

wiki中已有的Mamba与GNN融合文档介绍了MbaGCN架构。IJCAI 2025论文的Mamba-GCN有以下区别:

特性MbaGCNMamba-GCN (IJCAI 2025)
选择性机制门控信号硬选择 + 软选择
SSM集成方式层间融合层内深度融合
过平滑解决缓解从根本上消除
实验规模中等大规模实验验证

3. 理论框架

3.1 问题形式化

给定图 和节点特征 ,传统GCN的传播规则为:

其中 是带自环的邻接矩阵, 是度矩阵。

问题:无论输入特征如何,所有节点最终都会收敛到相似的表示。

3.2 选择性消息传递

Mamba-GCN的核心是选择性消息传递机制:

其中选择分数 由以下方式计算:

这里 选择性门控,由Mamba的选择机制生成:

3.3 选择性状态空间模块

节点状态通过选择性SSM进行演化:

具体的选择性SSM计算流程:

步骤1:参数生成(输入依赖)

步骤2:离散化

步骤3:选择性扫描

3.4 解决Over-smoothing的理论分析

定理1(选择性防止平滑):对于选择性SSM,如果存在节点 使得 的谱半径 ,则该节点的表示不会趋于平凡值。

证明思路:设初始表示为 ,经过 步传播后:

时,第一项指数衰减到0,节点保留了初始信息的”记忆”。

推论:由于 是输入依赖的,网络可以学习为不同节点生成不同的转移矩阵,从而保留节点的个性化信息。

4. 架构详解

4.1 整体结构

┌──────────────────────────────────────────────────────────────────────┐
│                       Mamba-GCN 整体架构                            │
├──────────────────────────────────────────────────────────────────────┤
│                                                                      │
│  输入: 节点特征 X, 边索引 E                                          │
│    │                                                                 │
│    ▼                                                                 │
│  ┌──────────────────────────────────────────────────────────────┐   │
│  │           选择性消息传递层 (Selective Message Passing)          │   │
│  │                                                              │   │
│  │   Query生成: Q = X W_Q                                        │   │
│  │   Key生成: K = X W_K                                          │   │
│  │   Value生成: V = X W_V                                        │   │
│  │   选择性门控: G = σ(W_g [X; XV] )                            │   │
│  │   加权聚合: M = Softmax(KQ^T/√d) ⊙ G ⊙ V                     │   │
│  │                                                              │   │
│  └──────────────────────────────────────────────────────────────┘   │
│    │                                                                 │
│    ▼                                                                 │
│  ┌──────────────────────────────────────────────────────────────┐   │
│  │              选择性SSM层 (Selective SSM Layer)                 │   │
│  │                                                              │   │
│  │   输入依赖参数: A(x), B(x), C(x), Δ(x)                       │   │
│  │   离散化: Ā, B̄                                                │   │
│  │   状态更新: h' = Āh + B̄m                                     │   │
│  │   输出投影: y = Ch'                                           │   │
│  │                                                              │   │
│  └──────────────────────────────────────────────────────────────┘   │
│    │                                                                 │
│    ▼                                                                 │
│  ┌──────────────────────────────────────────────────────────────┐   │
│  │                    残差连接 + LayerNorm                       │   │
│  │              H_out = LayerNorm(H + SSM(H))                   │   │
│  └──────────────────────────────────────────────────────────────┘   │
│    │                                                                 │
│    ▼                                                                 │
│  输出: 节点嵌入 H_out                                                │
│                                                                      │
└──────────────────────────────────────────────────────────────────────┘

4.2 层间连接

Mamba-GCN采用层间选择性连接

其中 是第 层的可学习跳跃系数:

这种设计允许网络自适应地控制信息流,浅层侧重于局部信息聚合,深层侧重于全局状态更新。

4.3 与现有工作的对比

特性GCNGATMbaGCNMamba-GCN
邻居选择固定注意力门控硬选择+软选择
状态建模SSM选择性SSM
平滑问题严重中等缓解根本解决
复杂度
可解释性

5. 实验结果

5.1 节点分类

在标准数据集上的节点分类结果:

数据集CoraCiteSeerPubMedogbn-arxiv
GCN81.5%70.3%79.0%71.9%
GAT83.0%72.5%79.0%72.1%
GraphSAGE82.3%72.1%78.5%72.3%
JKNet82.5%72.8%79.2%72.5%
MbaGCN85.2%74.1%81.3%73.8%
Mamba-GCN87.1%75.6%82.4%74.9%

5.2 深度实验

验证深层网络(2-32层)的性能:

准确率 (%)
    │
 90 ├─────────────────────────────────────────────•••• Mamba-GCN (32层: 85.2%)
    │                               ┄┄•••••┄┄•  •• • MbaGCN (32层: 78.3%)
 85 ├                           ┄┄•              •••• •
    │                       ┄┄•                       •••• JKNet (32层: 71.5%)
    │                   ┄┄•                               •••
 80 ├                ┄┄•                                       ••• GCN (32层: 58.2%)
    │           ┄┄•••                                           
    │      ┄┄•••                                            
 75 ├•••••                                                      
    └────────────────────────────────────────────────────────► 层数
      2    4    8    16   32                                  

关键观察

  • 传统GCN在16层后性能急剧下降(over-smoothing主导)
  • Mamba-GCN在32层仍保持高性能,证明了选择性机制的有效性
  • MbaGCN有改进但仍受限于平滑问题

5.3 消融实验

组件CoraCiteSeerPubMed
基线GCN81.5%70.3%79.0%
+ 选择性消息传递83.8%72.4%80.2%
+ 选择性SSM85.2%73.8%81.1%
+ 残差连接86.1%74.7%81.9%
完整模型87.1%75.6%82.4%

5.4 表示可视化

使用t-SNE对Cora数据集的节点嵌入进行可视化:

  • GCN(深层):所有类别节点混杂在一起,over-smoothing严重
  • Mamba-GCN(深层):不同类别清晰分离,保持良好的可区分性

6. 代码实现

6.1 选择性消息传递

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class SelectiveMessagePassing(nn.Module):
    """
    选择性消息传递层
    根据输入动态选择要聚合的邻居信息
    """
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.query_proj = nn.Linear(in_dim, out_dim)
        self.key_proj = nn.Linear(in_dim, out_dim)
        self.value_proj = nn.Linear(in_dim, out_dim)
        
        # 选择性门控网络
        self.gate_net = nn.Sequential(
            nn.Linear(in_dim * 2, out_dim),
            nn.Sigmoid()
        )
        
        # 输出投影
        self.out_proj = nn.Linear(out_dim, out_dim)
        
    def forward(self, x, edge_index):
        """
        Args:
            x: 节点特征 [N, d_in]
            edge_index: 边索引 [2, E]
        Returns:
            选择性聚合后的消息 [N, d_out]
        """
        N = x.shape[0]
        
        # 生成Q, K, V
        Q = self.query_proj(x)
        K = self.key_proj(x)
        V = self.value_proj(x)
        
        # 计算注意力分数
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Q.shape[-1])
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # 生成选择性门控
        src, dst = edge_index
        x_src = x[src]
        x_dst = x[dst]
        gate_input = torch.cat([x_dst, x_src], dim=-1)
        gate_values = self.gate_net(gate_input)  # [E, d]
        
        # 应用门控到注意力权重
        # 只保留与目标节点"相关"的邻居信息
        gated_weights = attn_weights[dst, src]  # [E]
        gated_weights = gated_weights.unsqueeze(-1) * gate_values  # [E, d]
        
        # 归一化
        deg = torch.zeros(N, device=x.device)
        deg.scatter_add_(0, dst, torch.ones(dst.shape[0], device=x.device))
        deg = deg.unsqueeze(-1) + 1e-8  # 避免除零
        
        # 加权聚合
        messages = V[src] * gated_weights
        out = torch.zeros(N, V.shape[-1], device=x.device)
        out.scatter_add_(0, dst.unsqueeze(-1).expand_as(messages), messages)
        out = out / deg
        
        return self.out_proj(out)

6.2 选择性SSM

class SelectiveSSM(nn.Module):
    """
    选择性状态空间模型
    核心创新:根据输入动态生成SSM参数
    """
    def __init__(self, d_model, d_state=16):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        
        # 参数投影网络
        self.x_proj = nn.Sequential(
            nn.Linear(d_model, d_model * 2 + d_state * 2 + 1),
            nn.GLU()
        )
        
        # dt投影
        self.dt_proj = nn.Sequential(
            nn.Linear(1, d_model),
            nn.Softplus()
        )
        
        # A矩阵初始化(负值保证稳定性)
        self.A = nn.Parameter(torch.randn(d_model, d_state))
        nn.init.normal_(self.A, mean=0, std=-1)  # 负值
        
        # 输出投影
        self.out_proj = nn.Linear(d_model, d_model)
        
    def forward(self, x, h_prev=None):
        """
        Args:
            x: 输入特征 [N, d_model]
            h_prev: 上一时刻状态 [N, d_state]
        Returns:
            新状态 [N, d_model]
        """
        if h_prev is None:
            h_prev = torch.zeros(x.shape[0], self.d_state, device=x.device)
        
        # 生成输入依赖参数
        x_dbl = self.x_proj(x)  # [N, d_model + d_state*2 + 1]
        
        # 分割参数
        dt, B, C = x_dbl.split(
            [1, self.d_state, self.d_state], dim=-1
        )
        
        # dt变换
        dt = self.dt_proj(dt)  # [N, d_model]
        
        # 离散化A矩阵
        # A_discrete = exp(dt * A)
        A = torch.exp(torch.einsum('nd,ds->ns', dt, -torch.abs(self.A)))
        
        # 离散化B矩阵
        # B_discrete = (exp(dt * A) - I) * inv(dt * A) * B
        B = (torch.exp(torch.einsum('nd,ds->ns', dt, -torch.abs(self.A))) - 1) * B
        B = B / (-torch.abs(self.A).unsqueeze(0) + 1e-8) * torch.sigmoid(dt) * C  # 简化形式
        
        # 选择性扫描(简化的并行扫描)
        # 实际实现需要更复杂的并行扫描算法
        h = torch.einsum('ns,nm,nd->md', A, B, x) + torch.einsum('nm,md->nd', h_prev, A)
        
        # 输出投影
        y = torch.einsum('md,dm->mm', h, C)
        
        return self.out_proj(y), h

6.3 完整Mamba-GCN层

class MambaGCNLayer(nn.Module):
    """
    Mamba-GCN单层
    整合选择性消息传递和选择性SSM
    """
    def __init__(self, in_dim, out_dim, d_state=16):
        super().__init__()
        self.msg_pass = SelectiveMessagePassing(in_dim, out_dim)
        self.ssm = SelectiveSSM(out_dim, d_state)
        self.norm1 = nn.LayerNorm(out_dim)
        self.norm2 = nn.LayerNorm(out_dim)
        self.dropout = nn.Dropout(0.1)
        
        # 跳跃系数
        self.alpha_net = nn.Sequential(
            nn.Linear(out_dim, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x, edge_index, h_state=None):
        # 选择性消息传递
        m = self.msg_pass(x, edge_index)
        
        # 残差连接1
        h = self.norm1(x + m)
        h = self.dropout(h)
        
        # 选择性SSM
        h_new, h_state = self.ssm(h, h_state)
        
        # 跳跃连接
        alpha = self.alpha_net(h)
        h = alpha * h_new + (1 - alpha) * h
        
        # 残差连接2
        out = self.norm2(h + x)
        
        return out, h_state
 
 
class MambaGCN(nn.Module):
    """
    完整的Mamba-GCN模型
    """
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers, d_state=16):
        super().__init__()
        self.embedding = nn.Linear(in_dim, hidden_dim)
        self.layers = nn.ModuleList([
            MambaGCNLayer(hidden_dim, hidden_dim, d_state)
            for _ in range(num_layers - 1)
        ])
        self.classifier = nn.Linear(hidden_dim, out_dim)
        
    def forward(self, x, edge_index):
        h = self.embedding(x)
        h_states = []
        
        for layer in self.layers:
            h, h_state = layer(h, edge_index)
            h_states.append(h_state)
        
        return self.classifier(h)

7. 总结与展望

7.1 主要贡献

  1. 理论贡献:证明了选择性机制可以从根本上解决over-smoothing问题
  2. 方法创新:提出选择性消息传递和选择性SSM的深度融合架构
  3. 实验验证:在大规模数据集上验证了方法的有效性
  4. 实践价值:保持复杂度的同时实现高性能

7.2 局限性

  1. SSM实现复杂度:并行选择性扫描的高效实现具有挑战性
  2. 超参数敏感:d_state等超参数需要调优
  3. 异构图支持:目前主要针对同构图设计

7.3 未来方向

  • 更高效的选择性扫描实现
  • 异构图和动态图的扩展
  • 与其他GNN变体的结合
  • 在知识图谱、分子图等领域的应用

参考文献

相关资源

Footnotes

  1. Li et al. (2018): “Deeper insights into graph neural networks for node classification”, arXiv

  2. He et al. (2025): “Mamba-Based Graph Convolutional Networks: Tackling Over-smoothing with Selective State Space”, IJCAI 2025 2

  3. Gu & Dao (2023): “Mamba: Linear-Time Sequence Modeling with Selective State Spaces”, arXiv