1. 概述

图神经网络(GNN)通过消息传递机制聚合邻居信息来学习节点表示,但标准GNN是点估计模型,缺乏对预测不确定性的建模。在许多实际应用(如药物发现、金融风控)中,量化预测的不确定性至关重要。1

本文内容结构

┌─────────────────────────────────────────────────────────────────────┐
│                  图神经网络概率推断框架                                │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│   ┌─────────────────┐    ┌─────────────────┐                       │
│   │  概率图模型基础  │───▶│ 消息传递的贝叶斯 │                       │
│   └─────────────────┘    │    解释        │                       │
│           │               └────────┬────────┘                       │
│           ▼                        │                                │
│   ┌─────────────────┐    ┌────────▼────────┐                       │
│   │  置信度传播     │◀───▶│  概率消息传递  │                       │
│   └─────────────────┘    └────────┬────────┘                       │
│           │                        │                                │
│           ▼                        ▼                                │
│   ┌─────────────────────────────────────────────┐                   │
│   │            不确定性量化方法                  │                   │
│   │  ┌──────────┐ ┌──────────┐ ┌──────────┐    │                   │
│   │  │MC Dropout│ │ Bootstrap │ │ 贝叶斯   │    │                   │
│   │  │          │ │           │ │   GNN    │    │                   │
│   │  └──────────┘ └──────────┘ └──────────┘    │                   │
│   └─────────────────────────────────────────────┘                   │
│                           │                                          │
│                           ▼                                          │
│   ┌─────────────────────────────────────────────┐                   │
│   │              应用与实现                        │                   │
│   │        链接预测 · 节点分类 · 变分推断         │                   │
│   └─────────────────────────────────────────────┘                   │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

2. 图上的概率模型概述

2.1 马尔可夫随机场与图神经网络

图结构数据天然适合用概率图模型描述。设图 ,节点特征为 ,标签为

**马尔可夫随机场(MRF)**定义联合分布:

其中 是势函数, 是配分函数。

2.2 条件随机场与半监督学习

在节点分类中,我们通常建模条件分布 。根据 条件随机场理论:

GNN的消息传递机制可以视为对这一条件分布的近似推断

2.3 图的贝叶斯视角

从贝叶斯角度,GNN的参数学习涉及:

其中 是GNN参数, 是观测数据。


3. 消息传递的概率解释

3.1 置信度传播概述

**置信度传播(Belief Propagation, BP)**是图模型中精确推断的核心算法。对于树结构图,BP能够精确计算边缘分布。2

消息传递规则

节点置信度

3.2 BP消息传递与GNN消息传递的对应关系

GNN的消息传递机制与BP有深刻的联系。考虑节点 的嵌入更新:

对应关系表

BP分量GNN分量概率解释
势函数 节点自环特征 节点先验/证据
消息 聚合消息 条件置信度传递
置信度 节点嵌入 后验边缘分布
归一化激活函数 概率归一化

3.3 形式化对应推导

表示节点 在第 层的信念(置信度),类比GNN嵌入

BP更新

GNN更新(以GCN为例):

关键类比

