概述
图注意力网络(Graph Attention Network, GAT)通过注意力机制为不同邻居节点分配动态权重,克服了GCN中邻居权重固定的局限性。1 GATv2进一步改进注意力打分函数,使模型能够捕获更丰富的依赖关系。2
GAT vs GCN 核心区别
| 特性 | GCN | GAT |
|---|---|---|
| 邻居权重 | 固定(由度矩阵决定) | 自适应学习 |
| 聚合方式 | 归一化求和 | 注意力加权 |
| 感受野 | 固定拓扑 | 可学习拓扑 |
| 表达能力 | 较弱 | 更强 |
| 计算复杂度 |
GAT架构详解
注意力系数计算
原始GAT的注意力打分函数:
其中:
- :线性变换矩阵
- :注意力参数向量
- :拼接操作
LeakyReLU激活
使用LeakyReLU允许负值有较小的梯度,防止”死神经元”。
归一化注意力权重
节点更新
单头注意力的输出:
GATv2:动态注意力
原始GAT的问题
原始GAT的注意力是静态的——所有位置共享相同的注意力模式:
这导致:
- 排名不变性:无法区分不同邻居的相对重要性
- 表达能力受限:所有查询使用相同的键交互模式
GATv2的打分函数
GATv2提出动态注意力:
关键改进:
- 使用加性注意力而非拼接
- 先对query和key进行求和再加偏置
- 最后通过共享的线性层输出标量
数学对比
| 版本 | 公式 | 表达能力 |
|---|---|---|
| GAT | 拼接后投影 | |
| GATv2 | 求和后非线性 |
GATv2的注意力头可以表示任意排列不变的标量函数(如MLP),表达能力显著增强。
位置感知能力对比
原始GAT (Static): GATv2 (Dynamic):
查询: "谁更重要?" 查询: "相对于我,你有多重要?"
注意力分数分布: 注意力分数分布:
[0.2, 0.2, 0.2, 0.2, 0.2] [0.05, 0.6, 0.1, 0.05, 0.2]
↑ 所有邻居相同 ↑ 动态分配
多头注意力机制
聚合过程
使用 个独立的注意力头:
其中 表示向量拼接。
最后一层的特殊处理
对于回归任务,通常使用平均而非拼接:
超参数选择
| 参数 | 推荐值 | 说明 |
|---|---|---|
| 注意力头数 | 4-8 | 8在多数任务表现最佳 |
| 隐藏维度 | 64-128 | 每头的维度 |
| Dropout | 0.0-0.6 | 训练时使用 |
| ELU/GELU | ELU | 激活函数 |
PyTorch实现
GAT层实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class GATLayer(nn.Module):
"""单层GAT实现"""
def __init__(self, in_features, out_features, num_heads=1, concat=True, dropout=0.0, alpha=0.2):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.num_heads = num_heads
self.concat = concat
self.dropout = dropout
self.W = nn.Linear(in_features, out_features * num_heads, bias=False)
self.a = nn.Linear(2 * out_features, 1, bias=False)
self.leakyrelu = nn.LeakyReLU(alpha)
self.dropout_layer = nn.Dropout(dropout)
nn.init.xavier_uniform_(self.W.weight)
nn.init.xavier_uniform_(self.a.weight)
def forward(self, h, adj):
"""
h: (N, in_features) 节点特征
adj: (N, N) 邻接矩阵
"""
N = h.shape[0]
# 线性变换: (N, in_features) -> (N, num_heads * out_features)
Wh = self.W(h)
# 分头: (N, num_heads, out_features)
Wh = Wh.view(N, self.num_heads, self.out_features)
# 计算注意力分数
e = self._prepare_attentional_mechanism_input(Wh)
# 应用掩码(将非邻居设为-inf)
e = e.where(adj.unsqueeze(1) > 0, torch.tensor(-1e9).to(e.device))
# Softmax归一化
attention = F.softmax(e, dim=-1)
attention = self.dropout_layer(attention)
# 加权求和
h_prime = torch.bmm(attention, Wh) # (N, num_heads, out_features)
if self.concat:
h_prime = h_prime.reshape(N, self.num_heads * self.out_features)
else:
h_prime = h_prime.mean(dim=1)
return h_prime
def _prepare_attentional_mechanism_input(self, Wh):
"""计算所有节点对的注意力分数"""
N = Wh.shape[0]
# [Wh_i, Wh_j] 拼接: (N, N, num_heads, 2*out_features)
whs = torch.cat([Wh.unsqueeze(1).expand(N, N, self.num_heads, self.out_features),
Wh.unsqueeze(0).expand(N, N, self.num_heads, self.out_features)], dim=-1)
# e: (N, N, num_heads)
e = self.a(whs).squeeze(-1)
return self.leakyrelu(e)GATv2层实现
class GATv2Layer(nn.Module):
"""GATv2层:动态注意力"""
def __init__(self, in_features, out_features, num_heads=1, concat=True, dropout=0.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.num_heads = num_heads
self.concat = concat
self.W = nn.Linear(in_features, out_features * num_heads, bias=False)
self.att = nn.Linear(out_features, 1, bias=False)
self.dropout = nn.Dropout(dropout)
self.act = nn.ELU()
def forward(self, h, adj):
N = h.shape[0]
M = self.num_heads
d_h = self.out_features
# 线性变换并分头
Wh = self.W(h).view(N, M, d_h)
# GATv2: 计算注意力分数
e = self._compute_scores(Wh) # (N, N, M)
# 应用掩码
e = e.where(adj.unsqueeze(-1) > 0, torch.tensor(-1e9).to(e.device))
# 归一化
attention = F.softmax(e, dim=1)
attention = self.dropout(attention)
# 加权聚合
h_prime = torch.einsum('nmi,nmh->nih', attention, Wh) # (N, M, d_h)
if self.concat:
return h_prime.reshape(N, M * d_h)
else:
return h_prime.mean(dim=1)
def _compute_scores(self, Wh):
"""
GATv2的核心改进:
注意力分数 = w^T * σ(W*q + W*k)
"""
N, M, d = Wh.shape
# 扩展维度用于广播
Wh_i = Wh.unsqueeze(1) # (1, N, M, d)
Wh_j = Wh.unsqueeze(0) # (N, 1, M, d)
# 求和形式 (不是拼接)
combined = Wh_i + Wh_j # (N, N, M, d)
# 通过非线性 + 线性层得到标量
e = self.att(self.act(combined)) # (N, N, M, 1)
return e.squeeze(-1)完整GAT模型
class GAT(nn.Module):
"""多层GAT模型"""
def __init__(self, in_features, hidden_features, out_features, num_heads=8, dropout=0.6):
super().__init__()
self.conv1 = GATLayer(in_features, hidden_features, num_heads, concat=True, dropout=dropout)
self.conv2 = GATLayer(hidden_features * num_heads, out_features, num_heads=1, concat=False, dropout=dropout)
self.dropout = dropout
def forward(self, x, adj):
x = F.elu(self.conv1(x, adj))
x = F.dropout(x, self.dropout, training=self.training)
x = self.conv2(x, adj)
return F.log_softmax(x, dim=1)使用PyTorch Geometric
from torch_geometric.nn import GATConv, GATv2Conv
import torch.nn.functional as F
class PyG_GAT(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
super().__init__()
self.gat1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.6)
self.gat2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6)
def forward(self, x, edge_index):
x = F.dropout(x, p=0.6, training=self.training)
x = self.gat1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
x = self.gat2(x, edge_index)
return x
# 使用GATv2
class PyG_GATv2(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
super().__init__()
self.gat1 = GATv2Conv(in_channels, hidden_channels, heads=heads, dropout=0.6)
self.gat2 = GATv2Conv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6)
def forward(self, x, edge_index):
x = F.dropout(x, p=0.6, training=self.training)
x = self.gat1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
x = self.gat2(x, edge_index)
return x谱域视角分析
GAT与谱方法的关系
GAT通过数据驱动的方式学习谱域的滤波器响应。
谱GCN的滤波器:
GAT的隐式滤波:
频率响应分析
GAT的注意力机制等价于学习了一个位置相关的谱滤波器:
其中 是由注意力机制学习的频率间交互函数。
低通 vs 高通特性
| 特性 | GCN | GAT |
|---|---|---|
| 频率响应 | 低通(平滑) | 自适应 |
| 低频放大 | 固定 | 可学习 |
| 高频保留 | 抑制 | 可学习 |
实战调参指南
超参数敏感性排序
| 重要性 | 参数 | 调参建议 |
|---|---|---|
| ⭐⭐⭐⭐⭐ | 注意力头数 | 4-8头效果最佳 |
| ⭐⭐⭐⭐ | 隐藏维度 | 与特征维度匹配 |
| ⭐⭐⭐ | Dropout | 0.0-0.6,推荐0.6 |
| ⭐⭐ | 层数 | 2-4层,3层最常见 |
| ⭐ | 学习率 | 0.005-0.01 |
训练技巧
1. 注意力归一化
# 归一化方式影响显著
# 推荐:行softmax(默认)
# 备选:对称归一化
attention = F.softmax(e, dim=-1) # 行归一化2. 残差连接
class ResidualGAT(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super().__init__()
self.gat = GATLayer(in_features, out_features)
self.res_proj = nn.Linear(in_features, out_features)
def forward(self, x, adj):
return self.gat(x, adj) + self.res_proj(x) # 残差连接3. 层归一化
class NormGAT(nn.Module):
def forward(self, x, adj):
h = self.gat(x, adj)
return F.layer_norm(h, h.shape) # 层归一化常见问题与解决方案
| 问题 | 原因 | 解决方案 |
|---|---|---|
| 训练不收敛 | 学习率过大 | 降低到0.001 |
| 验证集性能下降 | 过拟合 | 增加dropout |
| 注意力分数全0 | LeakyReLU斜率问题 | 使用ELU |
| 内存不足 | 注意力矩阵太大 | 减少头数或维度 |
实战案例:论文引用网络
数据集
使用Cora数据集:
- 2,708篇论文
- 5,429条引用关系
- 7个类别
完整训练代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import add_self_loops
from torch_geometric.nn import GATConv
# 加载数据
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
# 添加自环
data.edge_index = add_self_loops(data.edge_index)[0]
# 定义模型
class Net(nn.Module):
def __init__(self):
super().__init__()
self.gat1 = GATConv(dataset.num_features, 8, heads=8, dropout=0.6)
self.gat2 = GATConv(8*8, dataset.num_classes, heads=1, concat=False, dropout=0.6)
def forward(self, x, edge_index):
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.gat1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.gat2(x, edge_index)
return F.log_softmax(x, dim=1)
# 训练
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
model.train()
for epoch in range(500):
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
# 测试
model.eval()
_, pred = model(data.x, data.edge_index).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum())
acc = correct / int(data.test_mask.sum())
print(f'测试准确率: {acc:.4f}')预期结果
| 模型 | Cora准确率 | 参数量 |
|---|---|---|
| GCN | ~81.5% | 92K |
| GAT | ~83.0% | 93K |
| GATv2 | ~83.5% | 93K |
GAT的变体与发展
1. Semi-Supervised GAT
引入标签传播的半监督信号:
2. Edge Features GAT
处理边属性:
class EdgeGAT(nn.Module):
def forward(self, x, edge_index, edge_attr):
# 在注意力计算中融入边特征
e = self.edge_mlp(torch.cat([x[edge_index[0]], edge_attr, x[edge_index[1]]], dim=-1))
# ...3. GraphSAT
使用非线性注意力: