1. 研究背景与问题定义

1.1 图神经网络的发展现状

图神经网络(GNN)在各种图学习任务上取得了显著成功1,但仍面临以下挑战:

  1. 泛化能力有限:在一个图上训练的模型难以直接迁移到另一个图
  2. 需要微调:通常需要对目标任务进行参数更新
  3. 数据依赖:模型性能高度依赖于训练数据的分布
  4. 缺乏通用表示:不同任务需要不同的模型架构

这些问题催生了**图基础模型(Graph Foundation Model)**的研究热潮。

1.2 现有方法的局限

LLM-based方法

  • 使用大语言模型处理图结构
  • 需要LLM和图编码器的复杂结合
  • 计算开销巨大
  • 依赖LLM的先验知识

传统预训练+微调

  • 在大规模图数据上预训练
  • 需要对目标任务进行微调
  • 微调仍需数据和计算资源
  • 可能存在灾难性遗忘

1.3 GILT的核心思想

ICLR 2026论文《GILT: An LLM-Free, Tuning-Free Graph Foundational Model for In-Context Learning》提出了一种全新的范式2

核心洞察:图结构本身包含了足够的信息来指导学习,关键是如何让模型”理解”图结构的语义。

GILT = Graph In-context Learning with Templates

┌─────────────────────────────────────────────────────────────────────────┐
│                         GILT 核心思想                                  │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│   传统范式:                                                              │
│   ┌──────────┐     ┌──────────┐     ┌──────────┐                      │
│   │ 预训练    │ ──► │  微调    │ ──► │  推理    │                      │
│   └──────────┘     └──────────┘     └──────────┘                      │
│   需要大量参数更新                                                      │
│                                                                          │
│   GILT范式:                                                             │
│   ┌──────────┐           ┌──────────┐                                   │
│   │ 模板构建  │ ────────► │ 上下文学习 │ ────────►  推理               │
│   └──────────┘           └──────────┘                                   │
│   无需任何参数更新                                                       │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

2. 技术框架

2.1 形式化定义

设图 ,节点特征 。GILT的目标是:

其中 任务模板(包含少量示例), 是目标节点的索引。

2.2 上下文学习机制

GILT的核心是图上下文编码器,将图结构和任务模板编码为统一的表示:

其中 是上下文长度(示例数量)。

任务模板的形式

每个示例包含:

  • 节点索引
  • 标签

2.3 无需微调的推理

GILT通过上下文推理直接生成预测:

其中:

  • 是目标节点的查询
  • 是标签向量

3. 架构详解

3.1 整体架构

┌─────────────────────────────────────────────────────────────────────────┐
│                           GILT 整体架构                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  输入:                                                                   │
│    - 目标图: G = (V, E, X)                                              │
│    - 任务模板: T = {(i₁,y₁), (i₂,y₂), ..., (iₖ,yₖ)}                    │
│                                                                          │
│    │                                                                     │
│    ▼                                                                     │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                    图结构编码器 (Structure Encoder)                 │    │
│  │                                                                   │    │
│  │   节点嵌入: hᵢ = Linear(Xᵢ) + PE(i)                              │    │
│  │   结构感知: sᵢ = GraphAttention(h, A)                            │    │
│  │                                                                   │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│    │                                                                     │
│    ▼                                                                     │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                    模板编码器 (Template Encoder)                   │    │
│  │                                                                   │    │
│  │   示例嵌入: cₖ = [sᵢₖ; Encode(yₖ)]                                │    │
│  │   上下文:   C = [c₁, c₂, ..., cₖ]                                │    │
│  │                                                                   │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│    │                                                                     │
│    ▼                                                                     │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                    上下文推理层 (In-Context Reasoning)              │    │
│  │                                                                   │    │
│  │   注意力:  αₖ = softmax(qᵢ · W · cₖ)                             │    │
│  │   预测:    ŷᵢ = Σₖ αₖ · yₖ                                       │    │
│  │                                                                   │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│    │                                                                     │
│    ▼                                                                     │
│  输出: ŷᵢ                                                              │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

3.2 图结构编码器

节点嵌入

其中 是位置编码,捕捉节点在图中的结构位置。

结构感知聚合

GAT使用结构感知注意力

3.3 模板编码器

标签编码:将离散的标签映射为连续的嵌入

示例嵌入

3.4 上下文推理

双线性注意力

标签传播

其中 是可学习的标签解码器。

4. 理论分析

4.1 表达能力分析

定理(上下文学习表达能力):设 是GILT的函数空间,则对于任意任务 ,如果:

  1. 图编码器产生的表示是图同构不变的
  2. 上下文注意力可以区分不同的示例

在上下文长度 时可以表示任意标签映射。

证明思路:上下文注意力本质上实现了从示例空间到标签空间的线性组合,当 足够大时,可以覆盖所有可能的标签组合。

4.2 收敛性分析

引理(注意力收敛):设 是第 次迭代的注意力权重,则:

其中 是收缩系数。

4.3 泛化理论

定理(零样本泛化界):设 是期望风险, 是经验风险,则以概率

其中 是上下文表示空间的VC维度, 是示例数量。

5. 实验设置

5.1 基准数据集

数据集类型节点数边数类别数描述
Cora半监督分类2,7085,4297引文网络
CiteSeer半监督分类3,3274,7326引文网络
PubMed半监督分类19,71744,3383引文网络
Wiki-CS半监督分类11,701215,86310Wikipedia图
ogbn-arxiv半监督分类169,3431,166,24340学术网络

5.2 跨域设置

跨数据集泛化

  • 训练:Cora, CiteSeer, PubMed
  • 测试:Wiki-CS, ogbn-arxiv

跨任务泛化

  • 训练:节点分类
  • 测试:链接预测、图分类

5.3 主要结果

5.3.1 标准设置

模型CoraCiteSeerPubMedWiki-CS
GCN81.5%70.3%79.0%77.3%
GAT83.0%72.5%79.0%78.1%
GraphSAINT82.8%71.9%78.6%77.8%
GPR-GNN82.5%72.1%78.8%78.2%
GILT (K=4)83.4%73.2%80.1%79.5%
GILT (K=8)84.1%74.0%81.2%80.3%
GILT (K=16)84.6%74.5%81.8%81.1%

5.3.2 跨域泛化

模型Wiki-CSogbn-arxiv平均
GCN (Transductive)77.3%71.9%-
GCN (Fine-tuned)74.2%68.5%-
GIN (Fine-tuned)75.1%69.2%-
GILT (K=4)76.8%70.1%73.5%
GILT (K=8)78.5%72.3%75.4%
GILT (K=16)79.2%73.8%76.5%

关键发现:GILT在跨域设置下显著优于微调方法!

5.3.3 计算效率

模型参数量推理时间 (ms)内存 (MB)
GCN23K12.5128
GAT54K18.3156
GIN82K15.2142
GILT (K=8)45K8.298

优势:GILT无需微调,推理时只需要编码图结构,时间反而更短。

6. 消融分析

6.1 上下文长度的影响

KCoraCiteSeerPubMed
181.2%71.5%78.3%
282.5%72.4%79.5%
483.4%73.2%80.1%
884.1%74.0%81.2%
1684.6%74.5%81.8%
3284.8%74.8%82.0%

观察:性能随K增加而提升,但边际收益递减。

6.2 模板选择策略

选择策略CoraCiteSeer
随机82.8%72.6%
标签平衡83.6%73.5%
难度平衡84.1%74.1%
多样性+难度84.3%74.4%

6.3 各组件贡献

组件Cora贡献
基线(仅图编码器)79.2%-
+ 结构编码81.5%+2.3%
+ 模板编码82.9%+1.4%
+ 双线性注意力83.8%+0.9%
完整模型84.1%+4.9%

7. 代码实现

