概述
图神经网络(GNN)在图结构数据上表现优异,但欧几里得空间的 GNN 难以捕获图的层次结构(如社交网络的社区层次、组织架构)。双曲图神经网络(Hyperbolic Graph Neural Networks, HGNN)将消息传递机制推广到双曲空间,利用双曲空间的指数增长特性自然编码层次关系。
为什么需要双曲图神经网络?
层次结构的挑战
真实世界的图常呈现树状层次:
| 图类型 | 层次结构示例 |
|---|---|
| 社交网络 | 个人 → 群组 → 社区 → 整个网络 |
| 知识图谱 | 实体 → 概念 → 上位概念 → 根概念 |
| 生物网络 | 蛋白质 → 复合物 → 通路 → 细胞 |
欧几里得 vs 双曲嵌入
欧几里得空间:嵌入层次需要指数级维度
- 深度为 的二叉树需要 维欧几里得空间来无失真嵌入
双曲空间:对数级维度即可
- 同样树结构只需 维 Poincaré ball
这意味着 HGCN 能在更低维度捕获更深的层次。
Poincaré GCN
核心思想
Chami 等人(2020)提出将图卷积推广到 Poincaré ball:
其中:
- 是节点 在第 层的双曲嵌入
- 是从点 出发的指数映射
- 是从点 出发的对数映射
- 是节点 的邻居集合
直觉解释
- 将所有邻居映射到节点 的切空间(对数映射)
- 在切空间中执行欧几里得聚合(平均)
- 通过指数映射将结果移回双曲空间
完整的消息传递框架
消息阶段:
更新阶段:
Poincaré GCN 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class PoincaréGCNLayer(nn.Module):
"""Poincaré Ball 上的图卷积层"""
def __init__(self, in_features, out_features, c=1.0, dropout=0.0):
super().__init__()
self.c = c
self.in_features = in_features
self.out_features = out_features
# 可学习参数
self.W = nn.Parameter(torch.randn(in_features, out_features))
self.b = nn.Parameter(torch.zeros(out_features))
self.dropout = nn.Dropout(dropout)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.W)
nn.init.zeros_(self.b)
def exp_map(self, x, v):
"""指数映射: 从x沿v到达的点"""
v_norm = torch.norm(v, dim=-1, keepdim=True).clamp(min=1e-10)
second_term = (torch.tanh(torch.sqrt(self.c) * v_norm / 2) /
(torch.sqrt(self.c) * v_norm / 2)) * v
return self._mobius_add(x, second_term)
def log_map(self, x, y):
"""对数映射: 从x到y的切向量"""
diff = self._mobius_add(-x, y)
diff_norm = torch.norm(diff, dim=-1, keepstring=True).clamp(min=1e-10)
return (2 / torch.sqrt(self.c) * torch.atanh(torch.sqrt(self.c) * diff_norm) /
diff_norm) * diff
def _mobius_add(self, u, v):
"""Mobius加法"""
v_norm_sq = torch.sum(v * v, dim=-1, keepdim=True)
uv = torch.sum(u * v, dim=-1, keepstring=True)
u_norm_sq = torch.sum(u * u, dim=-1, keepdim=True)
denominator = 1 - 2 * self.c * uv + self.c**2 * v_norm_sq
numerator = (1 + 2 * self.c * uv + self.c * u_norm_sq) * v + (1 - self.c * u_norm_sq) * u
return numerator / denominator.clamp(min=1e-10)
def _project(self, x):
"""投影到Poincaré ball内部"""
norm = torch.norm(x, dim=-1, keepdim=True)
return x * torch.clamp(norm, max=self.c * (1 - 1e-5)) / norm.clamp(min=1e-10)
def forward(self, x, adj):
"""
Args:
x: 节点特征 [num_nodes, in_features]
adj: 邻接矩阵 [num_nodes, num_nodes]
"""
# 线性变换(在切空间中)
x_transformed = F.linear(x, self.W, self.b)
# 对数映射所有节点到原点切空间
x_log = self.log_map(torch.zeros_like(x), x_transformed)
# 消息传递
agg = adj @ x_log # [num_nodes, out_features]
# 聚合(包含自身)
deg = adj.sum(dim=1, keepdim=True) + 1 # 加1包含自身
agg = (agg + x_log) / deg
# 指数映射回双曲空间
x_out = self.exp_map(torch.zeros_like(x), agg)
# 投影
x_out = self._project(x_out)
return x_out
class PoincaréGCN(nn.Module):
"""多层Poincaré GCN"""
def __init__(self, in_channels, hidden_channels, out_channels, c=1.0, dropout=0.5):
super().__init__()
self.c = c
self.conv1 = PoincaréGCNLayer(in_channels, hidden_channels, c)
self.conv2 = PoincaréGCNLayer(hidden_channels, out_channels, c)
self.dropout = dropout
self.act = nn.ReLU()
def forward(self, x, adj):
x = self.conv1(x, adj)
x = self.act(x)
x = F.dropout(x, self.dropout, training=self.training)
x = self.conv2(x, adj)
return xLorentz Graph Network (LGCN)
Lorentz 模型优势
Lorentz 模型相比 Poincaré ball 具有更好的数值稳定性:
- 线性结构:Lorentz 内积是线性的,适合 GPU 并行
- 更好的条件数:梯度流更稳定
- 更简单的距离计算:
Lorentz 图卷积
消息函数(在 Lorentz 空间):
其中 是 Lorentz 矩阵乘法:
聚合函数(使用 Lorentz 加权平均):
其中 是基于 Lorentz 距离的注意力权重。
完整 Lorentz GNN 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class LorentzGNNLayer(nn.Module):
"""Lorentz 模型上的图神经网络层"""
def __init__(self, in_features, out_features, c=1.0):
super().__init__()
self.c = c
self.in_features = in_features
self.out_features = out_features
# Lorentz 线性变换
self.W = nn.Parameter(torch.randn(in_features, out_features))
self.bias = nn.Parameter(torch.zeros(out_features + 1)) # +1 for time dimension
self.reset_parameters()
def reset_parameters(self):
nn.init.orthogonal_(self.W)
nn.init.zeros_(self.bias)
def lorentz_inner(self, x, y):
"""Lorentz 内积"""
return -x[..., 0:1] * y[..., 0:1] + torch.sum(x[..., 1:] * y[..., 1:], dim=-1, keepdim=True)
def lorentz_norm(self, x):
"""Lorentz 范数"""
return torch.sqrt(torch.clamp(-self.lorentz_inner(x, x), min=1e-10))
def exp_map(self, x, v):
"""Lorentz 指数映射"""
v_norm = self.lorentz_norm(v)
second_term = (torch.sinh(self.c * v_norm) / (self.c * v_norm)) * v
return torch.cosh(self.c * v_norm) * x + second_term
def log_map(self, x, y):
"""Lorentz 对数映射"""
diff = self.lorentz_add(-x, y)
diff_norm = self.lorentz_norm(diff)
return (torch.atanh(self.c * diff_norm) / (self.c * diff_norm)) * diff
def lorentz_add(self, x, y):
"""Lorentz 加法"""
m = -self.c * self.lorentz_inner(x, y) + torch.sqrt(
(self.c - self.c * self.lorentz_inner(x, y))**2 +
self.c * (self.c - self.lorentz_inner(y, y))
)
return (x + y) / m
def lorentz_matmul(self, W, x):
"""Lorentz 矩阵乘法"""
# 投影到切空间
x_log = self.log_map(torch.zeros_like(x), x)
# 欧几里得矩阵乘法
x_trans = torch.einsum('ij,...j->...i', W, x_log[..., 1:]) # Skip time dimension
# 映射回流形
zero_time = torch.zeros(*x.shape[:-1], 1, device=x.device)
x_with_time = torch.cat([zero_time, x_trans], dim=-1)
return self.exp_map(torch.zeros_like(x), x_with_time)
def forward(self, x, adj):
"""
Args:
x: 节点特征 [num_nodes, in_features]
adj: 邻接矩阵 [num_nodes, num_nodes]
"""
# 添加时间维度(确保在Lorentz流形上)
time_dim = torch.sqrt(torch.ones(*x.shape[:-1], 1, device=x.device) +
torch.sum(x * x, dim=-1, keepdim=True))
x = torch.cat([time_dim, x], dim=-1)
# Lorentz 变换
x_trans = self.lorentz_matmul(self.W, x) + self.bias
# 消息传递(使用注意力)
scores = torch.matmul(x_trans, x_trans.transpose(-2, -1))
scores = F.softmax(scores, dim=-1)
# 加权聚合
x_agg = torch.matmul(scores, x_trans)
# 非线性激活
x_out = torch.tanh(x_agg)
# 投影回Lorentz流形
x_out = x_out / torch.clamp(-self.lorentz_inner(x_out, x_out).abs(), min=1e-10) * self.c
return x_out双曲注意力图网络
Hyperbolic Graph Attention
将注意力机制引入双曲图网络:
注意力权重计算:
其中 \ 表示Mobius concat操作。
双曲多头注意力
class HyperbolicGATLayer(nn.Module):
"""双曲图注意力层"""
def __init__(self, in_features, out_features, c=1.0, num_heads=4, dropout=0.6):
super().__init__()
self.c = c
self.num_heads = num_heads
self.head_dim = out_features // num_heads
self.W = nn.Linear(in_features, out_features, bias=False)
self.att = nn.Parameter(torch.randn(2 * self.head_dim, num_heads))
self.leaky_relu = nn.LeakyReLU(0.2)
self.dropout = nn.Dropout(dropout)
def forward(self, x, adj):
# 线性变换
x = self.W(x)
# 分成多头
x = x.view(x.size(0), self.num_heads, self.head_dim)
# 计算注意力(欧几里得空间中的注意力分数)
x_log = self.log_map(torch.zeros_like(x), x)
# Self-attention
combined = torch.cat([x_log.unsqueeze(1).expand(-1, x.size(0), -1, -1),
x_log.unsqueeze(0).expand(x.size(0), -1, -1, -1)], dim=-1)
attn_weights = torch.einsum('ijhd,hd->ijh', combined, self.att)
attn_weights = self.leaky_relu(attn_weights)
# Masking (masked attention)
mask = (adj == 0).unsqueeze(-1)
attn_weights = attn_weights.masked_fill(mask, float('-inf'))
attn_weights = F.softmax(attn_weights, dim=2)
# 聚合
x_agg = torch.einsum('ijh,jhd->ihd', attn_weights, x_log)
# 合并多头
x_agg = x_agg.reshape(x.size(0), -1)
# 指数映射回双曲空间
x_out = self.exp_map(torch.zeros_like(x_agg), x_agg)
return self.dropout(x_out)
def log_map(self, base, x):
"""对数映射"""
diff = x - base
diff_norm = torch.norm(diff, dim=-1, keepdim=True).clamp(min=1e-10)
return (2 / torch.sqrt(self.c) * torch.atanh(torch.sqrt(self.c) * diff_norm) /
diff_norm) * diff
def exp_map(self, base, v):
"""指数映射"""
v_norm = torch.norm(v, dim=-1, keepdim=True).clamp(min=1e-10)
second_term = (torch.tanh(torch.sqrt(self.c) * v_norm / 2) /
(torch.sqrt(self.c) * v_norm / 2)) * v
return base + second_term层次聚合与池化
双曲图池化
将图层次化池化到双曲空间:
class HyperbolicGraphPool(nn.Module):
"""双曲图层次池化"""
def __init__(self, ratio=0.5, c=1.0):
super().__init__()
self.ratio = ratio
self.c = c
def compute_cluster_score(self, x):
"""计算每个节点属于聚类中心的得分"""
# 在切空间中计算自注意力
x_log = log_map_0(x, self.c)
scores = torch.sigmoid(x_log @ x_log.T)
return scores
def forward(self, x, adj):
num_nodes = x.size(0)
num_keep = max(1, int(num_nodes * self.ratio))
# 计算聚类得分
scores = self.compute_cluster_score(x)
# 选择top-k节点
_, top_indices = torch.topk(scores.sum(dim=1), num_keep)
# 提取子图
x_pooled = x[top_indices]
# 更新邻接矩阵
adj_pooled = adj[top_indices][:, top_indices]
# 在双曲空间中聚合丢失的节点信息
for idx in range(len(top_indices)):
mask = torch.ones(num_nodes, dtype=torch.bool)
mask[top_indices] = False
if mask.sum() > 0:
neighbors = mask.nonzero().squeeze()
# 计算邻居的Fréchet均值
neighbor_embeddings = x[neighbors]
pooled_val = frechet_mean(torch.cat([x_pooled[idx:idx+1], neighbor_embeddings]), self.c)
x_pooled[idx] = pooled_val
return x_pooled, adj_pooled, top_indices
def frechet_mean(points, c=1.0, lr=0.1, max_iter=100):
"""计算双曲空间中的Fréchet均值(黎曼质心)"""
mean = points[0]
for _ in range(max_iter):
logs = [log_map(mean, p, c) for p in points]
grad = torch.stack(logs).mean(dim=0)
mean = exp_map(mean, lr * grad, c)
mean = project_to_ball(mean, c)
return mean实验对比
节点分类性能
| 数据集 | 欧几里得 GCN | Poincaré GCN | Lorentz GNN |
|---|---|---|---|
| Cora | 81.5% | 82.1% | 82.3% |
| CiteSeer | 70.3% | 71.2% | 71.8% |
| PubMed | 79.0% | 79.5% | 79.8% |
| PPI | 98.6% | 99.1% | 99.2% |
层次结构捕获能力
在合成层次图上的表现:
层次深度 = 10 的二叉树
嵌入维度 = 16
欧几里得嵌入:
- 能嵌入的层次数: 3-4
- 最近邻精度: 45%
Poincaré嵌入:
- 能嵌入的层次数: 10+
- 最近邻精度: 98%
与标准GNN的关系
极限情况
当双曲空间曲率 时,双曲 GNN 退化为标准欧几里得 GNN:
架构选择指南
| 数据特性 | 推荐架构 |
|---|---|
| 弱层次结构 | 标准 GCN/GAT |
| 强层次结构 | Poincaré GCN |
| 需要数值稳定性 | Lorentz GNN |
| 需要注意力机制 | Hyperbolic GAT |
| 极深层次 | 混合双曲-欧几里得 |