\underbrace{m_{ji}^{(l)}}_{\text{BP消息}} \quad \Longleftrightarrow \quad \underbrace{\frac{1}{\sqrt{\tilde{d}_i \tilde{d}_j} \mathbf{h}_j^{(l)}}_{\text{GNN聚合消息}}

3.4 loopy Belief Propagation与GNN

对于带环图,BP不再精确,但**loopy Belief Propagation(LBP)**通过迭代消息传递仍能获得近似解——这与GNN的多层堆叠惊人相似:

┌─────────────────────────────────────────────────────────────────────┐
│               BP推断 vs GNN消息传递                                    │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│   BP (精确/近似推断)              GNN (表示学习)                       │
│   ┌─────────────────┐            ┌─────────────────┐               │
│   │                 │            │                 │               │
│   │  消息: 传递     │            │  消息: 聚合      │               │
│   │  概率分布       │            │  向量嵌入        │               │
│   │                 │            │                 │               │
│   └────────┬────────┘            └────────┬────────┘               │
│            │                               │                        │
│            ▼                               ▼                        │
│   ┌─────────────────┐            ┌─────────────────┐               │
│   │                 │            │                 │               │
│   │  信念: 边缘分布  │            │  嵌入: 节点表示  │               │
│   │  b_i(y_i)      │            │  h_i            │               │
│   │                 │            │                 │               │
│   └────────┬────────┘            └────────┬────────┘               │
│            │                               │                        │
│            ▼                               ▼                        │
│   ┌─────────────────┐            ┌─────────────────┐               │
│   │  迭代直到收敛    │            │  固定L层前向传播 │               │
│   │  (循环图近似)   │            │  (可学习参数)   │               │
│   └─────────────────┘            └─────────────────┘               │
│                                                                     │
│   共同点:都是通过迭代消息传递聚合局部信息                              │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

4. 图神经网络的不确定性量化

4.1 不确定性分类

在GNN中,不确定性主要分为两类:

类型定义来源量化方法
认知不确定性 (Epistemic)模型参数不确定性训练数据不足贝叶斯方法
偶然不确定性 (Aleatoric)数据固有噪声标签模糊/特征噪声损失函数建模
分布外检测 (OOD)输入是否在训练分布内泛化边界能量函数

4.2 MC Dropout for GNN

MC Dropout是最简单的贝叶斯近似方法,通过多次前向传播时保持Dropout开启来近似贝叶斯后验。3

原理

其中 是第 次Dropout采样对应的网络。

预测均值与方差

PyTorch Geometric实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch_geometric.nn as pyg_nn
 
class MCDropoutGNN(nn.Module):
    """
    MC Dropout for Graph Neural Networks
    使用MC Dropout量化GNN预测的不确定性
    """
    def __init__(self, in_channels, hidden_channels, out_channels, dropout_rate=0.5):
        super().__init__()
        self.dropout = nn.Dropout(dropout_rate)
        
        # 图卷积层
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, out_channels)
        
        # 输出层
        self.fc = nn.Linear(hidden_channels, out_channels)
    
    def forward(self, x, edge_index):
        # 第一次GCN + Dropout
        x = self.dropout(F.relu(self.conv1(x, edge_index)))
        x = self.dropout(F.relu(self.conv2(x, edge_index)))
        x = self.conv3(x, edge_index)
        return x
    
    def predict_with_uncertainty(self, x, edge_index, n_samples=30):
        """
        MC Dropout采样,返回均值和标准差
        """
        self.train()  # 开启Dropout
        
        predictions = []
        with torch.no_grad():
            for _ in range(n_samples):
                logits = self.forward(x, edge_index)
                probs = F.softmax(logits, dim=-1)
                predictions.append(probs)
        
        predictions = torch.stack(predictions, dim=0)  # (n_samples, N, out_channels)
        
        # 计算均值和方差
        mean_prob = predictions.mean(dim=0)
        uncertainty = predictions.std(dim=0)  # 认知不确定性
        
        return mean_prob, uncertainty
 
 
def train_mc_dropout_gnn(model, train_data, val_data, epochs=200):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        out = model(train_data.x, train_data.edge_index)
        loss = criterion(out[train_data.train_mask], train_data.y[train_data.train_mask])
        loss.backward()
        optimizer.step()
        
        if epoch % 20 == 0:
            # 评估不确定性
            model.eval()
            with torch.no_grad():
                mean_prob, uncertainty = model.predict_with_uncertainty(
                    val_data.x, val_data.edge_index, n_samples=30
                )
                pred = mean_prob.argmax(dim=1)
                acc = (pred[val_data.val_mask] == val_data.y[val_data.val_mask]).float().mean()
                avg_uncertainty = uncertainty[val_data.val_mask].mean()
                
                print(f"Epoch {epoch}: Acc={acc:.4f}, Avg Uncertainty={avg_uncertainty:.4f}")

4.3 Bootstrap方法

Bootstrap GNN通过训练多个GNN副本并聚合预测来估计不确定性。

算法步骤

  1. 从原始数据中有放回采样生成个子数据集
  2. 在每个子数据集上训练一个GNN
  3. 聚合个模型的预测

PyTorch实现

