概述

图注意力网络(Graph Attention Network, GAT)通过注意力机制为不同邻居节点分配动态权重,克服了GCN中邻居权重固定的局限性。1 GATv2进一步改进注意力打分函数,使模型能够捕获更丰富的依赖关系。2

GAT vs GCN 核心区别

特性GCNGAT
邻居权重固定(由度矩阵决定)自适应学习
聚合方式归一化求和注意力加权
感受野固定拓扑可学习拓扑
表达能力较弱更强
计算复杂度

GAT架构详解

注意力系数计算

原始GAT的注意力打分函数:

其中:

  • :线性变换矩阵
  • :注意力参数向量
  • :拼接操作

LeakyReLU激活

使用LeakyReLU允许负值有较小的梯度,防止”死神经元”。

归一化注意力权重

节点更新

单头注意力的输出:


GATv2:动态注意力

原始GAT的问题

原始GAT的注意力是静态的——所有位置共享相同的注意力模式:

这导致:

  1. 排名不变性:无法区分不同邻居的相对重要性
  2. 表达能力受限:所有查询使用相同的键交互模式

GATv2的打分函数

GATv2提出动态注意力

关键改进

  1. 使用加性注意力而非拼接
  2. 先对query和key进行求和再加偏置
  3. 最后通过共享的线性层输出标量

数学对比

版本公式表达能力
GAT拼接后投影
GATv2求和后非线性

GATv2的注意力头可以表示任意排列不变的标量函数(如MLP),表达能力显著增强。

位置感知能力对比

原始GAT (Static):          GATv2 (Dynamic):
查询: "谁更重要?"           查询: "相对于我,你有多重要?"

注意力分数分布:            注意力分数分布:
[0.2, 0.2, 0.2, 0.2, 0.2]  [0.05, 0.6, 0.1, 0.05, 0.2]
   ↑ 所有邻居相同              ↑ 动态分配

多头注意力机制

聚合过程

使用 个独立的注意力头:

其中 表示向量拼接。

最后一层的特殊处理

对于回归任务,通常使用平均而非拼接:

超参数选择

参数推荐值说明
注意力头数 4-88在多数任务表现最佳
隐藏维度64-128每头的维度
Dropout0.0-0.6训练时使用
ELU/GELUELU激活函数

PyTorch实现

GAT层实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class GATLayer(nn.Module):
    """单层GAT实现"""
    def __init__(self, in_features, out_features, num_heads=1, concat=True, dropout=0.0, alpha=0.2):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_heads = num_heads
        self.concat = concat
        self.dropout = dropout
        
        self.W = nn.Linear(in_features, out_features * num_heads, bias=False)
        self.a = nn.Linear(2 * out_features, 1, bias=False)
        
        self.leakyrelu = nn.LeakyReLU(alpha)
        self.dropout_layer = nn.Dropout(dropout)
        
        nn.init.xavier_uniform_(self.W.weight)
        nn.init.xavier_uniform_(self.a.weight)
    
    def forward(self, h, adj):
        """
        h: (N, in_features) 节点特征
        adj: (N, N) 邻接矩阵
        """
        N = h.shape[0]
        
        # 线性变换: (N, in_features) -> (N, num_heads * out_features)
        Wh = self.W(h)
        # 分头: (N, num_heads, out_features)
        Wh = Wh.view(N, self.num_heads, self.out_features)
        
        # 计算注意力分数
        e = self._prepare_attentional_mechanism_input(Wh)
        
        # 应用掩码(将非邻居设为-inf)
        e = e.where(adj.unsqueeze(1) > 0, torch.tensor(-1e9).to(e.device))
        
        # Softmax归一化
        attention = F.softmax(e, dim=-1)
        attention = self.dropout_layer(attention)
        
        # 加权求和
        h_prime = torch.bmm(attention, Wh)  # (N, num_heads, out_features)
        
        if self.concat:
            h_prime = h_prime.reshape(N, self.num_heads * self.out_features)
        else:
            h_prime = h_prime.mean(dim=1)
        
        return h_prime
    
    def _prepare_attentional_mechanism_input(self, Wh):
        """计算所有节点对的注意力分数"""
        N = Wh.shape[0]
        
        # [Wh_i, Wh_j] 拼接: (N, N, num_heads, 2*out_features)
        whs = torch.cat([Wh.unsqueeze(1).expand(N, N, self.num_heads, self.out_features),
                         Wh.unsqueeze(0).expand(N, N, self.num_heads, self.out_features)], dim=-1)
        
        # e: (N, N, num_heads)
        e = self.a(whs).squeeze(-1)
        return self.leakyrelu(e)

GATv2层实现

