概述

变分图神经网络(Variational Graph Neural Networks, VGNNs)将变分推断与图神经网络相结合,用于学习图数据的概率潜在表示1 这种方法在无监督图表示学习、链接预测、图生成等任务中具有重要应用。本文档系统介绍变分GNN的核心方法、架构设计和最新进展。


变分图自编码器(VGAE)

经典VGAE架构

VGAE(Variational Graph Autoencoder)由Kipf和Welling在2016年提出,是最早的变分图表示学习方法之一。2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
 
class VGAE(nn.Module):
    """
    变分图自编码器
    
    论文: "Variational Graph Auto-Encoders" (NeurIPS 2016)
    
    架构:
    - 编码器: GCN -> (mu, logvar)
    - 解码器: 内积解码器
    """
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        
        # 共享图卷积层
        self.gcn1 = GCNConv(in_channels, hidden_channels)
        self.gcn2 = GCNConv(in_channels, hidden_channels)
        
        # 均值和方差分支
        self.gcn_mu = GCNConv(hidden_channels, out_channels)
        self.gcn_logvar = GCNConv(hidden_channels, out_channels)
        
        # 解码器:内积
        self.decoder = InnerProductDecoder()
    
    def encode(self, x, edge_index):
        """
        编码:学习潜在变量分布
        """
        # 图卷积
        h = F.relu(self.gcn1(x, edge_index))
        
        # 均值和方差
        mu = self.gcn_mu(h, edge_index)
        logvar = self.gcn_logvar(h, edge_index)
        
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """
        重参数化技巧
        """
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        else:
            return mu
    
    def decode(self, z, edge_index):
        """
        解码:重建邻接矩阵
        """
        return self.decoder(z, edge_index)
    
    def forward(self, x, edge_index):
        """
        前向传播
        """
        mu, logvar = self.encode(x, edge_index)
        z = self.reparameterize(mu, logvar)
        adj_recon = self.decode(z, edge_index)
        
        return adj_recon, mu, logvar, z
 
 
class InnerProductDecoder(nn.Module):
    """
    内积解码器
    
    重建边概率: p(A_ij | z_i, z_j) = sigmoid(z_i^T z_j)
    """
    def __init__(self):
        super().__init__()
    
    def forward(self, z, edge_index=None, sigmoid=True):
        """
        Args:
            z: 潜在表示 (num_nodes, hidden_dim)
            edge_index: 边索引(用于重建特定边)
        """
        # 全连接内积
        adj = torch.sigmoid(torch.mm(z, z.t()))
        
        if edge_index is not None:
            # 只返回指定边的重建概率
            row, col = edge_index
            adj_sparse = adj[row, col]
            return adj_sparse
        
        return adj

损失函数

VGAE的损失函数包含两部分:

  1. 重建损失:衡量解码器重建图结构的能力
  2. KL散度:强制潜在变量接近先验分布
def vgae_loss(model, data, beta=1.0):
    """
    VGAE损失函数
    
    L = L_reconstruction + beta * L_KL
    
    其中:
    - L_reconstruction = -E[log p(A|z)]
    - L_KL = D_KL(q(z|A) || p(z))
    """
    adj = data.edge_index
    
    # 前向传播
    adj_recon, mu, logvar, z = model(data.x, adj)
    
    # 重建损失(二元交叉熵)
    # 使用原始邻接矩阵作为目标
    adj_target = torch.zeros_like(adj_recon)
    adj_target[adj_recon > 0.5] = 1.0  # 正样本
    # 负样本:随机采样的不存在的边
    neg_edge_index = negative_sampling(adj, data.num_nodes)
    
    pos_loss = F.binary_cross_entropy_with_logits(
        adj_recon, 
        torch.ones_like(adj_recon)
    )
    neg_loss = F.binary_cross_entropy_with_logits(
        model.decode(z, neg_edge_index),
        torch.zeros_like(model.decode(z, neg_edge_index))
    )
    
    recon_loss = pos_loss + neg_loss
    
    # KL散度:假设q(z) = N(mu, exp(logvar)), p(z) = N(0, I)
    # D_KL = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl_loss = -0.5 * torch.mean(
        1 + logvar - mu.pow(2) - logvar.exp()
    )
    
    # 总损失
    loss = recon_loss + beta * kl_loss
    
    return loss, recon_loss, kl_loss
 
 