class BootstrapGNN(nn.Module):
    """
    Bootstrap Ensemble for GNN
    训练B个独立的GNN模型,使用bootstrap采样进行训练
    """
    def __init__(self, in_channels, hidden_channels, out_channels, n_estimators=10):
        super().__init__()
        self.n_estimators = n_estimators
        
        # 创建B个独立的GNN
        self.models = nn.ModuleList([
            nn.Sequential(
                GCNConv(in_channels, hidden_channels),
                nn.ReLU(),
                nn.Dropout(0.5),
                GCNConv(hidden_channels, hidden_channels),
                nn.ReLU(),
                nn.Dropout(0.5),
                GCNConv(hidden_channels, out_channels)
            )
            for _ in range(n_estimators)
        ])
    
    def forward(self, x, edge_index, model_idx=None):
        """
        如果指定model_idx,只用单个模型
        否则返回所有模型的平均预测
        """
        if model_idx is not None:
            return self.models[model_idx](x, edge_index)
        
        # 聚合所有模型预测
        outputs = [model(x, edge_index) for model in self.models]
        return torch.stack(outputs).mean(dim=0)
    
    def predict_with_uncertainty(self, x, edge_index):
        outputs = torch.stack([model(x, edge_index) for model in self.models])
        mean = outputs.mean(dim=0)
        std = outputs.std(dim=0)
        return mean, std
 
 
def train_bootstrap_gnn(data, n_estimators=10, epochs=200):
    """训练Bootstrap GNN集合"""
    models = []
    
    for b in range(n_estimators):
        # Bootstrap采样
        n_nodes = data.x.size(0)
        bootstrap_idx = torch.randint(0, n_nodes, (n_nodes,))
        bootstrap_mask = torch.zeros(n_nodes, dtype=torch.bool)
        bootstrap_mask[bootstrap_idx] = True
        
        # 创建子数据集
        class SubDataset:
            def __init__(self, data, mask):
                self.x = data.x
                self.edge_index = data.edge_index
                self.y = data.y
                self.mask = mask
        
        sub_data = SubDataset(data, bootstrap_mask)
        
        # 训练单个模型
        model = BootstrapGNN(data.x.size(1), 64, data.y.max().item() + 1, n_estimators=1)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        
        for epoch in range(epochs):
            model.train()
            optimizer.zero_grad()
            out = model(data.x, data.edge_index)
            loss = F.cross_entropy(out[sub_data.mask], data.y[sub_data.mask])
            loss.backward()
            optimizer.step()
        
        models.append(model.models[0])
    
    # 包装成BootstrapGNN
    ensemble = BootstrapGNN(data.x.size(1), 64, data.y.max().item() + 1, n_estimators=n_estimators)
    for i, m in enumerate(models):
        ensemble.models[i].load_state_dict(m.state_dict())
    
    return ensemble

4.4 贝叶斯GNN

贝叶斯GNN对GNN权重引入后验分布 ,是最 principled 的不确定性量化方法。

变分推断近似

变分下界(ELBO)

PyTorch实现

from torch.distributions import Normal, kl_divergence
 
class BayesianGNNLayer(nn.Module):
    """
    变分贝叶斯图卷积层
    使用重参数化技巧进行变分推断
    """
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        # 均值和标准差参数
        self.weight_mean = nn.Parameter(torch.randn(in_channels, out_channels))
        self.weight_log_std = nn.Parameter(torch.zeros(in_channels, out_channels))
        
        self.bias_mean = nn.Parameter(torch.zeros(out_channels))
        self.bias_log_std = nn.Parameter(torch.zeros(out_channels))
        
        # 初始化
        nn.init.kaiming_normal_(self.weight_mean)
        nn.init.constant_(self.weight_log_std, -5)  # 初始标准差小
    
    def forward(self, x, edge_index, sampling=True):
        # 边权重归一化
        edge_weight = torch.ones(edge_index.size(1), device=x.device)
        row, col = edge_index
        deg = torch.zeros(x.size(0), device=x.device)
        deg = deg.scatter_add(0, row, edge_weight)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        
        if sampling or self.training:
            # 重参数化采样
            weight = self.weight_mean + torch.randn_like(self.weight_mean) * torch.exp(self.weight_log_std)
            bias = self.bias_mean + torch.randn_like(self.bias_mean) * torch.exp(self.bias_log_std)
        else:
            # 使用均值(用于推理)
            weight = self.weight_mean
            bias = self.bias_mean
        
        # 图卷积
        out = torch.matmul(x, weight)
        out = out * norm.view(-1, 1)
        out = out.index_select(0, col)
        out = out.index_add(0, row, torch.zeros(x.size(0), out.size(1), device=x.device))
        out = out + bias
        
        return out
    
    def kl_loss(self):
        """计算该层的KL散度"""
        weight_prior = Normal(0, 1)
        weight_posterior = Normal(self.weight_mean, torch.exp(self.weight_log_std))
        kl_w = kl_divergence(weight_posterior, weight_prior).sum()
        
        bias_prior = Normal(0, 1)
        bias_posterior = Normal(self.bias_mean, torch.exp(self.bias_log_std))
        kl_b = kl_divergence(bias_posterior, bias_prior).sum()
        
        return kl_w + kl_b
 
 
