概述
变分图神经网络(Variational Graph Neural Networks, VGNNs)将变分推断与图神经网络相结合,用于学习图数据的概率潜在表示。1 这种方法在无监督图表示学习、链接预测、图生成等任务中具有重要应用。本文档系统介绍变分GNN的核心方法、架构设计和最新进展。
变分图自编码器(VGAE)
经典VGAE架构
VGAE(Variational Graph Autoencoder)由Kipf和Welling在2016年提出,是最早的变分图表示学习方法之一。2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class VGAE(nn.Module):
"""
变分图自编码器
论文: "Variational Graph Auto-Encoders" (NeurIPS 2016)
架构:
- 编码器: GCN -> (mu, logvar)
- 解码器: 内积解码器
"""
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
# 共享图卷积层
self.gcn1 = GCNConv(in_channels, hidden_channels)
self.gcn2 = GCNConv(in_channels, hidden_channels)
# 均值和方差分支
self.gcn_mu = GCNConv(hidden_channels, out_channels)
self.gcn_logvar = GCNConv(hidden_channels, out_channels)
# 解码器:内积
self.decoder = InnerProductDecoder()
def encode(self, x, edge_index):
"""
编码:学习潜在变量分布
"""
# 图卷积
h = F.relu(self.gcn1(x, edge_index))
# 均值和方差
mu = self.gcn_mu(h, edge_index)
logvar = self.gcn_logvar(h, edge_index)
return mu, logvar
def reparameterize(self, mu, logvar):
"""
重参数化技巧
"""
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
else:
return mu
def decode(self, z, edge_index):
"""
解码:重建邻接矩阵
"""
return self.decoder(z, edge_index)
def forward(self, x, edge_index):
"""
前向传播
"""
mu, logvar = self.encode(x, edge_index)
z = self.reparameterize(mu, logvar)
adj_recon = self.decode(z, edge_index)
return adj_recon, mu, logvar, z
class InnerProductDecoder(nn.Module):
"""
内积解码器
重建边概率: p(A_ij | z_i, z_j) = sigmoid(z_i^T z_j)
"""
def __init__(self):
super().__init__()
def forward(self, z, edge_index=None, sigmoid=True):
"""
Args:
z: 潜在表示 (num_nodes, hidden_dim)
edge_index: 边索引(用于重建特定边)
"""
# 全连接内积
adj = torch.sigmoid(torch.mm(z, z.t()))
if edge_index is not None:
# 只返回指定边的重建概率
row, col = edge_index
adj_sparse = adj[row, col]
return adj_sparse
return adj损失函数
VGAE的损失函数包含两部分:
- 重建损失:衡量解码器重建图结构的能力
- KL散度:强制潜在变量接近先验分布
def vgae_loss(model, data, beta=1.0):
"""
VGAE损失函数
L = L_reconstruction + beta * L_KL
其中:
- L_reconstruction = -E[log p(A|z)]
- L_KL = D_KL(q(z|A) || p(z))
"""
adj = data.edge_index
# 前向传播
adj_recon, mu, logvar, z = model(data.x, adj)
# 重建损失(二元交叉熵)
# 使用原始邻接矩阵作为目标
adj_target = torch.zeros_like(adj_recon)
adj_target[adj_recon > 0.5] = 1.0 # 正样本
# 负样本:随机采样的不存在的边
neg_edge_index = negative_sampling(adj, data.num_nodes)
pos_loss = F.binary_cross_entropy_with_logits(
adj_recon,
torch.ones_like(adj_recon)
)
neg_loss = F.binary_cross_entropy_with_logits(
model.decode(z, neg_edge_index),
torch.zeros_like(model.decode(z, neg_edge_index))
)
recon_loss = pos_loss + neg_loss
# KL散度:假设q(z) = N(mu, exp(logvar)), p(z) = N(0, I)
# D_KL = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kl_loss = -0.5 * torch.mean(
1 + logvar - mu.pow(2) - logvar.exp()
)
# 总损失
loss = recon_loss + beta * kl_loss
return loss, recon_loss, kl_loss
def negative_sampling(edge_index, num_nodes, num_neg_samples=None):
"""
负采样:生成不存在的边
"""
num_edges = edge_index.shape[1]
if num_neg_samples is None:
num_neg_samples = num_edges
# 现有边集合
existing_edges = set()
for i, j in zip(edge_index[0], edge_index[1]):
existing_edges.add((i.item(), j.item()))
existing_edges.add((j.item(), i.item()))
# 生成负样本
neg_edges = []
while len(neg_edges) < num_neg_samples:
i = torch.randint(0, num_nodes, (1,)).item()
j = torch.randint(0, num_nodes, (1,)).item()
if i != j and (i, j) not in existing_edges:
neg_edges.append([i, j])
return torch.tensor(neg_edges, dtype=torch.long).t()高级变分GNN架构
1. 变分图Transformer(VGT)
class VariationalGraphTransformer(nn.Module):
"""
变分图Transformer
结合Transformer架构和变分推断
"""
def __init__(self, in_channels, hidden_channels, out_channels, num_heads=4, num_layers=3):
super().__init__()
# 输入投影
self.input_proj = nn.Linear(in_channels, hidden_channels)
# 图Transformer层
self.layers = nn.ModuleList([
GraphTransformerLayer(hidden_channels, num_heads)
for _ in range(num_layers)
])
# 变分参数
self.mu_head = nn.Linear(hidden_channels, out_channels)
self.logvar_head = nn.Linear(hidden_channels, out_channels)
# 图池化
self.readout = GlobalAttentionPool(hidden_channels)
def forward(self, x, edge_index, batch=None):
"""
前向传播
"""
# 输入投影
h = self.input_proj(x)
h = F.relu(h)
# 图Transformer层
for layer in self.layers:
h = layer(h, edge_index)
# 图级表示(使用注意力池化)
h_graph = self.readout(h, batch)
# 潜在变量分布
mu = self.mu_head(h_graph)
logvar = self.logvar_head(h_graph)
return mu, logvar
def sample(self, mu, logvar, num_samples=1):
"""从潜在分布采样"""
std = torch.exp(0.5 * logvar)
eps = torch.randn(num_samples, mu.shape[0], mu.shape[1], device=mu.device)
z = mu + eps * std
return z
class GraphTransformerLayer(nn.Module):
"""
图Transformer层
"""
def __init__(self, hidden_channels, num_heads):
super().__init__()
self.hidden_channels = hidden_channels
self.num_heads = num_heads
self.head_dim = hidden_channels // num_heads
# 注意力
self.q_proj = nn.Linear(hidden_channels, hidden_channels)
self.k_proj = nn.Linear(hidden_channels, hidden_channels)
self.v_proj = nn.Linear(hidden_channels, hidden_channels)
# 输出投影
self.out_proj = nn.Linear(hidden_channels, hidden_channels)
# 前馈网络
self.ffn = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels * 4),
nn.ReLU(),
nn.Linear(hidden_channels * 4, hidden_channels)
)
# 归一化
self.norm1 = nn.LayerNorm(hidden_channels)
self.norm2 = nn.LayerNorm(hidden_channels)
def forward(self, x, edge_index):
"""前向传播"""
# 自注意力
h = self._graph_attention(x, edge_index)
h = self.norm1(x + h)
# 前馈
h = self.norm2(h + self.ffn(h))
return h
def _graph_attention(self, x, edge_index):
"""图注意力计算"""
row, col = edge_index
# 投影
q = self.q_proj(x).view(-1, self.num_heads, self.head_dim)
k = self.k_proj(x).view(-1, self.num_heads, self.head_dim)
v = self.v_proj(x).view(-1, self.num_heads, self.head_dim)
# 注意力分数
alpha = (q[row] * k[col]).sum(dim=-1) / (self.head_dim ** 0.5)
alpha = F.softmax(alpha, dim=0)
# 聚合
out = torch.zeros_like(x).view(-1, self.num_heads, self.head_dim)
out[col] += alpha.unsqueeze(-1) * v[row]
out = out.view(-1, self.hidden_channels)
out = self.out_proj(out)
return out2. 半隐式变分GNN
class SemiImplicitVGAE(nn.Module):
"""
半隐式变分图自编码器
使用半隐式分布增强表达能力
论文: "Semi-Implicity Variational Graph Auto-Encoders" (AAAI 2022)
"""
def __init__(self, in_channels, hidden_channels, out_channels, num_components=10):
super().__init__()
# 确定性编码器
self.encoder = nn.Sequential(
GCNConv(in_channels, hidden_channels),
nn.ReLU(),
GCNConv(hidden_channels, hidden_channels),
nn.ReLU()
)
# 变分参数(全局)
self.mu = nn.Linear(hidden_channels, out_channels)
self.logvar = nn.Linear(hidden_channels, out_channels)
# 半隐式先验参数
self.num_components = num_components
self.prior_means = nn.Parameter(torch.randn(num_components, out_channels))
self.prior_vars = nn.Parameter(torch.ones(num_components, out_channels))
# 后验混合组件参数
self.post_mixing = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels),
nn.ReLU(),
nn.Linear(hidden_channels, num_components)
)
def prior_distribution(self, num_nodes):
"""
半隐式先验分布
p(z) = sum_k pi_k N(z | mu_k, sigma_k^2)
其中 pi_k 由神经网络学习
"""
# 均匀混合
pi = torch.ones(num_nodes, self.num_components) / self.num_components
return pi, self.prior_means, self.prior_vars
def posterior_distribution(self, h):
"""
后验混合分布
q(z) = sum_k alpha_k N(z | mu, sigma^2)
"""
alpha_logits = self.post_mixing(h)
alpha = F.softmax(alpha_logits, dim=-1)
mu = self.mu(h)
logvar = self.logvar(h)
return alpha, mu, logvar
def sample_from_mixture(self, alpha, mu, logvar, num_samples=1):
"""
从混合分布采样
"""
num_nodes = mu.shape[0]
z_samples = []
for _ in range(num_samples):
# 采样混合组件
k = torch.multinomial(alpha, 1).squeeze(-1) # (num_nodes,)
# 从对应组件采样
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(mu)
z = mu + std * eps
z_samples.append(z)
return torch.stack(z_samples, dim=0) # (num_samples, num_nodes, dim)图生成的变分方法
GraphVAE与图生成
class GraphVAE(nn.Module):
"""
图变分自编码器用于图生成
支持生成可变大小的图
"""
def __init__(self, node_dim, edge_dim, latent_dim, max_nodes=50):
super().__init__()
self.max_nodes = max_nodes
self.node_dim = node_dim
self.edge_dim = edge_dim
# 编码器(用于学习图级表示)
self.encoder = nn.Sequential(
nn.Linear(node_dim * max_nodes, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU()
)
# 变分层
self.mu_layer = nn.Linear(256, latent_dim)
self.logvar_layer = nn.Linear(256, latent_dim)
# 解码器
self.decoder = GraphDecoder(latent_dim, node_dim, edge_dim, max_nodes)
def encode(self, x_padded, mask):
"""
编码图结构
"""
# 展平节点特征
x_flat = x_padded.flatten(1) # (batch, node_dim * max_nodes)
# 编码
h = self.encoder(x_flat)
# 变分参数
mu = self.mu_layer(h)
logvar = self.logvar_layer(h)
return mu, logvar
def decode(self, z):
"""
解码:生成图结构
"""
return self.decoder(z)
def forward(self, x_padded, mask):
mu, logvar = self.encode(x_padded, mask)
z = self.reparameterize(mu, logvar)
x_recon, edge_recon = self.decode(z)
return x_recon, edge_recon, mu, logvar
class GraphDecoder(nn.Module):
"""
图解码器
生成节点和边
"""
def __init__(self, latent_dim, node_dim, edge_dim, max_nodes):
super().__init__()
self.max_nodes = max_nodes
# 节点解码器
self.node_decoder = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, max_nodes * node_dim)
)
# 边解码器(邻接矩阵)
self.edge_decoder = nn.Sequential(
nn.Linear(latent_dim + max_nodes * node_dim, 512),
nn.ReLU(),
nn.Linear(512, max_nodes * max_nodes)
)
def forward(self, z):
"""
生成图
"""
# 解码节点
node_logits = self.node_decoder(z)
node_probs = torch.sigmoid(node_logits.view(-1, self.max_nodes, self.node_dim))
# 解码边
edge_input = torch.cat([z, node_logits], dim=-1)
edge_logits = self.edge_decoder(edge_input)
edge_probs = torch.sigmoid(edge_logits.view(-1, self.max_nodes, self.max_nodes))
# 确保对称性
edge_probs = (edge_probs + edge_probs.transpose(-2, -1)) / 2
# 屏蔽上三角
mask = torch.triu(torch.ones(self.max_nodes, self.max_nodes), diagonal=1).bool()
edge_probs = edge_probs * mask.unsqueeze(0)
return node_probs, edge_probs潜在空间操作与图编辑
图插值与编辑
class GraphLatentSpace:
"""
潜在空间操作工具
"""
def __init__(self, model):
self.model = model
self.model.eval()
@torch.no_grad()
def interpolate(self, graph1, graph2, num_steps=10):
"""
图插值:在两个图之间生成中间表示
用于:
- 图动画生成
- 图变换理解
- 数据增强
"""
# 编码
mu1, _ = self.model.encode(graph1)
mu2, _ = self.model.encode(graph2)
# 线性插值
alphas = torch.linspace(0, 1, num_steps)
interpolated = []
for alpha in alphas:
z = (1 - alpha) * mu1 + alpha * mu2
z = z.unsqueeze(0)
# 解码
node_probs, edge_probs = self.model.decode(z)
interpolated.append((node_probs, edge_probs))
return interpolated
@torch.no_grad()
def random_walk(self, start_graph, num_steps=10, step_size=0.1):
"""
潜在空间随机游走
用于:
- 图空间探索
- 多样性生成
- 插值增强
"""
mu, logvar = self.model.encode(start_graph)
current_z = mu.unsqueeze(0)
path = [current_z.clone()]
for _ in range(num_steps):
# 添加随机扰动
noise = torch.randn_like(current_z) * step_size
current_z = current_z + noise
path.append(current_z.clone())
return path
@torch.no_grad()
def conditional_generation(self, constraint_graph, modify_mask, target_nodes=None):
"""
条件图生成:在约束下生成图
Args:
constraint_graph: 约束图结构
modify_mask: 指示哪些节点需要修改的掩码
target_nodes: 目标节点特征(可选)
"""
mu, logvar = self.model.encode(constraint_graph)
# 只修改指定节点
z = mu.clone()
if target_nodes is not None:
z[modify_mask] = target_nodes[modify_mask]
node_probs, edge_probs = self.model.decode(z.unsqueeze(0))
return node_probs, edge_probs对比学习方法
图对比变分学习
class ContrastiveVGAE(nn.Module):
"""
对比变分图自编码器
结合变分学习和对比学习
"""
def __init__(self, in_channels, hidden_channels, out_channels, tau=0.5):
super().__init__()
# 图编码器
self.encoder = GCNEncoder(in_channels, hidden_channels, out_channels)
# 投影头
self.projection = nn.Sequential(
nn.Linear(out_channels, hidden_channels),
nn.ReLU(),
nn.Linear(hidden_channels, out_channels)
)
# 温度参数
self.tau = tau
def forward(self, x, edge_index):
mu, logvar = self.encoder(x, edge_index)
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
z = mu + eps * std
else:
z = mu
# 投影
h = self.projection(z)
return mu, logvar, z, h
def contrastive_loss(self, h1, h2):
"""
对比损失:InfoNCE
拉近同一图的不同视图,拉远不同图
"""
# 归一化
h1 = F.normalize(h1, dim=-1)
h2 = F.normalize(h2, dim=-1)
# 正样本对
pos_sim = (h1 * h2).sum(dim=-1) # (batch,)
# 负样本对(批内)
neg_sim = torch.mm(h1, h2.t()) # (batch, batch)
# 温度缩放
pos_sim = pos_sim / self.tau
neg_sim = neg_sim / self.tau
# 对比损失
logits = torch.cat([pos_sim.unsqueeze(1), neg_sim], dim=1)
labels = torch.zeros(len(h1), dtype=torch.long, device=h1.device)
loss = F.cross_entropy(logits, labels)
return loss
def total_loss(self, x, edge_index, adj_target, beta=1.0):
"""
总损失 = 重构损失 + KL散度 + 对比损失
"""
mu, logvar, z, h = self.forward(x, edge_index)
# 重构损失
recon_loss = self.reconstruction_loss(z, adj_target)
# KL散度
kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
# 对比损失(需要两次前向传播获取两个视图)
h2 = self.projection(self.encoder(x, edge_index)[0])
contrastive_loss = self.contrastive_loss(h, h2)
# 总损失
loss = recon_loss + beta * kl_loss + contrastive_loss
return loss应用场景
1. 链接预测
class LinkPredictionWithVGAE:
"""
使用VGAE进行链接预测
"""
def __init__(self, model):
self.model = model
@torch.no_grad()
def predict_link(self, x, edge_index, node_i, node_j):
"""
预测两个节点之间是否存在边
"""
mu, _ = self.model.encode(x, edge_index)
# 计算相似度
sim = torch.sigmoid((mu[node_i] * mu[node_j]).sum())
return sim.item()
@torch.no_grad()
def predict_all_links(self, x, edge_index, k=10):
"""
预测top-k最可能的边
"""
mu, _ = self.model.encode(x, edge_index)
num_nodes = mu.shape[0]
# 计算所有节点对的相似度
sim_matrix = torch.sigmoid(mu @ mu.t())
# 获取上三角的索引(排除对角线)
i, j = torch.triu_indices(num_nodes, num_nodes, offset=1)
sims = sim_matrix[i, j]
# 获取top-k
top_k = torch.topk(sims, k)
predicted_edges = list(zip(i[top_k.indices].tolist(),
j[top_k.indices].tolist(),
top_k.values.tolist()))
return predicted_edges2. 图分类
class GraphClassificationWithVGAE:
"""
使用VGAE表示进行图分类
"""
def __init__(self, vgae_model, classifier):
self.vgae = vgae_model
self.classifier = classifier
@torch.no_grad()
def predict(self, x, edge_index, batch=None):
"""
图分类
"""
mu, _, z, _ = self.vgae(x, edge_index)
# 图级分类
if batch is not None:
# 使用批级别的mu
graph_repr = self.vgae.readout(mu, batch)
else:
graph_repr = mu.mean(dim=0, keepdim=True)
return self.classifier(graph_repr)总结
变分GNN方法的总结:
| 方法 | 特点 | 适用场景 |
|---|---|---|
| VGAE | 简单高效 | 链接预测、节点分类 |
| VGT | 表达能力更强 | 复杂图结构 |
| 半隐式VGAE | 更灵活的分布 | 异构图、多模态 |
| 对比VGAE | 更好的表示 | 无监督学习 |
参考
相关文章
- graph-convolutional-network — 图卷积网络基础
- gnn-probabilistic-inference — GNN概率推断
- variational-inference — 变分推断基础
- generative-adversarial-networks — GAN基础