class GATv2Layer(nn.Module):
    """GATv2层:动态注意力"""
    def __init__(self, in_features, out_features, num_heads=1, concat=True, dropout=0.0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_heads = num_heads
        self.concat = concat
        
        self.W = nn.Linear(in_features, out_features * num_heads, bias=False)
        self.att = nn.Linear(out_features, 1, bias=False)
        self.dropout = nn.Dropout(dropout)
        
        self.act = nn.ELU()
    
    def forward(self, h, adj):
        N = h.shape[0]
        M = self.num_heads
        d_h = self.out_features
        
        # 线性变换并分头
        Wh = self.W(h).view(N, M, d_h)
        
        # GATv2: 计算注意力分数
        e = self._compute_scores(Wh)  # (N, N, M)
        
        # 应用掩码
        e = e.where(adj.unsqueeze(-1) > 0, torch.tensor(-1e9).to(e.device))
        
        # 归一化
        attention = F.softmax(e, dim=1)
        attention = self.dropout(attention)
        
        # 加权聚合
        h_prime = torch.einsum('nmi,nmh->nih', attention, Wh)  # (N, M, d_h)
        
        if self.concat:
            return h_prime.reshape(N, M * d_h)
        else:
            return h_prime.mean(dim=1)
    
    def _compute_scores(self, Wh):
        """
        GATv2的核心改进:
        注意力分数 = w^T * σ(W*q + W*k)
        """
        N, M, d = Wh.shape
        
        # 扩展维度用于广播
        Wh_i = Wh.unsqueeze(1)  # (1, N, M, d)
        Wh_j = Wh.unsqueeze(0)  # (N, 1, M, d)
        
        # 求和形式 (不是拼接)
        combined = Wh_i + Wh_j  # (N, N, M, d)
        
        # 通过非线性 + 线性层得到标量
        e = self.att(self.act(combined))  # (N, N, M, 1)
        return e.squeeze(-1)

完整GAT模型

class GAT(nn.Module):
    """多层GAT模型"""
    def __init__(self, in_features, hidden_features, out_features, num_heads=8, dropout=0.6):
        super().__init__()
        self.conv1 = GATLayer(in_features, hidden_features, num_heads, concat=True, dropout=dropout)
        self.conv2 = GATLayer(hidden_features * num_heads, out_features, num_heads=1, concat=False, dropout=dropout)
        self.dropout = dropout
    
    def forward(self, x, adj):
        x = F.elu(self.conv1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.conv2(x, adj)
        return F.log_softmax(x, dim=1)

使用PyTorch Geometric

from torch_geometric.nn import GATConv, GATv2Conv
import torch.nn.functional as F
 
class PyG_GAT(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
        super().__init__()
        self.gat1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.6)
        self.gat2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6)
    
    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.gat1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.gat2(x, edge_index)
        return x
 
# 使用GATv2
class PyG_GATv2(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
        super().__init__()
        self.gat1 = GATv2Conv(in_channels, hidden_channels, heads=heads, dropout=0.6)
        self.gat2 = GATv2Conv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6)
    
    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.gat1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.gat2(x, edge_index)
        return x

谱域视角分析

GAT与谱方法的关系

GAT通过数据驱动的方式学习谱域的滤波器响应。

谱GCN的滤波器

GAT的隐式滤波

频率响应分析

GAT的注意力机制等价于学习了一个位置相关的谱滤波器

其中 是由注意力机制学习的频率间交互函数。

低通 vs 高通特性

特性GCNGAT
频率响应低通(平滑)自适应
低频放大固定可学习
高频保留抑制可学习

实战调参指南

超参数敏感性排序

重要性参数调参建议
⭐⭐⭐⭐⭐注意力头数4-8头效果最佳
⭐⭐⭐⭐隐藏维度与特征维度匹配
⭐⭐⭐Dropout0.0-0.6,推荐0.6
⭐⭐层数2-4层,3层最常见
学习率0.005-0.01

训练技巧

1. 注意力归一化

# 归一化方式影响显著
# 推荐:行softmax(默认)
# 备选:对称归一化
attention = F.softmax(e, dim=-1)  # 行归一化

2. 残差连接

class ResidualGAT(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.gat = GATLayer(in_features, out_features)
        self.res_proj = nn.Linear(in_features, out_features)
    
    def forward(self, x, adj):
        return self.gat(x, adj) + self.res_proj(x)  # 残差连接

3. 层归一化

class NormGAT(nn.Module):
    def forward(self, x, adj):
        h = self.gat(x, adj)
        return F.layer_norm(h, h.shape)  # 层归一化

常见问题与解决方案

问题原因解决方案
训练不收敛学习率过大降低到0.001
验证集性能下降过拟合增加dropout
注意力分数全0LeakyReLU斜率问题使用ELU
内存不足注意力矩阵太大减少头数或维度

实战案例:论文引用网络

数据集

使用Cora数据集:

  • 2,708篇论文
  • 5,429条引用关系
  • 7个类别

完整训练代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import add_self_loops
from torch_geometric.nn import GATConv
 
# 加载数据
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
 
# 添加自环
data.edge_index = add_self_loops(data.edge_index)[0]
 
# 定义模型
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.gat1 = GATConv(dataset.num_features, 8, heads=8, dropout=0.6)
        self.gat2 = GATConv(8*8, dataset.num_classes, heads=1, concat=False, dropout=0.6)
    
    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.gat1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.gat2(x, edge_index)
        return F.log_softmax(x, dim=1)
 
# 训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = data.to(device)
 
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
 
model.train()
for epoch in range(500):
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
 
# 测试
model.eval()
_, pred = model(data.x, data.edge_index).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum())
acc = correct / int(data.test_mask.sum())
print(f'测试准确率: {acc:.4f}')

预期结果

模型Cora准确率参数量
GCN~81.5%92K
GAT~83.0%93K
GATv2~83.5%93K

GAT的变体与发展

1. Semi-Supervised GAT

引入标签传播的半监督信号:

2. Edge Features GAT

处理边属性:

class EdgeGAT(nn.Module):
    def forward(self, x, edge_index, edge_attr):
        # 在注意力计算中融入边特征
        e = self.edge_mlp(torch.cat([x[edge_index[0]], edge_attr, x[edge_index[1]]], dim=-1))
        # ...

3. GraphSAT

使用非线性注意力:


参考


相关词条:图神经网络图卷积网络详解GNN表达力理论

Footnotes

  1. Veličković et al., “Graph Attention Networks”, ICLR 2018

  2. Brody et al., “How Attentive are Graph Attention Networks?”, ICLR 2022