class BayesianGNN(nn.Module):
    """
    贝叶斯图神经网络
    """
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.layer1 = BayesianGNNLayer(in_channels, hidden_channels)
        self.layer2 = BayesianGNNLayer(hidden_channels, hidden_channels)
        self.layer3 = BayesianGNNLayer(hidden_channels, out_channels)
    
    def forward(self, x, edge_index, sampling=True):
        x = F.relu(self.layer1(x, edge_index, sampling))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.layer2(x, edge_index, sampling))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.layer3(x, edge_index, sampling)
        return x
    
    def predict_with_uncertainty(self, x, edge_index, n_samples=30):
        """使用MC采样估计不确定性"""
        self.train()
        
        predictions = []
        with torch.no_grad():
            for _ in range(n_samples):
                logits = self.forward(x, edge_index, sampling=True)
                predictions.append(logits)
        
        predictions = torch.stack(predictions)
        mean = predictions.mean(dim=0)
        std = predictions.std(dim=0)
        
        return mean, std
    
    def elbo_loss(self, x, edge_index, y, mask, n_samples=1, beta=1.0):
        """
        证据下界(ELBO)损失
        """
        # KL散度
        kl = self.layer1.kl_loss() + self.layer2.kl_loss() + self.layer3.kl_loss()
        
        # 重参数化采样估计似然
        log_likelihood = 0
        for _ in range(n_samples):
            logits = self.forward(x, edge_index, sampling=True)
            log_likelihood += F.cross_entropy(logits[mask], y[mask], reduction='mean')
        log_likelihood /= n_samples
        
        # ELBO = -L = - (log p(D|Theta) - beta * KL)
        # 实际优化时最小化 -ELBO
        return -log_likelihood + beta * kl

5. 概率图神经网络(Probabilistic Graph Networks)

5.1 定义与架构

**概率图神经网络(PGN)**将概率推断直接融入GNN的消息传递过程,在每个消息传递步骤建模分布参数。4

核心思想

5.2 实现示例

class ProbabilisticGraphNetwork(nn.Module):
    """
    概率图神经网络
    在每层输出分布参数而非点估计
    """
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        
        # 消息网络:输出均值和标准差
        self.message_net = nn.Sequential(
            nn.Linear(hidden_channels * 2, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels * 2)  # mu和log_std
        )
        
        # 聚合网络
        self.aggregate_net = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU()
        )
        
        # 更新网络
        self.update_net = nn.Sequential(
            nn.Linear(hidden_channels * 2, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels * 2)
        )
        
        # 初始嵌入层
        self.input_net = nn.Linear(in_channels, hidden_channels)
        
        # 分类器
        self.classifier = nn.Linear(hidden_channels, out_channels)
    
    def message(self, mu_i, sigma_i, mu_j, sigma_j):
        """消息函数:输出分布参数"""
        combined = torch.cat([mu_i, mu_j], dim=-1)
        out = self.message_net(combined)
        mu, log_std = out.chunk(2, dim=-1)
        sigma = F.softplus(log_std) + 1e-6
        return mu, sigma
    
    def aggregate(self, messages):
        """聚合消息:使用均值聚合"""
        mu_messages, sigma_messages = zip(*messages)
        mu_agg = torch.stack(mu_messages).mean(dim=0)
        sigma_agg = torch.stack(sigma_messages).mean(dim=0)
        return mu_agg, sigma_agg
    
    def update(self, mu_self, sigma_self, mu_agg, sigma_agg):
        """更新函数:结合自身状态和聚合消息"""
        combined = torch.cat([mu_self, mu_agg], dim=-1)
        out = self.update_net(combined)
        mu, log_std = out.chunk(2, dim=-1)
        sigma = F.softplus(log_std) + 1e-6
        
        # 残差连接
        mu = mu + mu_self
        sigma = (sigma + sigma_self) / 2
        return mu, sigma
    
    def forward(self, x, edge_index):
        N = x.size(0)
        
        # 初始化
        mu = F.relu(self.input_net(x))
        sigma = torch.ones_like(mu)
        
        # 多层消息传递
        for _ in range(3):
            # 收集邻居消息
            row, col = edge_index
            messages = []
            
            for i in range(edge_index.size(1)):
                src, dst = col[i].item(), row[i].item()
                msg = self.message(mu[dst:dst+1], sigma[dst:dst+1], 
                                  mu[src:src+1], sigma[src:src+1])
                messages.append((msg[0].squeeze(0), msg[1].squeeze(0)))
            
            # 按节点聚合
            new_mu = torch.zeros_like(mu)
            new_sigma = torch.zeros_like(sigma)
            count = torch.zeros(N, 1, device=x.device)
            
            for i in range(edge_index.size(1)):
                src, dst = col[i].item(), row[i].item()
                new_mu[dst] += messages[i][0]
                new_sigma[dst] += messages[i][1] ** 2
                count[dst] += 1
            
            # 归一化
            count[count == 0] = 1
            new_mu = new_mu / count
            new_sigma = torch.sqrt(new_sigma / count)
            
            # 更新
            mu, sigma = self.update(mu, sigma, new_mu, new_sigma)
        
        # 分类
        return self.classifier(mu), mu, sigma

