概述

图神经网络(GNN)在图结构数据上表现优异,但欧几里得空间的 GNN 难以捕获图的层次结构(如社交网络的社区层次、组织架构)。双曲图神经网络(Hyperbolic Graph Neural Networks, HGNN)将消息传递机制推广到双曲空间,利用双曲空间的指数增长特性自然编码层次关系。


为什么需要双曲图神经网络?

层次结构的挑战

真实世界的图常呈现树状层次

图类型层次结构示例
社交网络个人 → 群组 → 社区 → 整个网络
知识图谱实体 → 概念 → 上位概念 → 根概念
生物网络蛋白质 → 复合物 → 通路 → 细胞

欧几里得 vs 双曲嵌入

欧几里得空间:嵌入层次需要指数级维度

  • 深度为 的二叉树需要 维欧几里得空间来无失真嵌入

双曲空间:对数级维度即可

  • 同样树结构只需 维 Poincaré ball

这意味着 HGCN 能在更低维度捕获更深的层次


Poincaré GCN

核心思想

Chami 等人(2020)提出将图卷积推广到 Poincaré ball:

其中:

  • 是节点 在第 层的双曲嵌入
  • 是从点 出发的指数映射
  • 是从点 出发的对数映射
  • 是节点 的邻居集合

直觉解释

  1. 将所有邻居映射到节点 的切空间(对数映射)
  2. 在切空间中执行欧几里得聚合(平均)
  3. 通过指数映射将结果移回双曲空间

完整的消息传递框架

消息阶段

更新阶段

Poincaré GCN 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class PoincaréGCNLayer(nn.Module):
    """Poincaré Ball 上的图卷积层"""
    
    def __init__(self, in_features, out_features, c=1.0, dropout=0.0):
        super().__init__()
        self.c = c
        self.in_features = in_features
        self.out_features = out_features
        
        # 可学习参数
        self.W = nn.Parameter(torch.randn(in_features, out_features))
        self.b = nn.Parameter(torch.zeros(out_features))
        
        self.dropout = nn.Dropout(dropout)
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W)
        nn.init.zeros_(self.b)
    
    def exp_map(self, x, v):
        """指数映射: 从x沿v到达的点"""
        v_norm = torch.norm(v, dim=-1, keepdim=True).clamp(min=1e-10)
        second_term = (torch.tanh(torch.sqrt(self.c) * v_norm / 2) / 
                      (torch.sqrt(self.c) * v_norm / 2)) * v
        return self._mobius_add(x, second_term)
    
    def log_map(self, x, y):
        """对数映射: 从x到y的切向量"""
        diff = self._mobius_add(-x, y)
        diff_norm = torch.norm(diff, dim=-1, keepstring=True).clamp(min=1e-10)
        return (2 / torch.sqrt(self.c) * torch.atanh(torch.sqrt(self.c) * diff_norm) / 
                diff_norm) * diff
    
    def _mobius_add(self, u, v):
        """Mobius加法"""
        v_norm_sq = torch.sum(v * v, dim=-1, keepdim=True)
        uv = torch.sum(u * v, dim=-1, keepstring=True)
        u_norm_sq = torch.sum(u * u, dim=-1, keepdim=True)
        
        denominator = 1 - 2 * self.c * uv + self.c**2 * v_norm_sq
        numerator = (1 + 2 * self.c * uv + self.c * u_norm_sq) * v + (1 - self.c * u_norm_sq) * u
        
        return numerator / denominator.clamp(min=1e-10)
    
    def _project(self, x):
        """投影到Poincaré ball内部"""
        norm = torch.norm(x, dim=-1, keepdim=True)
        return x * torch.clamp(norm, max=self.c * (1 - 1e-5)) / norm.clamp(min=1e-10)
    
    def forward(self, x, adj):
        """
        Args:
            x: 节点特征 [num_nodes, in_features]
            adj: 邻接矩阵 [num_nodes, num_nodes]
        """
        # 线性变换(在切空间中)
        x_transformed = F.linear(x, self.W, self.b)
        
        # 对数映射所有节点到原点切空间
        x_log = self.log_map(torch.zeros_like(x), x_transformed)
        
        # 消息传递
        agg = adj @ x_log  # [num_nodes, out_features]
        
        # 聚合(包含自身)
        deg = adj.sum(dim=1, keepdim=True) + 1  # 加1包含自身
        agg = (agg + x_log) / deg
        
        # 指数映射回双曲空间
        x_out = self.exp_map(torch.zeros_like(x), agg)
        
        # 投影
        x_out = self._project(x_out)
        
        return x_out
 
 
class PoincaréGCN(nn.Module):
    """多层Poincaré GCN"""
    
    def __init__(self, in_channels, hidden_channels, out_channels, c=1.0, dropout=0.5):
        super().__init__()
        self.c = c
        
        self.conv1 = PoincaréGCNLayer(in_channels, hidden_channels, c)
        self.conv2 = PoincaréGCNLayer(hidden_channels, out_channels, c)
        
        self.dropout = dropout
        self.act = nn.ReLU()
    
    def forward(self, x, adj):
        x = self.conv1(x, adj)
        x = self.act(x)
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.conv2(x, adj)
        return x

