概述

图神经网络(Graph Neural Network, GNN)通过消息传递机制聚合邻居信息,其底层逻辑与置信传播马尔可夫随机场有着深刻的联系。1

本章从概率推断的角度重新审视GNN,探讨其与变分推断、贝叶斯学习、以及知识图谱补全的关系。


消息传递的概率解释

从置信传播到GNN

图卷积网络中,消息传递可以形式化为:

概率图模型的类比

将图结构数据视为MRF的一个实现:

GNN组件概率解释
节点表示 节点的后验信念
消息 从节点 传递到 的信息
聚合操作置信传播中的消息组合
更新函数信念更新规则

和积算法的GNN版本

考虑一个图上的联合分布 ,其中 是节点特征。

边缘分布可以通过消息传递计算:

消息作为后验估计

GNN的消息可以理解为对邻居信息的后验估计


贝叶斯图神经网络

权重不确定性

标准GNN假设权重是确定性参数。贝叶斯GNN引入权重的分布:

预测时,对权重分布边缘化:

PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, Bernoulli
 
class BayesianGNNLayer(nn.Module):
    """
    贝叶斯图神经网络层
    
    使用变分推断近似权重后验
    """
    def __init__(self, in_features, out_features, edge_dim=0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # 权重参数:均值和方差
        self.weight_mu = nn.Parameter(torch.randn(in_features, out_features) * 0.1)
        self.weight_logvar = nn.Parameter(torch.zeros(in_features, out_features))
        
        self.bias_mu = nn.Parameter(torch.zeros(out_features))
        self.bias_logvar = nn.Parameter(torch.zeros(out_features))
    
    def forward(self, x, edge_index):
        """
        前向传播(集成预测)
        """
        # 采样权重
        weight = self.sample_weight(self.weight_mu, self.weight_logvar)
        bias = self.sample_weight(self.bias_mu, self.bias_logvar)
        
        # 消息传递
        row, col = edge_index
        messages = x[col] @ weight  # 源节点信息
        
        # 聚合(平均)
        aggr = torch.zeros_like(x[:, :weight.shape[1]])
        aggr.index_add_(0, row, messages)
        
        # 计算计数
        deg = torch.zeros_like(aggr)
        deg.scatter_add_(0, row, torch.ones_like(row, dtype=torch.float))
        deg = deg.clamp(min=1)
        
        # 归一化
        aggr = aggr / deg
        
        return aggr + bias
    
    def sample_weight(self, mu, logvar):
        """
        重参数化采样
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def kl_divergence(self):
        """
        计算权重先验与变分后验的KL散度
        
        假设先验是标准正态分布
        """
        # KL(N(μ,σ²) || N(0,1))
        kl = -0.5 * torch.sum(1 + self.weight_logvar - self.weight_mu.pow(2) - self.weight_logvar.exp())
        return kl
 
 
class BayesianGNN(nn.Module):
    """
    完整贝叶斯GNN
    """
    def __init__(self, num_features, hidden_dim, num_classes, num_layers=3):
        super().__init__()
        self.layers = nn.ModuleList([
            BayesianGNNLayer(
                num_features if i == 0 else hidden_dim,
                hidden_dim
            )
            for i in range(num_layers)
        ])
        self.classifier = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x, edge_index, num_samples=1):
        """
        贝叶斯前向传播
        
        Args:
            x: 节点特征
            edge_index: 边索引
            num_samples: MC采样次数
        
        Returns:
            mean: 预测均值
            variance: 预测方差
        """
        predictions = []
        
        for _ in range(num_samples):
            h = x
            for layer in self.layers:
                h = F.relu(layer(h, edge_index))
            
            logits = self.classifier(h)
            predictions.append(logits)
        
        predictions = torch.stack(predictions)
        
        # 集成均值和方差
        mean = predictions.mean(dim=0)
        variance = predictions.var(dim=0)
        
        return mean, variance
    
    def elbo_loss(self, x, edge_index, y, num_samples=5):
        """
        Evidence Lower Bound损失
        
        ELBO = 重构损失 + KL散度
        """
        # MC估计重构损失
        recon_loss = 0
        for _ in range(num_samples):
            pred, _ = self.forward(x, edge_index, num_samples=1)
            recon_loss += F.cross_entropy(pred, y)
        recon_loss /= num_samples
        
        # KL散度
        kl_loss = sum(layer.kl_divergence() for layer in self.layers)
        
        return recon_loss + 0.01 * kl_loss

变分图自编码器(VGAE)

模型结构

VGAE使用变分推断学习图结构的潜在表示:

  • 编码器:GNN生成潜在变量分布
  • 解码器:从潜在变量重建邻接矩阵

概率模型

先验

似然

变分后验

PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
 
class VariationalGraphAutoEncoder(nn.Module):
    """
    变分图自编码器(VGAE)
    
    用于图结构学习和节点表示学习
    """
    def __init__(self, num_features, latent_dim, hidden_dim=32):
        super().__init__()
        self.latent_dim = latent_dim
        
        # 编码器:推断潜在分布参数
        self.gcn_mean = GCNConv(num_features, latent_dim)
        self.gcn_logvar = GCNConv(num_features, latent_dim)
        
        # 解码器:重建邻接矩阵
        self.decoder = nn.Bilinear(latent_dim, latent_dim, 1)
    
    def encode(self, x, edge_index):
        """
        编码到潜在空间
        
        Returns:
            mu: 均值
            logvar: 对数方差
        """
        mu = self.gcn_mean(x, edge_index)
        logvar = self.gcn_logvar(x, edge_index)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """
        重参数化采样
        """
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        else:
            return mu
    
    def decode(self, z, edge_index):
        """
        从潜在变量解码边概率
        """
        row, col = edge_index
        
        # 计算每条边的概率
        edge_logits = self.decoder(z[row], z[col]).squeeze(-1)
        
        return torch.sigmoid(edge_logits)
    
    def decode_all(self, z):
        """
        解码完整邻接矩阵
        """
        adj = torch.sigmoid(z @ z.T)
        return adj
    
    def forward(self, x, edge_index):
        """
        完整前向传播
        """
        mu, logvar = self.encode(x, edge_index)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar
    
    def loss(self, x, edge_index, edge_index_neg):
        """
        VGAE损失函数
        
        包括重构损失和KL散度
        """
        # 编码
        z, mu, logvar = self.forward(x, edge_index)
        
        # 正样本重构损失
        pos_score = self.decode(z, edge_index)
        pos_loss = F.binary_cross_entropy(pos_score, torch.ones_like(pos_score))
        
        # 负样本重构损失
        neg_score = self.decode(z, edge_index_neg)
        neg_loss = F.binary_cross_entropy(neg_score, torch.zeros_like(neg_score))
        
        recon_loss = pos_loss + neg_loss
        
        # KL散度
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        
        return recon_loss + kl_loss
 
 
class GraphAutoEncoder(nn.Module):
    """
    图自编码器(非变分版本)
    """
    def __init__(self, num_features, latent_dim, hidden_dim=32):
        super().__init__()
        
        # 编码器
        self.encoder = nn.Sequential(
            GCNConv(num_features, hidden_dim),
            nn.ReLU(),
            GCNConv(hidden_dim, latent_dim)
        )
        
        # 解码器
        self.decoder = nn.Bilinear(latent_dim, latent_dim, 1)
    
    def forward(self, x, edge_index):
        z = self.encoder(x, edge_index)
        return z
    
    def decode(self, z, edge_index):
        row, col = edge_index
        edge_logits = self.decoder(z[row], z[col]).squeeze(-1)
        return torch.sigmoid(edge_logits)
    
    def loss(self, x, edge_index, edge_index_neg):
        z = self.forward(x, edge_index)
        
        pos_score = self.decode(z, edge_index)
        neg_score = self.decode(z, edge_index_neg)
        
        loss = F.binary_cross_entropy(pos_score, torch.ones_like(pos_score)) + \
               F.binary_cross_entropy(neg_score, torch.zeros_like(neg_score))
        
        return loss

图结构学习

联合学习框架

传统GNN假设图结构是固定的。图结构学习(Graph Structure Learning)联合学习最优的图结构:

class GraphStructureLearner(nn.Module):
    """
    图结构学习器
    
    学习最优的邻接矩阵
    """
    def __init__(self, node_dim, hidden_dim):
        super().__init__()
        
        # 相似度函数
        self.similarity = nn.Sequential(
            nn.Linear(node_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
        # 阈值化
        self.threshold = 0.5
    
    def forward(self, x):
        """
        学习邻接矩阵
        
        Args:
            x: (num_nodes, node_dim) 节点特征
        
        Returns:
            adj: (num_nodes, num_nodes) 学习到的邻接矩阵
        """
        n = x.shape[0]
        
        # 计算节点对之间的相似度
        similarities = []
        for i in range(n):
            for j in range(n):
                if i != j:
                    pair = torch.cat([x[i], x[j]])
                    sim = self.similarity(pair.unsqueeze(0))
                    similarities.append((i, j, sim))
        
        # 构建稀疏邻接矩阵
        rows, cols, vals = zip(*similarities)
        adj = torch.zeros(n, n)
        adj[list(rows), list(cols)] = torch.cat(vals).squeeze()
        
        # 对称化
        adj = (adj + adj.T) / 2
        
        return adj
    
    def forward_with_learned_graph(self, x, edge_index):
        """
        使用学习的图结构进行消息传递
        """
        # 学习图结构
        adj = self.forward(x)
        
        # 归一化
        deg = adj.sum(dim=1, keepdim=True)
        adj = adj / deg
        
        # 消息传递
        return adj @ x

概率图结构学习

class ProbabilisticGraphLearner(nn.Module):
    """
    概率图结构学习
    
    使用Gumbel-Softmax学习离散图结构
    """
    def __init__(self, node_dim, temperature=1.0):
        super().__init__()
        self.temperature = temperature
        
        # 边存在概率的参数
        self.edge_logits = nn.Linear(node_dim * 2, 1)
    
    def forward(self, x, hard=True):
        """
        采样边
        
        Args:
            x: 节点特征
            hard: 是否使用硬采样
        
        Returns:
            edge_probs: 边概率
            edge_mask: 采样的边掩码
        """
        n = x.shape[0]
        
        # 计算边分数
        rows, cols = [], []
        for i in range(n):
            for j in range(n):
                if i != j:
                    rows.append(i)
                    cols.append(j)
        
        rows = torch.tensor(rows, device=x.device)
        cols = torch.tensor(cols, device=x.device)
        
        # 拼接特征计算边分数
        x_i = x[rows]
        x_j = x[cols]
        pair = torch.cat([x_i, x_j], dim=1)
        
        logits = self.edge_logits(pair).squeeze()
        
        # Gumbel-Softmax采样
        if hard:
            # 硬采样:使用argmax
            probs = torch.sigmoid(logits)
            edge_mask = (probs > 0.5).float()
        else:
            # 软采样:Gumbel-Softmax
            gumbels = -torch.empty_like(logits).exponential_().log()
            gumbels = (logits + gumbels) / self.temperature
            edge_mask = torch.sigmoid(gumbels)
        
        return probs, edge_mask, (rows, cols)

知识图谱补全的概率视角

知识图谱的表示学习

知识图谱由三元组 组成,表示头实体 与尾实体 之间存在关系

概率模型

翻译模型(TransE)可以概率化:

双线性模型

变分知识图谱嵌入

class VariationalKnowledgeGraphEmbedding(nn.Module):
    """
    变分知识图谱嵌入
    
    引入实体和关系的概率分布
    """
    def __init__(self, num_entities, num_relations, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim
        
        # 实体嵌入:均值和方差
        self.entity_mu = nn.Embedding(num_entities, latent_dim)
        self.entity_logvar = nn.Embedding(num_entities, latent_dim)
        
        # 关系嵌入
        self.relation = nn.Embedding(num_relations, latent_dim)
        
        # 先验(标准高斯)
        self.prior = torch.distributions.Normal(0, 1)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def score(self, h, r, t):
        """
        计算三元组分数
        """
        # TransE评分函数
        h_emb = self.entity_mu(h)
        r_emb = self.relation(r)
        t_emb = self.entity_mu(t)
        
        score = torch.norm(h_emb + r_emb - t_emb, p=2, dim=-1)
        return -score  # 分数越高越好
    
    def loss(self, pos_triples, neg_triples):
        """
        损失函数
        
        使用负采样损失 + KL散度
        """
        h, r, t = pos_triples
        
        # 正样本分数
        pos_score = self.score(h, r, t)
        
        # 负样本分数
        h_neg, r_neg, t_neg = neg_triples
        neg_score = self.score(h_neg, r_neg, t_neg)
        
        # 负采样损失
        margin = 1.0
        loss = torch.clamp(neg_score - pos_score + margin, min=0).mean()
        
        # KL散度
        h_mu = self.entity_mu(h)
        h_logvar = self.entity_logvar(h)
        t_mu = self.entity_mu(t)
        t_logvar = self.entity_logvar(t)
        
        kl_loss = -0.5 * torch.mean(
            1 + h_logvar - h_mu.pow(2) - h_logvar.exp() +
            1 + t_logvar - t_mu.pow(2) - t_logvar.exp()
        )
        
        return loss + 0.01 * kl_loss

图上的变分推断

消息传递变分推断

将GNN视为变分推断的一个步骤:

class MessagePassingVI(nn.Module):
    """
    消息传递变分推断
    
    将GNN层视为变分E步或M步
    """
    def __init__(self, num_nodes, node_dim, hidden_dim):
        super().__init__()
        self.num_nodes = num_nodes
        
        # 变分分布参数
        self.q_params = nn.Parameter(torch.randn(num_nodes, hidden_dim))
        
        # 消息网络
        self.message_net = nn.Sequential(
            nn.Linear(node_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # 更新网络
        self.update_net = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU()
        )
    
    def e_step(self, x, edge_index):
        """
        E步:更新变分参数
        
        对应消息传递
        """
        row, col = edge_index
        
        # 构建消息
        messages = self.message_net(torch.cat([x[row], x[col]], dim=1))
        
        # 聚合
        aggr = torch.zeros(self.num_nodes, messages.shape[1])
        aggr.index_add_(0, row, messages)
        
        # 归一化
        deg = torch.zeros(self.num_nodes)
        deg.scatter_add_(0, row, torch.ones_like(row, dtype=torch.float))
        deg = deg.clamp(min=1)
        aggr = aggr / deg.unsqueeze(1)
        
        # 更新变分参数
        self.q_params.data = self.update_net(
            torch.cat([self.q_params.data, aggr], dim=1)
        )
        
        return self.q_params
    
    def m_step(self, x):
        """
        M步:更新变分分布参数
        
        最小化KL散度
        """
        # 这里简化为更新变分分布的均值
        pass
    
    def elbo(self, x, edge_index):
        """
        计算ELBO
        
        ELBO = 重构损失 - KL散度
        """
        # 重构损失(简化版本)
        recon_loss = F.mse_loss(self.q_params, x)
        
        # KL散度(与先验的差异)
        prior = torch.distributions.Normal(0, 1)
        q = torch.distributions.Normal(self.q_params, 1)
        kl_loss = torch.distributions.kl.kl_divergence(q, prior).mean()
        
        return recon_loss + 0.1 * kl_loss

图神经网络的不确定性量化

集成方法

class EnsembleGNN(nn.Module):
    """
    集成GNN
    
    通过多个GNN模型的不确定性量化
    """
    def __init__(self, num_features, hidden_dim, num_classes, num_models=5):
        super().__init__()
        self.models = nn.ModuleList([
            GNN(num_features, hidden_dim, num_classes)
            for _ in range(num_models)
        ])
    
    def forward(self, x, edge_index):
        """
        返回预测均值和方差
        """
        predictions = []
        for model in self.models:
            model.eval()
            with torch.no_grad():
                out = model(x, edge_index)
                predictions.append(out)
        
        predictions = torch.stack(predictions)
        
        # 集成预测
        mean = predictions.mean(dim=0)
        variance = predictions.var(dim=0)
        
        return mean, variance
    
    def predict_with_uncertainty(self, x, edge_index):
        """
        带不确定性的预测
        """
        mean, variance = self.forward(x, edge_index)
        std = torch.sqrt(variance)
        
        return mean, std

Monte Carlo Dropout

class MCDropoutGNN(nn.Module):
    """
    Monte Carlo Dropout GNN
    
    通过多次Dropout采样估计不确定性
    """
    def __init__(self, num_features, hidden_dim, num_classes, dropout=0.5):
        super().__init__()
        
        self.conv1 = GCNConv(num_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, edge_index, num_samples=10):
        """
        MC Dropout前向传播
        """
        predictions = []
        
        for _ in range(num_samples):
            # 启用dropout
            h = self.dropout(F.relu(self.conv1(x, edge_index)))
            h = self.dropout(F.relu(self.conv2(h, edge_index)))
            out = self.classifier(h)
            predictions.append(out)
        
        predictions = torch.stack(predictions)
        
        mean = predictions.mean(dim=0)
        variance = predictions.var(dim=0)
        
        return mean, variance

与现有wiki内容的联系


参考


相关阅读

Footnotes

  1. Kipf, E., & Welling, M. (2017). Semi-supervised classification with graph convolutional networks. ICLR.