6. 图上的变分推断

6.1 图变分自编码器(VGAE)

**变分图自编码器(VGAE)**使用变分推断学习图结构数据的低维表示。5

模型结构

解码器

损失函数

6.2 PyTorch Geometric实现VGAE

from torch_geometric.nn import VGAE, GCNConv
 
class VariationalGraphAutoEncoder(nn.Module):
    """
    变分图自编码器
    用于无监督图表示学习和链接预测
    """
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        
        # 编码器:两层GCN
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        
        # 均值和方差预测
        self.conv_mu = GCNConv(out_channels, out_channels)
        self.conv_logstd = GCNConv(out_channels, out_channels)
        
        # 解码器:内积解码器
        self.decoder = InnerProductDecoder()
        
        # VGAE包装器
        self.vgae = VGAE(encoder=self, decoder=self.decoder)
    
    def encode(self, x, edge_index):
        # 图卷积
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        
        # 均值和标准差
        mu = self.conv_mu(x, edge_index)
        logstd = self.conv_logstd(x, edge_index)
        
        return mu, logstd
    
    def reparameterize(self, mu, logstd):
        if self.training:
            return mu + torch.randn_like(logstd) * torch.exp(logstd)
        return mu
    
    def decode(self, z, edge_index):
        return self.decoder.forward(z, edge_index)
 
 
class InnerProductDecoder(nn.Module):
    """内积解码器用于链接预测"""
    def __init__(self):
        super().__init__()
    
    def forward(self, z, edge_index, sigmoid=True):
        # 重建所有可能的边
        adj = torch.sigmoid(z @ z.t())
        return adj
 
 
