概述
图神经网络(Graph Neural Network, GNN)通过消息传递机制聚合邻居信息,其底层逻辑与置信传播和马尔可夫随机场有着深刻的联系。1
本章从概率推断的角度重新审视GNN,探讨其与变分推断、贝叶斯学习、以及知识图谱补全的关系。
消息传递的概率解释
从置信传播到GNN
在图卷积网络中,消息传递可以形式化为:
概率图模型的类比
将图结构数据视为MRF的一个实现:
| GNN组件 | 概率解释 |
|---|---|
| 节点表示 | 节点的后验信念 |
| 消息 | 从节点 传递到 的信息 |
| 聚合操作 | 置信传播中的消息组合 |
| 更新函数 | 信念更新规则 |
和积算法的GNN版本
考虑一个图上的联合分布 ,其中 是节点特征。
边缘分布可以通过消息传递计算:
消息作为后验估计
GNN的消息可以理解为对邻居信息的后验估计:
贝叶斯图神经网络
权重不确定性
标准GNN假设权重是确定性参数。贝叶斯GNN引入权重的分布:
预测时,对权重分布边缘化:
PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, Bernoulli
class BayesianGNNLayer(nn.Module):
"""
贝叶斯图神经网络层
使用变分推断近似权重后验
"""
def __init__(self, in_features, out_features, edge_dim=0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# 权重参数:均值和方差
self.weight_mu = nn.Parameter(torch.randn(in_features, out_features) * 0.1)
self.weight_logvar = nn.Parameter(torch.zeros(in_features, out_features))
self.bias_mu = nn.Parameter(torch.zeros(out_features))
self.bias_logvar = nn.Parameter(torch.zeros(out_features))
def forward(self, x, edge_index):
"""
前向传播(集成预测)
"""
# 采样权重
weight = self.sample_weight(self.weight_mu, self.weight_logvar)
bias = self.sample_weight(self.bias_mu, self.bias_logvar)
# 消息传递
row, col = edge_index
messages = x[col] @ weight # 源节点信息
# 聚合(平均)
aggr = torch.zeros_like(x[:, :weight.shape[1]])
aggr.index_add_(0, row, messages)
# 计算计数
deg = torch.zeros_like(aggr)
deg.scatter_add_(0, row, torch.ones_like(row, dtype=torch.float))
deg = deg.clamp(min=1)
# 归一化
aggr = aggr / deg
return aggr + bias
def sample_weight(self, mu, logvar):
"""
重参数化采样
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def kl_divergence(self):
"""
计算权重先验与变分后验的KL散度
假设先验是标准正态分布
"""
# KL(N(μ,σ²) || N(0,1))
kl = -0.5 * torch.sum(1 + self.weight_logvar - self.weight_mu.pow(2) - self.weight_logvar.exp())
return kl
class BayesianGNN(nn.Module):
"""
完整贝叶斯GNN
"""
def __init__(self, num_features, hidden_dim, num_classes, num_layers=3):
super().__init__()
self.layers = nn.ModuleList([
BayesianGNNLayer(
num_features if i == 0 else hidden_dim,
hidden_dim
)
for i in range(num_layers)
])
self.classifier = nn.Linear(hidden_dim, num_classes)
def forward(self, x, edge_index, num_samples=1):
"""
贝叶斯前向传播
Args:
x: 节点特征
edge_index: 边索引
num_samples: MC采样次数
Returns:
mean: 预测均值
variance: 预测方差
"""
predictions = []
for _ in range(num_samples):
h = x
for layer in self.layers:
h = F.relu(layer(h, edge_index))
logits = self.classifier(h)
predictions.append(logits)
predictions = torch.stack(predictions)
# 集成均值和方差
mean = predictions.mean(dim=0)
variance = predictions.var(dim=0)
return mean, variance
def elbo_loss(self, x, edge_index, y, num_samples=5):
"""
Evidence Lower Bound损失
ELBO = 重构损失 + KL散度
"""
# MC估计重构损失
recon_loss = 0
for _ in range(num_samples):
pred, _ = self.forward(x, edge_index, num_samples=1)
recon_loss += F.cross_entropy(pred, y)
recon_loss /= num_samples
# KL散度
kl_loss = sum(layer.kl_divergence() for layer in self.layers)
return recon_loss + 0.01 * kl_loss变分图自编码器(VGAE)
模型结构
VGAE使用变分推断学习图结构的潜在表示:
- 编码器:GNN生成潜在变量分布
- 解码器:从潜在变量重建邻接矩阵
概率模型
先验:
似然:
变分后验:
PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class VariationalGraphAutoEncoder(nn.Module):
"""
变分图自编码器(VGAE)
用于图结构学习和节点表示学习
"""
def __init__(self, num_features, latent_dim, hidden_dim=32):
super().__init__()
self.latent_dim = latent_dim
# 编码器:推断潜在分布参数
self.gcn_mean = GCNConv(num_features, latent_dim)
self.gcn_logvar = GCNConv(num_features, latent_dim)
# 解码器:重建邻接矩阵
self.decoder = nn.Bilinear(latent_dim, latent_dim, 1)
def encode(self, x, edge_index):
"""
编码到潜在空间
Returns:
mu: 均值
logvar: 对数方差
"""
mu = self.gcn_mean(x, edge_index)
logvar = self.gcn_logvar(x, edge_index)
return mu, logvar
def reparameterize(self, mu, logvar):
"""
重参数化采样
"""
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
else:
return mu
def decode(self, z, edge_index):
"""
从潜在变量解码边概率
"""
row, col = edge_index
# 计算每条边的概率
edge_logits = self.decoder(z[row], z[col]).squeeze(-1)
return torch.sigmoid(edge_logits)
def decode_all(self, z):
"""
解码完整邻接矩阵
"""
adj = torch.sigmoid(z @ z.T)
return adj
def forward(self, x, edge_index):
"""
完整前向传播
"""
mu, logvar = self.encode(x, edge_index)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
def loss(self, x, edge_index, edge_index_neg):
"""
VGAE损失函数
包括重构损失和KL散度
"""
# 编码
z, mu, logvar = self.forward(x, edge_index)
# 正样本重构损失
pos_score = self.decode(z, edge_index)
pos_loss = F.binary_cross_entropy(pos_score, torch.ones_like(pos_score))
# 负样本重构损失
neg_score = self.decode(z, edge_index_neg)
neg_loss = F.binary_cross_entropy(neg_score, torch.zeros_like(neg_score))
recon_loss = pos_loss + neg_loss
# KL散度
kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kl_loss
class GraphAutoEncoder(nn.Module):
"""
图自编码器(非变分版本)
"""
def __init__(self, num_features, latent_dim, hidden_dim=32):
super().__init__()
# 编码器
self.encoder = nn.Sequential(
GCNConv(num_features, hidden_dim),
nn.ReLU(),
GCNConv(hidden_dim, latent_dim)
)
# 解码器
self.decoder = nn.Bilinear(latent_dim, latent_dim, 1)
def forward(self, x, edge_index):
z = self.encoder(x, edge_index)
return z
def decode(self, z, edge_index):
row, col = edge_index
edge_logits = self.decoder(z[row], z[col]).squeeze(-1)
return torch.sigmoid(edge_logits)
def loss(self, x, edge_index, edge_index_neg):
z = self.forward(x, edge_index)
pos_score = self.decode(z, edge_index)
neg_score = self.decode(z, edge_index_neg)
loss = F.binary_cross_entropy(pos_score, torch.ones_like(pos_score)) + \
F.binary_cross_entropy(neg_score, torch.zeros_like(neg_score))
return loss图结构学习
联合学习框架
传统GNN假设图结构是固定的。图结构学习(Graph Structure Learning)联合学习最优的图结构:
class GraphStructureLearner(nn.Module):
"""
图结构学习器
学习最优的邻接矩阵
"""
def __init__(self, node_dim, hidden_dim):
super().__init__()
# 相似度函数
self.similarity = nn.Sequential(
nn.Linear(node_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
# 阈值化
self.threshold = 0.5
def forward(self, x):
"""
学习邻接矩阵
Args:
x: (num_nodes, node_dim) 节点特征
Returns:
adj: (num_nodes, num_nodes) 学习到的邻接矩阵
"""
n = x.shape[0]
# 计算节点对之间的相似度
similarities = []
for i in range(n):
for j in range(n):
if i != j:
pair = torch.cat([x[i], x[j]])
sim = self.similarity(pair.unsqueeze(0))
similarities.append((i, j, sim))
# 构建稀疏邻接矩阵
rows, cols, vals = zip(*similarities)
adj = torch.zeros(n, n)
adj[list(rows), list(cols)] = torch.cat(vals).squeeze()
# 对称化
adj = (adj + adj.T) / 2
return adj
def forward_with_learned_graph(self, x, edge_index):
"""
使用学习的图结构进行消息传递
"""
# 学习图结构
adj = self.forward(x)
# 归一化
deg = adj.sum(dim=1, keepdim=True)
adj = adj / deg
# 消息传递
return adj @ x概率图结构学习
class ProbabilisticGraphLearner(nn.Module):
"""
概率图结构学习
使用Gumbel-Softmax学习离散图结构
"""
def __init__(self, node_dim, temperature=1.0):
super().__init__()
self.temperature = temperature
# 边存在概率的参数
self.edge_logits = nn.Linear(node_dim * 2, 1)
def forward(self, x, hard=True):
"""
采样边
Args:
x: 节点特征
hard: 是否使用硬采样
Returns:
edge_probs: 边概率
edge_mask: 采样的边掩码
"""
n = x.shape[0]
# 计算边分数
rows, cols = [], []
for i in range(n):
for j in range(n):
if i != j:
rows.append(i)
cols.append(j)
rows = torch.tensor(rows, device=x.device)
cols = torch.tensor(cols, device=x.device)
# 拼接特征计算边分数
x_i = x[rows]
x_j = x[cols]
pair = torch.cat([x_i, x_j], dim=1)
logits = self.edge_logits(pair).squeeze()
# Gumbel-Softmax采样
if hard:
# 硬采样:使用argmax
probs = torch.sigmoid(logits)
edge_mask = (probs > 0.5).float()
else:
# 软采样:Gumbel-Softmax
gumbels = -torch.empty_like(logits).exponential_().log()
gumbels = (logits + gumbels) / self.temperature
edge_mask = torch.sigmoid(gumbels)
return probs, edge_mask, (rows, cols)知识图谱补全的概率视角
知识图谱的表示学习
知识图谱由三元组 组成,表示头实体 与尾实体 之间存在关系 。
概率模型
翻译模型(TransE)可以概率化:
双线性模型:
变分知识图谱嵌入
class VariationalKnowledgeGraphEmbedding(nn.Module):
"""
变分知识图谱嵌入
引入实体和关系的概率分布
"""
def __init__(self, num_entities, num_relations, latent_dim):
super().__init__()
self.latent_dim = latent_dim
# 实体嵌入:均值和方差
self.entity_mu = nn.Embedding(num_entities, latent_dim)
self.entity_logvar = nn.Embedding(num_entities, latent_dim)
# 关系嵌入
self.relation = nn.Embedding(num_relations, latent_dim)
# 先验(标准高斯)
self.prior = torch.distributions.Normal(0, 1)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def score(self, h, r, t):
"""
计算三元组分数
"""
# TransE评分函数
h_emb = self.entity_mu(h)
r_emb = self.relation(r)
t_emb = self.entity_mu(t)
score = torch.norm(h_emb + r_emb - t_emb, p=2, dim=-1)
return -score # 分数越高越好
def loss(self, pos_triples, neg_triples):
"""
损失函数
使用负采样损失 + KL散度
"""
h, r, t = pos_triples
# 正样本分数
pos_score = self.score(h, r, t)
# 负样本分数
h_neg, r_neg, t_neg = neg_triples
neg_score = self.score(h_neg, r_neg, t_neg)
# 负采样损失
margin = 1.0
loss = torch.clamp(neg_score - pos_score + margin, min=0).mean()
# KL散度
h_mu = self.entity_mu(h)
h_logvar = self.entity_logvar(h)
t_mu = self.entity_mu(t)
t_logvar = self.entity_logvar(t)
kl_loss = -0.5 * torch.mean(
1 + h_logvar - h_mu.pow(2) - h_logvar.exp() +
1 + t_logvar - t_mu.pow(2) - t_logvar.exp()
)
return loss + 0.01 * kl_loss图上的变分推断
消息传递变分推断
将GNN视为变分推断的一个步骤:
class MessagePassingVI(nn.Module):
"""
消息传递变分推断
将GNN层视为变分E步或M步
"""
def __init__(self, num_nodes, node_dim, hidden_dim):
super().__init__()
self.num_nodes = num_nodes
# 变分分布参数
self.q_params = nn.Parameter(torch.randn(num_nodes, hidden_dim))
# 消息网络
self.message_net = nn.Sequential(
nn.Linear(node_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# 更新网络
self.update_net = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU()
)
def e_step(self, x, edge_index):
"""
E步:更新变分参数
对应消息传递
"""
row, col = edge_index
# 构建消息
messages = self.message_net(torch.cat([x[row], x[col]], dim=1))
# 聚合
aggr = torch.zeros(self.num_nodes, messages.shape[1])
aggr.index_add_(0, row, messages)
# 归一化
deg = torch.zeros(self.num_nodes)
deg.scatter_add_(0, row, torch.ones_like(row, dtype=torch.float))
deg = deg.clamp(min=1)
aggr = aggr / deg.unsqueeze(1)
# 更新变分参数
self.q_params.data = self.update_net(
torch.cat([self.q_params.data, aggr], dim=1)
)
return self.q_params
def m_step(self, x):
"""
M步:更新变分分布参数
最小化KL散度
"""
# 这里简化为更新变分分布的均值
pass
def elbo(self, x, edge_index):
"""
计算ELBO
ELBO = 重构损失 - KL散度
"""
# 重构损失(简化版本)
recon_loss = F.mse_loss(self.q_params, x)
# KL散度(与先验的差异)
prior = torch.distributions.Normal(0, 1)
q = torch.distributions.Normal(self.q_params, 1)
kl_loss = torch.distributions.kl.kl_divergence(q, prior).mean()
return recon_loss + 0.1 * kl_loss图神经网络的不确定性量化
集成方法
class EnsembleGNN(nn.Module):
"""
集成GNN
通过多个GNN模型的不确定性量化
"""
def __init__(self, num_features, hidden_dim, num_classes, num_models=5):
super().__init__()
self.models = nn.ModuleList([
GNN(num_features, hidden_dim, num_classes)
for _ in range(num_models)
])
def forward(self, x, edge_index):
"""
返回预测均值和方差
"""
predictions = []
for model in self.models:
model.eval()
with torch.no_grad():
out = model(x, edge_index)
predictions.append(out)
predictions = torch.stack(predictions)
# 集成预测
mean = predictions.mean(dim=0)
variance = predictions.var(dim=0)
return mean, variance
def predict_with_uncertainty(self, x, edge_index):
"""
带不确定性的预测
"""
mean, variance = self.forward(x, edge_index)
std = torch.sqrt(variance)
return mean, stdMonte Carlo Dropout
class MCDropoutGNN(nn.Module):
"""
Monte Carlo Dropout GNN
通过多次Dropout采样估计不确定性
"""
def __init__(self, num_features, hidden_dim, num_classes, dropout=0.5):
super().__init__()
self.conv1 = GCNConv(num_features, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
self.classifier = nn.Linear(hidden_dim, num_classes)
self.dropout = nn.Dropout(dropout)
def forward(self, x, edge_index, num_samples=10):
"""
MC Dropout前向传播
"""
predictions = []
for _ in range(num_samples):
# 启用dropout
h = self.dropout(F.relu(self.conv1(x, edge_index)))
h = self.dropout(F.relu(self.conv2(h, edge_index)))
out = self.classifier(h)
predictions.append(out)
predictions = torch.stack(predictions)
mean = predictions.mean(dim=0)
variance = predictions.var(dim=0)
return mean, variance与现有wiki内容的联系
参考
相关阅读
- 图卷积网络详解 — GCN基础
- 图注意力网络 — 自适应邻居权重
- 知识图谱与GNN — 知识图谱补全
- 变分推断深度学习应用 — 变分框架
- 因子图与置信传播 — 推断算法
Footnotes
-
Kipf, E., & Welling, M. (2017). Semi-supervised classification with graph convolutional networks. ICLR. ↩