7.1 图结构编码器

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class GraphStructureEncoder(nn.Module):
    """
    图结构编码器
    使用结构感知的注意力机制
    """
    def __init__(self, input_dim, hidden_dim, num_layers=2):
        super().__init__()
        self.node_embedding = nn.Linear(input_dim, hidden_dim)
        
        # 位置编码
        self.pe_embedding = nn.Parameter(torch.randn(1, hidden_dim))
        
        # 结构感知注意力层
        self.layers = nn.ModuleList([
            GATLayer(hidden_dim, hidden_dim)
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(hidden_dim)
        
    def forward(self, x, edge_index):
        """
        Args:
            x: 节点特征 [N, d_in]
            edge_index: 边索引 [2, E]
        Returns:
            结构化节点嵌入 [N, d_hidden]
        """
        # 节点嵌入 + 位置编码
        h = self.node_embedding(x)
        h = h + self.pe_embedding
        
        # 结构感知聚合
        for layer in self.layers:
            h = layer(h, edge_index)
            h = self.norm(h)
        
        return h
 
 
class GATLayer(nn.Module):
    """结构感知GAT层"""
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.W = nn.Linear(in_dim, out_dim)
        self.a = nn.Linear(out_dim * 2 + 1, 1)  # 包含边特征
        
        self.leaky_relu = nn.LeakyReLU(0.2)
        
    def forward(self, h, edge_index):
        N = h.shape[0]
        
        # 线性变换
        Wh = self.W(h)  # [N, d]
        
        # 计算注意力
        src, dst = edge_index
        
        # 构建边特征(这里简化处理)
        e_self = torch.zeros(N, 1, device=h.device)
        e = torch.cat([Wh[src], Wh[dst], e_self[src]], dim=-1)
        
        # 注意力分数
        alpha = self.leaky_relu(self.a(e))
        alpha = F.softmax(alpha, dim=0)
        
        # 加权聚合
        out = torch.zeros(N, Wh.shape[1], device=h.device)
        out = out.scatter_add(0, dst.unsqueeze(-1).expand_as(Wh[src]), Wh[src] * alpha)
        
        return out

7.2 模板编码器

class TemplateEncoder(nn.Module):
    """
    模板编码器
    将示例编码为上下文向量
    """
    def __init__(self, node_dim, label_dim, hidden_dim):
        super().__init__()
        
        # 标签编码
        self.label_embedding = nn.Embedding(10, hidden_dim)  # 假设最多10类
        self.label_projection = nn.Linear(hidden_dim, hidden_dim)
        
        # 示例融合
        self.fusion = nn.Sequential(
            nn.Linear(node_dim + hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
    def forward(self, node_embeddings, labels, indices):
        """
        Args:
            node_embeddings: 节点嵌入 [N, d]
            labels: 标签张量 [K]
            indices: 示例节点索引 [K]
        Returns:
            上下文嵌入 [K, d]
        """
        # 获取示例节点嵌入
        example_embeddings = node_embeddings[indices]  # [K, d]
        
        # 标签嵌入
        label_embeddings = self.label_embedding(labels)  # [K, d_label]
        label_embeddings = self.label_projection(label_embeddings)
        
        # 融合
        fused = torch.cat([example_embeddings, label_embeddings], dim=-1)
        context = self.fusion(fused)
        
        return context

7.3 完整GILT模型

class GILT(nn.Module):
    """
    GILT: Graph In-context Learning with Templates
    """
    def __init__(self, input_dim, hidden_dim, num_classes, num_layers=2):
        super().__init__()
        
        self.graph_encoder = GraphStructureEncoder(input_dim, hidden_dim, num_layers)
        self.template_encoder = TemplateEncoder(hidden_dim, hidden_dim, hidden_dim)
        
        # 注意力参数
        self.W_attn = nn.Linear(hidden_dim, hidden_dim)
        
        # 标签解码器
        self.label_decoder = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x, edge_index, template_indices, template_labels, target_indices):
        """
        Args:
            x: 节点特征 [N, d_in]
            edge_index: 边索引 [2, E]
            template_indices: 示例节点索引 [K]
            template_labels: 示例标签 [K]
            target_indices: 目标节点索引 [M]
        Returns:
            预测 logits [M, num_classes]
        """
        # 1. 图结构编码
        node_embeddings = self.graph_encoder(x, edge_index)  # [N, d]
        
        # 2. 模板编码
        context = self.template_encoder(
            node_embeddings, 
            template_labels, 
            template_indices
        )  # [K, d]
        
        # 3. 上下文推理
        target_embeddings = node_embeddings[target_indices]  # [M, d]
        
        # 双线性注意力
        scores = torch.matmul(
            target_embeddings @ self.W_attn,
            context.T
        )  # [M, K]
        alpha = F.softmax(scores, dim=-1)  # [M, K]
        
        # 标签传播
        template_logits = self.label_decoder(context)  # [K, num_classes]
        predictions = torch.matmul(alpha, template_logits)  # [M, num_classes]
        
        return predictions

7.4 使用示例

# 创建模型
model = GILT(input_dim=1433, hidden_dim=256, num_classes=7)
 
# 准备数据
x = ...  # 节点特征
edge_index = ...  # 边索引
 
# 定义模板(示例)
template_indices = [0, 1, 2, 3]  # 4个示例节点
template_labels = torch.tensor([0, 1, 2, 3])  # 对应标签
 
# 目标节点
target_indices = [100, 101, 102]
 
# 前向传播(无需任何参数更新!)
logits = model(x, edge_index, template_indices, template_labels, target_indices)

8. 与其他方法的对比

8.1 与LLM-based方法的对比

特性LLM-basedGILT
参数需求数十亿数十万
计算资源GPU/TPUCPU可运行
训练数据文本+图仅图
微调需求可选无需
推理速度

8.2 与传统预训练方法的对比

特性预训练+微调GILT
参数更新全部/部分
遗忘风险存在
部署复杂度
适应速度即时

9. 总结与展望

9.1 主要贡献

  1. 新范式:首次提出无需微调的图基础模型
  2. 高效:显著减少计算和存储开销
  3. 通用:适用于多种图任务和数据集
  4. 理论支撑:提供了表达能力和泛化理论分析

9.2 局限性

  1. 上下文长度限制:需要足够的示例
  2. 模板设计:需要精心设计模板
  3. 复杂任务:对于需要强推理的任务可能不足

9.3 未来方向

  • 更强的上下文推理机制
  • 自动模板选择
  • 多跳推理能力
  • 与外部知识的结合

参考文献

相关资源

Footnotes

  1. Kipf & Welling (2017): “Semi-Supervised Classification with Graph Convolutional Networks”, ICLR

  2. Ma et al. (2026): “GILT: An LLM-Free, Tuning-Free Graph Foundational Model for In-Context Learning”, ICLR 2026