Lorentz Graph Network (LGCN)

Lorentz 模型优势

Lorentz 模型相比 Poincaré ball 具有更好的数值稳定性

  1. 线性结构:Lorentz 内积是线性的,适合 GPU 并行
  2. 更好的条件数:梯度流更稳定
  3. 更简单的距离计算

Lorentz 图卷积

消息函数(在 Lorentz 空间):

其中 Lorentz 矩阵乘法

聚合函数(使用 Lorentz 加权平均):

其中 是基于 Lorentz 距离的注意力权重。

完整 Lorentz GNN 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class LorentzGNNLayer(nn.Module):
    """Lorentz 模型上的图神经网络层"""
    
    def __init__(self, in_features, out_features, c=1.0):
        super().__init__()
        self.c = c
        self.in_features = in_features
        self.out_features = out_features
        
        # Lorentz 线性变换
        self.W = nn.Parameter(torch.randn(in_features, out_features))
        self.bias = nn.Parameter(torch.zeros(out_features + 1))  # +1 for time dimension
        
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.orthogonal_(self.W)
        nn.init.zeros_(self.bias)
    
    def lorentz_inner(self, x, y):
        """Lorentz 内积"""
        return -x[..., 0:1] * y[..., 0:1] + torch.sum(x[..., 1:] * y[..., 1:], dim=-1, keepdim=True)
    
    def lorentz_norm(self, x):
        """Lorentz 范数"""
        return torch.sqrt(torch.clamp(-self.lorentz_inner(x, x), min=1e-10))
    
    def exp_map(self, x, v):
        """Lorentz 指数映射"""
        v_norm = self.lorentz_norm(v)
        second_term = (torch.sinh(self.c * v_norm) / (self.c * v_norm)) * v
        return torch.cosh(self.c * v_norm) * x + second_term
    
    def log_map(self, x, y):
        """Lorentz 对数映射"""
        diff = self.lorentz_add(-x, y)
        diff_norm = self.lorentz_norm(diff)
        return (torch.atanh(self.c * diff_norm) / (self.c * diff_norm)) * diff
    
    def lorentz_add(self, x, y):
        """Lorentz 加法"""
        m = -self.c * self.lorentz_inner(x, y) + torch.sqrt(
            (self.c - self.c * self.lorentz_inner(x, y))**2 + 
            self.c * (self.c - self.lorentz_inner(y, y))
        )
        return (x + y) / m
    
    def lorentz_matmul(self, W, x):
        """Lorentz 矩阵乘法"""
        # 投影到切空间
        x_log = self.log_map(torch.zeros_like(x), x)
        # 欧几里得矩阵乘法
        x_trans = torch.einsum('ij,...j->...i', W, x_log[..., 1:])  # Skip time dimension
        # 映射回流形
        zero_time = torch.zeros(*x.shape[:-1], 1, device=x.device)
        x_with_time = torch.cat([zero_time, x_trans], dim=-1)
        return self.exp_map(torch.zeros_like(x), x_with_time)
    
    def forward(self, x, adj):
        """
        Args:
            x: 节点特征 [num_nodes, in_features]
            adj: 邻接矩阵 [num_nodes, num_nodes]
        """
        # 添加时间维度(确保在Lorentz流形上)
        time_dim = torch.sqrt(torch.ones(*x.shape[:-1], 1, device=x.device) + 
                            torch.sum(x * x, dim=-1, keepdim=True))
        x = torch.cat([time_dim, x], dim=-1)
        
        # Lorentz 变换
        x_trans = self.lorentz_matmul(self.W, x) + self.bias
        
        # 消息传递(使用注意力)
        scores = torch.matmul(x_trans, x_trans.transpose(-2, -1))
        scores = F.softmax(scores, dim=-1)
        
        # 加权聚合
        x_agg = torch.matmul(scores, x_trans)
        
        # 非线性激活
        x_out = torch.tanh(x_agg)
        
        # 投影回Lorentz流形
        x_out = x_out / torch.clamp(-self.lorentz_inner(x_out, x_out).abs(), min=1e-10) * self.c
        
        return x_out

双曲注意力图网络

Hyperbolic Graph Attention

将注意力机制引入双曲图网络:

注意力权重计算

其中 \ 表示Mobius concat操作。

双曲多头注意力