def train_vgae(data):
    """训练VGAE"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = VariationalGraphAutoEncoder(
        in_channels=data.x.size(1),
        hidden_channels=32,
        out_channels=16
    ).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    x, edge_index = data.x.to(device), data.edge_index.to(device)
    
    for epoch in range(1, 501):
        model.train()
        optimizer.zero_grad()
        
        # 编码
        z_mu, z_logstd = model.encode(x, edge_index)
        z = model.reparameterize(z_mu, z_logstd)
        
        # 重建邻接矩阵
        adj_recon = model.decode(z, edge_index)
        
        # 变分损失
        loss = model.vgae.loss(z, edge_index, z_mu, z_logstd)
        
        loss.backward()
        optimizer.step()
        
        if epoch % 50 == 0:
            print(f"Epoch {epoch}: Loss = {loss.item():.4f}")
    
    return model

7. 应用案例

7.1 链接预测的不确定性估计

链接预测是图上的核心任务。量化预测不确定性对于推荐系统和知识图谱补全至关重要。

class UncertaintyLinkPrediction(nn.Module):
    """
    带不确定性估计的链接预测模型
    """
    def __init__(self, gnn_model, dropout_model):
        super().__init__()
        self.gnn_model = gnn_model  # 主GNN
        self.dropout_model = dropout_model  # MC Dropout wrapper
    
    def link_predict_with_uncertainty(self, x, edge_index, edge_list, n_samples=30):
        """
        预测边存在概率及其不确定性
        """
        # 方法1:MC Dropout
        self.dropout_model.train()
        
        logits_list = []
        with torch.no_grad():
            for _ in range(n_samples):
                z = self.dropout_model(x, edge_index)
                # 计算边得分
                src, dst = edge_list[:, 0], edge_list[:, 1]
                scores = (z[src] * z[dst]).sum(dim=-1)
                logits_list.append(scores)
        
        logits = torch.stack(logits_list)
        mean_score = logits.mean(dim=0)
        uncertainty = logits.std(dim=0)  # 预测的不确定性
        
        # 转换为概率
        prob = torch.sigmoid(mean_score)
        
        return prob, uncertainty
    
    def detect_ood_edges(self, x, edge_index, train_edges, val_edges, threshold=0.1):
        """
        检测分布外边(可能是噪声或新类型的边)
        """
        # 计算训练边的能量分数分布
        train_prob, train_unc = self.link_predict_with_uncertainty(
            x, edge_index, train_edges, n_samples=30
        )
        
        # 验证边的不确定性
        val_prob, val_unc = self.link_predict_with_uncertainty(
            x, edge_index, val_edges, n_samples=30
        )
        
        # 高不确定性边可能是OOD
        ood_mask = val_unc > threshold
        
        return ood_mask, val_prob, val_unc

7.2 节点分类的不确定性感知预测

class UncertaintyAwareNodeClassification:
    """
    不确定性感知节点分类器
    能够识别高风险预测并拒绝或人工审核
    """
    
    def __init__(self, model, confidence_threshold=0.9, uncertainty_threshold=0.1):
        self.model = model
        self.confidence_threshold = confidence_threshold
        self.uncertainty_threshold = uncertainty_threshold
    
    def predict(self, data, n_samples=30):
        # MC Dropout采样
        self.model.train()
        predictions = []
        
        with torch.no_grad():
            for _ in range(n_samples):
                out = self.model(data.x, data.edge_index)
                probs = F.softmax(out, dim=1)
                predictions.append(probs)
        
        predictions = torch.stack(predictions)
        
        # 均值预测
        mean_probs = predictions.mean(dim=0)
        predictions_entropy = -(mean_probs * torch.log(mean_probs + 1e-8)).sum(dim=1)
        
        # 预测不确定性
        epistemic_unc = predictions.std(dim=0).mean(dim=1)
        
        # 最终预测
        pred_labels = mean_probs.argmax(dim=1)
        confidences = mean_probs.max(dim=1)[0]
        
        return {
            'labels': pred_labels,
            'confidences': confidences,
            'epistemic_uncertainty': epistemic_unc,
            'entropy': predictions_entropy
        }
    
    def filter_predictions(self, data, n_samples=30):
        """
        过滤出需要人工审核的预测
        """
        results = self.predict(data, n_samples)
        
        # 定义需要审核的条件
        needs_review = (
            (results['confidences'] < self.confidence_threshold) |
            (results['epistemic_uncertainty'] > self.uncertainty_threshold) |
            (results['entropy'] > 1.0)  # 熵高于阈值
        )
        
        reliable = ~needs_review
        
        return {
            'reliable': reliable,
            'needs_review': needs_review,
            **results
        }

8. 总结与展望

8.1 方法对比

方法优点缺点适用场景
MC Dropout简单,无需修改训练近似粗糙,采样多时慢快速原型
Bootstrap易于实现,稳定需训练多模型,内存大生产部署
贝叶斯GNN理论基础好训练复杂,参数多理论研究
概率GNN表示能力强实现复杂高精度任务

8.2 未来方向

方向说明
高效贝叶斯减少贝叶斯GNN的计算开销
异构图支持处理异构图上的不确定性
时序图动态图的不确定性建模
图生成生成模型的不确定性量化
可解释性将不确定性用于模型解释

相关文档


参考文献

Footnotes

  1. Zhang et al., “Uncertainty Estimation for Graph Neural Networks”, SIGKDD 2022.

  2. Yedidia et al., “Understanding Belief Propagation and Its Generalizations”, IJCAI 2003.

  3. Gal & Ghahramani, “Dropout as a Bayesian Approximation”, ICML 2016.

  4. Liu et al., “Probabilistic Graph Neural Networks”, NeurIPS 2019.

  5. Kipf & Welling, “Variational Graph Auto-Encoders”, NeurIPS 2016.