def negative_sampling(edge_index, num_nodes, num_neg_samples=None):
    """
    负采样:生成不存在的边
    """
    num_edges = edge_index.shape[1]
    
    if num_neg_samples is None:
        num_neg_samples = num_edges
    
    # 现有边集合
    existing_edges = set()
    for i, j in zip(edge_index[0], edge_index[1]):
        existing_edges.add((i.item(), j.item()))
        existing_edges.add((j.item(), i.item()))
    
    # 生成负样本
    neg_edges = []
    while len(neg_edges) < num_neg_samples:
        i = torch.randint(0, num_nodes, (1,)).item()
        j = torch.randint(0, num_nodes, (1,)).item()
        
        if i != j and (i, j) not in existing_edges:
            neg_edges.append([i, j])
    
    return torch.tensor(neg_edges, dtype=torch.long).t()

高级变分GNN架构

1. 变分图Transformer(VGT)

class VariationalGraphTransformer(nn.Module):
    """
    变分图Transformer
    
    结合Transformer架构和变分推断
    """
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads=4, num_layers=3):
        super().__init__()
        
        # 输入投影
        self.input_proj = nn.Linear(in_channels, hidden_channels)
        
        # 图Transformer层
        self.layers = nn.ModuleList([
            GraphTransformerLayer(hidden_channels, num_heads)
            for _ in range(num_layers)
        ])
        
        # 变分参数
        self.mu_head = nn.Linear(hidden_channels, out_channels)
        self.logvar_head = nn.Linear(hidden_channels, out_channels)
        
        # 图池化
        self.readout = GlobalAttentionPool(hidden_channels)
    
    def forward(self, x, edge_index, batch=None):
        """
        前向传播
        """
        # 输入投影
        h = self.input_proj(x)
        h = F.relu(h)
        
        # 图Transformer层
        for layer in self.layers:
            h = layer(h, edge_index)
        
        # 图级表示(使用注意力池化)
        h_graph = self.readout(h, batch)
        
        # 潜在变量分布
        mu = self.mu_head(h_graph)
        logvar = self.logvar_head(h_graph)
        
        return mu, logvar
    
    def sample(self, mu, logvar, num_samples=1):
        """从潜在分布采样"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn(num_samples, mu.shape[0], mu.shape[1], device=mu.device)
        z = mu + eps * std
        return z
 
 
class GraphTransformerLayer(nn.Module):
    """
    图Transformer层
    """
    def __init__(self, hidden_channels, num_heads):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.num_heads = num_heads
        self.head_dim = hidden_channels // num_heads
        
        # 注意力
        self.q_proj = nn.Linear(hidden_channels, hidden_channels)
        self.k_proj = nn.Linear(hidden_channels, hidden_channels)
        self.v_proj = nn.Linear(hidden_channels, hidden_channels)
        
        # 输出投影
        self.out_proj = nn.Linear(hidden_channels, hidden_channels)
        
        # 前馈网络
        self.ffn = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels * 4),
            nn.ReLU(),
            nn.Linear(hidden_channels * 4, hidden_channels)
        )
        
        # 归一化
        self.norm1 = nn.LayerNorm(hidden_channels)
        self.norm2 = nn.LayerNorm(hidden_channels)
    
    def forward(self, x, edge_index):
        """前向传播"""
        # 自注意力
        h = self._graph_attention(x, edge_index)
        h = self.norm1(x + h)
        
        # 前馈
        h = self.norm2(h + self.ffn(h))
        
        return h
    
    def _graph_attention(self, x, edge_index):
        """图注意力计算"""
        row, col = edge_index
        
        # 投影
        q = self.q_proj(x).view(-1, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(-1, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(-1, self.num_heads, self.head_dim)
        
        # 注意力分数
        alpha = (q[row] * k[col]).sum(dim=-1) / (self.head_dim ** 0.5)
        alpha = F.softmax(alpha, dim=0)
        
        # 聚合
        out = torch.zeros_like(x).view(-1, self.num_heads, self.head_dim)
        out[col] += alpha.unsqueeze(-1) * v[row]
        
        out = out.view(-1, self.hidden_channels)
        out = self.out_proj(out)
        
        return out

2. 半隐式变分GNN

class SemiImplicitVGAE(nn.Module):
    """
    半隐式变分图自编码器
    
    使用半隐式分布增强表达能力
    论文: "Semi-Implicity Variational Graph Auto-Encoders" (AAAI 2022)
    """
    def __init__(self, in_channels, hidden_channels, out_channels, num_components=10):
        super().__init__()
        
        # 确定性编码器
        self.encoder = nn.Sequential(
            GCNConv(in_channels, hidden_channels),
            nn.ReLU(),
            GCNConv(hidden_channels, hidden_channels),
            nn.ReLU()
        )
        
        # 变分参数(全局)
        self.mu = nn.Linear(hidden_channels, out_channels)
        self.logvar = nn.Linear(hidden_channels, out_channels)
        
        # 半隐式先验参数
        self.num_components = num_components
        self.prior_means = nn.Parameter(torch.randn(num_components, out_channels))
        self.prior_vars = nn.Parameter(torch.ones(num_components, out_channels))
        
        # 后验混合组件参数
        self.post_mixing = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, num_components)
        )
    
    def prior_distribution(self, num_nodes):
        """
        半隐式先验分布
        
        p(z) = sum_k pi_k N(z | mu_k, sigma_k^2)
        其中 pi_k 由神经网络学习
        """
        # 均匀混合
        pi = torch.ones(num_nodes, self.num_components) / self.num_components
        
        return pi, self.prior_means, self.prior_vars
    
    def posterior_distribution(self, h):
        """
        后验混合分布
        
        q(z) = sum_k alpha_k N(z | mu, sigma^2)
        """
        alpha_logits = self.post_mixing(h)
        alpha = F.softmax(alpha_logits, dim=-1)
        
        mu = self.mu(h)
        logvar = self.logvar(h)
        
        return alpha, mu, logvar
    
    def sample_from_mixture(self, alpha, mu, logvar, num_samples=1):
        """
        从混合分布采样
        """
        num_nodes = mu.shape[0]
        z_samples = []
        
        for _ in range(num_samples):
            # 采样混合组件
            k = torch.multinomial(alpha, 1).squeeze(-1)  # (num_nodes,)
            
            # 从对应组件采样
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(mu)
            z = mu + std * eps
            
            z_samples.append(z)
        
        return torch.stack(z_samples, dim=0)  # (num_samples, num_nodes, dim)

图生成的变分方法

GraphVAE与图生成

class GraphVAE(nn.Module):
    """
    图变分自编码器用于图生成
    
    支持生成可变大小的图
    """
    def __init__(self, node_dim, edge_dim, latent_dim, max_nodes=50):
        super().__init__()
        
        self.max_nodes = max_nodes
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        
        # 编码器(用于学习图级表示)
        self.encoder = nn.Sequential(
            nn.Linear(node_dim * max_nodes, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU()
        )
        
        # 变分层
        self.mu_layer = nn.Linear(256, latent_dim)
        self.logvar_layer = nn.Linear(256, latent_dim)
        
        # 解码器
        self.decoder = GraphDecoder(latent_dim, node_dim, edge_dim, max_nodes)
    
    def encode(self, x_padded, mask):
        """
        编码图结构
        """
        # 展平节点特征
        x_flat = x_padded.flatten(1)  # (batch, node_dim * max_nodes)
        
        # 编码
        h = self.encoder(x_flat)
        
        # 变分参数
        mu = self.mu_layer(h)
        logvar = self.logvar_layer(h)
        
        return mu, logvar
    
    def decode(self, z):
        """
        解码:生成图结构
        """
        return self.decoder(z)
    
    def forward(self, x_padded, mask):
        mu, logvar = self.encode(x_padded, mask)
        z = self.reparameterize(mu, logvar)
        x_recon, edge_recon = self.decode(z)
        
        return x_recon, edge_recon, mu, logvar
 
 
class GraphDecoder(nn.Module):
    """
    图解码器
    
    生成节点和边
    """
    def __init__(self, latent_dim, node_dim, edge_dim, max_nodes):
        super().__init__()
        
        self.max_nodes = max_nodes
        
        # 节点解码器
        self.node_decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, max_nodes * node_dim)
        )
        
        # 边解码器(邻接矩阵)
        self.edge_decoder = nn.Sequential(
            nn.Linear(latent_dim + max_nodes * node_dim, 512),
            nn.ReLU(),
            nn.Linear(512, max_nodes * max_nodes)
        )
    
    def forward(self, z):
        """
        生成图
        """
        # 解码节点
        node_logits = self.node_decoder(z)
        node_probs = torch.sigmoid(node_logits.view(-1, self.max_nodes, self.node_dim))
        
        # 解码边
        edge_input = torch.cat([z, node_logits], dim=-1)
        edge_logits = self.edge_decoder(edge_input)
        edge_probs = torch.sigmoid(edge_logits.view(-1, self.max_nodes, self.max_nodes))
        
        # 确保对称性
        edge_probs = (edge_probs + edge_probs.transpose(-2, -1)) / 2
        
        # 屏蔽上三角
        mask = torch.triu(torch.ones(self.max_nodes, self.max_nodes), diagonal=1).bool()
        edge_probs = edge_probs * mask.unsqueeze(0)
        
        return node_probs, edge_probs

潜在空间操作与图编辑

图插值与编辑

class GraphLatentSpace:
    """
    潜在空间操作工具
    """
    def __init__(self, model):
        self.model = model
        self.model.eval()
    
    @torch.no_grad()
    def interpolate(self, graph1, graph2, num_steps=10):
        """
        图插值:在两个图之间生成中间表示
        
        用于:
        - 图动画生成
        - 图变换理解
        - 数据增强
        """
        # 编码
        mu1, _ = self.model.encode(graph1)
        mu2, _ = self.model.encode(graph2)
        
        # 线性插值
        alphas = torch.linspace(0, 1, num_steps)
        interpolated = []
        
        for alpha in alphas:
            z = (1 - alpha) * mu1 + alpha * mu2
            z = z.unsqueeze(0)
            
            # 解码
            node_probs, edge_probs = self.model.decode(z)
            interpolated.append((node_probs, edge_probs))
        
        return interpolated
    
    @torch.no_grad()
    def random_walk(self, start_graph, num_steps=10, step_size=0.1):
        """
        潜在空间随机游走
        
        用于:
        - 图空间探索
        - 多样性生成
        - 插值增强
        """
        mu, logvar = self.model.encode(start_graph)
        current_z = mu.unsqueeze(0)
        
        path = [current_z.clone()]
        
        for _ in range(num_steps):
            # 添加随机扰动
            noise = torch.randn_like(current_z) * step_size
            current_z = current_z + noise
            
            path.append(current_z.clone())
        
        return path
    
    @torch.no_grad()
    def conditional_generation(self, constraint_graph, modify_mask, target_nodes=None):
        """
        条件图生成:在约束下生成图
        
        Args:
            constraint_graph: 约束图结构
            modify_mask: 指示哪些节点需要修改的掩码
            target_nodes: 目标节点特征(可选)
        """
        mu, logvar = self.model.encode(constraint_graph)
        
        # 只修改指定节点
        z = mu.clone()
        if target_nodes is not None:
            z[modify_mask] = target_nodes[modify_mask]
        
        node_probs, edge_probs = self.model.decode(z.unsqueeze(0))
        
        return node_probs, edge_probs

对比学习方法

图对比变分学习

class ContrastiveVGAE(nn.Module):
    """
    对比变分图自编码器
    
    结合变分学习和对比学习
    """
    def __init__(self, in_channels, hidden_channels, out_channels, tau=0.5):
        super().__init__()
        
        # 图编码器
        self.encoder = GCNEncoder(in_channels, hidden_channels, out_channels)
        
        # 投影头
        self.projection = nn.Sequential(
            nn.Linear(out_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, out_channels)
        )
        
        # 温度参数
        self.tau = tau
    
    def forward(self, x, edge_index):
        mu, logvar = self.encoder(x, edge_index)
        
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            z = mu + eps * std
        else:
            z = mu
        
        # 投影
        h = self.projection(z)
        
        return mu, logvar, z, h
    
    def contrastive_loss(self, h1, h2):
        """
        对比损失:InfoNCE
        
        拉近同一图的不同视图,拉远不同图
        """
        # 归一化
        h1 = F.normalize(h1, dim=-1)
        h2 = F.normalize(h2, dim=-1)
        
        # 正样本对
        pos_sim = (h1 * h2).sum(dim=-1)  # (batch,)
        
        # 负样本对(批内)
        neg_sim = torch.mm(h1, h2.t())  # (batch, batch)
        
        # 温度缩放
        pos_sim = pos_sim / self.tau
        neg_sim = neg_sim / self.tau
        
        # 对比损失
        logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)
        labels = torch.zeros(len(h1), dtype=torch.long, device=h1.device)
        
        loss = F.cross_entropy(logits, labels)
        
        return loss
    
    def total_loss(self, x, edge_index, adj_target, beta=1.0):
        """
        总损失 = 重构损失 + KL散度 + 对比损失
        """
        mu, logvar, z, h = self.forward(x, edge_index)
        
        # 重构损失
        recon_loss = self.reconstruction_loss(z, adj_target)
        
        # KL散度
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        
        # 对比损失(需要两次前向传播获取两个视图)
        h2 = self.projection(self.encoder(x, edge_index)[0])
        contrastive_loss = self.contrastive_loss(h, h2)
        
        # 总损失
        loss = recon_loss + beta * kl_loss + contrastive_loss
        
        return loss

应用场景

1. 链接预测

class LinkPredictionWithVGAE:
    """
    使用VGAE进行链接预测
    """
    def __init__(self, model):
        self.model = model
    
    @torch.no_grad()
    def predict_link(self, x, edge_index, node_i, node_j):
        """
        预测两个节点之间是否存在边
        """
        mu, _ = self.model.encode(x, edge_index)
        
        # 计算相似度
        sim = torch.sigmoid((mu[node_i] * mu[node_j]).sum())
        
        return sim.item()
    
    @torch.no_grad()
    def predict_all_links(self, x, edge_index, k=10):
        """
        预测top-k最可能的边
        """
        mu, _ = self.model.encode(x, edge_index)
        
        num_nodes = mu.shape[0]
        
        # 计算所有节点对的相似度
        sim_matrix = torch.sigmoid(mu @ mu.t())
        
        # 获取上三角的索引(排除对角线)
        i, j = torch.triu_indices(num_nodes, num_nodes, offset=1)
        sims = sim_matrix[i, j]
        
        # 获取top-k
        top_k = torch.topk(sims, k)
        
        predicted_edges = list(zip(i[top_k.indices].tolist(), 
                                   j[top_k.indices].tolist(),
                                   top_k.values.tolist()))
        
        return predicted_edges

2. 图分类

class GraphClassificationWithVGAE:
    """
    使用VGAE表示进行图分类
    """
    def __init__(self, vgae_model, classifier):
        self.vgae = vgae_model
        self.classifier = classifier
    
    @torch.no_grad()
    def predict(self, x, edge_index, batch=None):
        """
        图分类
        """
        mu, _, z, _ = self.vgae(x, edge_index)
        
        # 图级分类
        if batch is not None:
            # 使用批级别的mu
            graph_repr = self.vgae.readout(mu, batch)
        else:
            graph_repr = mu.mean(dim=0, keepdim=True)
        
        return self.classifier(graph_repr)

总结

变分GNN方法的总结:

方法特点适用场景
VGAE简单高效链接预测、节点分类
VGT表达能力更强复杂图结构
半隐式VGAE更灵活的分布异构图、多模态
对比VGAE更好的表示无监督学习

参考


相关文章

Footnotes

  1. Variational Graph Auto-Encoders (NeurIPS 2016)

  2. Semi-Implicity Variational Graph Auto-Encoders (AAAI 2022)