class HyperbolicGATLayer(nn.Module):
    """双曲图注意力层"""
    
    def __init__(self, in_features, out_features, c=1.0, num_heads=4, dropout=0.6):
        super().__init__()
        self.c = c
        self.num_heads = num_heads
        self.head_dim = out_features // num_heads
        
        self.W = nn.Linear(in_features, out_features, bias=False)
        self.att = nn.Parameter(torch.randn(2 * self.head_dim, num_heads))
        self.leaky_relu = nn.LeakyReLU(0.2)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, adj):
        # 线性变换
        x = self.W(x)
        
        # 分成多头
        x = x.view(x.size(0), self.num_heads, self.head_dim)
        
        # 计算注意力(欧几里得空间中的注意力分数)
        x_log = self.log_map(torch.zeros_like(x), x)
        
        # Self-attention
        combined = torch.cat([x_log.unsqueeze(1).expand(-1, x.size(0), -1, -1),
                             x_log.unsqueeze(0).expand(x.size(0), -1, -1, -1)], dim=-1)
        
        attn_weights = torch.einsum('ijhd,hd->ijh', combined, self.att)
        attn_weights = self.leaky_relu(attn_weights)
        
        # Masking (masked attention)
        mask = (adj == 0).unsqueeze(-1)
        attn_weights = attn_weights.masked_fill(mask, float('-inf'))
        attn_weights = F.softmax(attn_weights, dim=2)
        
        # 聚合
        x_agg = torch.einsum('ijh,jhd->ihd', attn_weights, x_log)
        
        # 合并多头
        x_agg = x_agg.reshape(x.size(0), -1)
        
        # 指数映射回双曲空间
        x_out = self.exp_map(torch.zeros_like(x_agg), x_agg)
        
        return self.dropout(x_out)
    
    def log_map(self, base, x):
        """对数映射"""
        diff = x - base
        diff_norm = torch.norm(diff, dim=-1, keepdim=True).clamp(min=1e-10)
        return (2 / torch.sqrt(self.c) * torch.atanh(torch.sqrt(self.c) * diff_norm) / 
                diff_norm) * diff
    
    def exp_map(self, base, v):
        """指数映射"""
        v_norm = torch.norm(v, dim=-1, keepdim=True).clamp(min=1e-10)
        second_term = (torch.tanh(torch.sqrt(self.c) * v_norm / 2) / 
                      (torch.sqrt(self.c) * v_norm / 2)) * v
        return base + second_term

层次聚合与池化

双曲图池化

将图层次化池化到双曲空间:

class HyperbolicGraphPool(nn.Module):
    """双曲图层次池化"""
    
    def __init__(self, ratio=0.5, c=1.0):
        super().__init__()
        self.ratio = ratio
        self.c = c
    
    def compute_cluster_score(self, x):
        """计算每个节点属于聚类中心的得分"""
        # 在切空间中计算自注意力
        x_log = log_map_0(x, self.c)
        scores = torch.sigmoid(x_log @ x_log.T)
        return scores
    
    def forward(self, x, adj):
        num_nodes = x.size(0)
        num_keep = max(1, int(num_nodes * self.ratio))
        
        # 计算聚类得分
        scores = self.compute_cluster_score(x)
        
        # 选择top-k节点
        _, top_indices = torch.topk(scores.sum(dim=1), num_keep)
        
        # 提取子图
        x_pooled = x[top_indices]
        
        # 更新邻接矩阵
        adj_pooled = adj[top_indices][:, top_indices]
        
        # 在双曲空间中聚合丢失的节点信息
        for idx in range(len(top_indices)):
            mask = torch.ones(num_nodes, dtype=torch.bool)
            mask[top_indices] = False
            if mask.sum() > 0:
                neighbors = mask.nonzero().squeeze()
                # 计算邻居的Fréchet均值
                neighbor_embeddings = x[neighbors]
                pooled_val = frechet_mean(torch.cat([x_pooled[idx:idx+1], neighbor_embeddings]), self.c)
                x_pooled[idx] = pooled_val
        
        return x_pooled, adj_pooled, top_indices
 
 
def frechet_mean(points, c=1.0, lr=0.1, max_iter=100):
    """计算双曲空间中的Fréchet均值(黎曼质心)"""
    mean = points[0]
    
    for _ in range(max_iter):
        logs = [log_map(mean, p, c) for p in points]
        grad = torch.stack(logs).mean(dim=0)
        
        mean = exp_map(mean, lr * grad, c)
        mean = project_to_ball(mean, c)
    
    return mean

实验对比

节点分类性能

数据集欧几里得 GCNPoincaré GCNLorentz GNN
Cora81.5%82.1%82.3%
CiteSeer70.3%71.2%71.8%
PubMed79.0%79.5%79.8%
PPI98.6%99.1%99.2%

层次结构捕获能力

在合成层次图上的表现:

层次深度 = 10 的二叉树
嵌入维度 = 16

欧几里得嵌入:
  - 能嵌入的层次数: 3-4
  - 最近邻精度: 45%

Poincaré嵌入:
  - 能嵌入的层次数: 10+
  - 最近邻精度: 98%

与标准GNN的关系

极限情况

当双曲空间曲率 时,双曲 GNN 退化为标准欧几里得 GNN:

架构选择指南

数据特性推荐架构
弱层次结构标准 GCN/GAT
强层次结构Poincaré GCN
需要数值稳定性Lorentz GNN
需要注意力机制Hyperbolic GAT
极深层次混合双曲-欧几里得

参考