深度概率推断实践

1 引言

深度概率推断(Deep Probabilistic Inference)将概率图模型的推断能力与深度学习的表示学习能力相结合。本章提供完整的PyTorch实现,包括:

  • 可微分消息传递层
  • 概率推断模块
  • 端到端训练流程
  • 实际应用案例

2 可微分消息传递层

2.1 基础消息传递

消息传递是概率图模型的核心操作。我们首先实现一个通用的可微分消息传递层:

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List
 
class DifferentiableMessagePassing(nn.Module):
    """可微分消息传递层基类"""
    
    def __init__(self, node_dim: int, edge_dim: int, msg_dim: int):
        super().__init__()
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.msg_dim = msg_dim
        
        # 消息函数
        self.msg_fn = nn.Sequential(
            nn.Linear(node_dim + edge_dim, msg_dim),
            nn.ReLU(),
            nn.Linear(msg_dim, msg_dim)
        )
        
        # 聚合函数
        self.aggr_fn = nn.Linear(msg_dim, msg_dim)
        
        # 更新函数
        self.update_fn = nn.GRUCell(msg_dim, node_dim)
        
    def message(self, source: torch.Tensor, target: torch.Tensor, 
                edge_attr: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        计算从source到target的消息
        
        Args:
            source: (num_edges, node_dim) 源节点特征
            target: (num_edges, node_dim) 目标节点特征
            edge_attr: (num_edges, edge_dim) 边特征
        
        Returns:
            messages: (num_edges, msg_dim) 消息
        """
        if edge_attr is not None:
            combined = torch.cat([source, edge_attr], dim=-1)
        else:
            combined = source
        
        return self.msg_fn(combined)
    
    def aggregate(self, messages: torch.Tensor, 
                  index: torch.Tensor, 
                  num_nodes: int) -> torch.Tensor:
        """
        聚合消息
        
        Args:
            messages: (num_edges, msg_dim) 消息
            index: (num_edges,) 目标节点索引
            num_nodes: int 节点数
        
        Returns:
            aggregated: (num_nodes, msg_dim) 聚合后的消息
        """
        # 散点聚合
        aggregated = torch.zeros(
            num_nodes, messages.size(-1), 
            device=messages.device, dtype=messages.dtype
        )
        aggregated.index_add_(0, index, messages)
        
        # 归一化(按邻居数量)
        counts = torch.zeros(num_nodes, device=messages.device)
        counts.index_add_(0, index, torch.ones_like(index, dtype=torch.float))
        counts = counts.clamp(min=1).unsqueeze(-1)
        
        return self.aggr_fn(aggregated / counts)
    
    def update(self, node_attr: torch.Tensor, 
               messages: torch.Tensor) -> torch.Tensor:
        """
        更新节点特征
        
        Args:
            node_attr: (num_nodes, node_dim) 当前节点特征
            messages: (num_nodes, msg_dim) 聚合后的消息
        
        Returns:
            updated: (num_nodes, node_dim) 更新后的节点特征
        """
        return self.update_fn(messages, node_attr)
    
    def forward(self, node_attr: torch.Tensor, 
                edge_index: torch.Tensor,
                edge_attr: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        完整的前向传播
        
        Args:
            node_attr: (num_nodes, node_dim) 节点特征
            edge_index: (2, num_edges) 边索引
            edge_attr: (num_edges, edge_dim) 边特征
        
        Returns:
            updated_node_attr: (num_nodes, node_dim) 更新后的节点特征
        """
        row, col = edge_index
        
        # 消息计算
        messages = self.message(node_attr[row], node_attr[col], edge_attr)
        
        # 消息聚合
        aggregated = self.aggregate(messages, col, node_attr.size(0))
        
        # 节点更新
        updated = self.update(node_attr, aggregated)
        
        return updated

2.2 信念传播层

实现标准的和-积算法(Sum-Product Algorithm):

class BeliefPropagationLayer(nn.Module):
    """和-积信念传播层"""
    
    def __init__(self, num_states: int, message_dim: int):
        super().__init__()
        self.num_states = num_states
        self.message_dim = message_dim
        
        # 势函数参数化
        self.potential_fn = nn.Sequential(
            nn.Linear(num_states * 2, message_dim),
            nn.ReLU(),
            nn.Linear(message_dim, 1)
        )
        
        # 消息归一化
        self.msg_normalize = nn.LayerNorm(message_dim)
        
    def forward(self, beliefs: torch.Tensor, 
               edge_index: torch.Tensor,
               num_iterations: int = 3) -> torch.Tensor:
        """
        执行信念传播迭代
        
        Args:
            beliefs: (num_nodes, num_states) 初始信念
            edge_index: (2, num_edges) 边索引
            num_iterations: 迭代次数
        
        Returns:
            final_beliefs: (num_nodes, num_states) 最终信念
        """
        num_nodes, num_states = beliefs.shape
        row, col = edge_index
        
        # 初始化消息为均匀分布
        messages = torch.ones(
            edge_index.size(1), num_states, 
            device=beliefs.device
        ) / num_states
        
        for iteration in range(num_iterations):
            new_messages = []
            
            for e_idx in range(edge_index.size(1)):
                src, tgt = row[e_idx], col[e_idx]
                
                # 从源节点收集所有入消息
                src_msg_sum = messages[row == src].sum(dim=0)
                
                # 排除目标节点的消息(避免重复)
                if (row == tgt).any():
                    tgt_incoming = messages[col == tgt]
                    mask = ~(row[row == tgt] == src)
                    src_msg_sum = tgt_incoming[mask].sum(dim=0)
                
                # 消息更新
                msg = F.softmax(
                    self.potential_fn(
                        torch.cat([
                            src_msg_sum.unsqueeze(0).expand(num_states, -1),
                            torch.eye(num_states, device=beliefs.device)
                        ], dim=-1)
                    ).squeeze(-1),
                    dim=-1
                )
                
                new_messages.append(msg)
            
            messages = torch.stack(new_messages)
            
            # 消息归一化
            messages = self.msg_normalize(messages)
        
        # 计算最终信念
        final_beliefs = beliefs.clone()
        for node in range(num_nodes):
            incoming = messages[col == node]
            if incoming.numel() > 0:
                final_beliefs[node] = (incoming * beliefs[node]).sum(dim=0)
                final_beliefs[node] = F.softmax(final_beliefs[node], dim=-1)
        
        return final_beliefs

2.3 变分消息传递层

将变分推断嵌入消息传递框架:

class VariationalMessagePassing(nn.Module):
    """变分消息传递层"""
    
    def __init__(self, node_dim: int, latent_dim: int):
        super().__init__()
        self.node_dim = node_dim
        self.latent_dim = latent_dim
        
        # 编码器:从观测推断变分参数
        self.encoder = nn.Sequential(
            nn.Linear(node_dim, latent_dim * 2),
            nn.Tanh(),
            nn.Linear(latent_dim * 2, latent_dim * 2)  # mu和log_var
        )
        
        # 消息网络
        self.message_net = nn.GRU(
            input_size=latent_dim,
            hidden_size=latent_dim,
            batch_first=True
        )
        
        # 解码器:从隐变量重建观测
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, node_dim),
            nn.Sigmoid()
        )
        
    def reparameterize(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
        """重参数化技巧"""
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x: torch.Tensor, 
               edge_index: torch.Tensor,
               num_iterations: int = 3) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        变分消息传递前向传播
        
        Returns:
            z: 隐变量样本
            mu: 隐变量均值
            log_var: 隐变量对数方差
        """
        num_nodes = x.size(0)
        
        # 编码初始信念
        h = self.encoder(x)  # (N, latent_dim * 2)
        mu = h[:, :self.latent_dim]
        log_var = h[:, self.latent_dim:]
        
        # 迭代消息传递
        z = self.reparameterize(mu, log_var)
        
        for _ in range(num_iterations):
            # 消息传递
            z_packed = z.unsqueeze(0)  # (1, N, latent_dim)
            msg_out, _ = self.message_net(z_packed)
            msg = msg_out.squeeze(0)  # (N, latent_dim)
            
            # 更新变分参数
            z = torch.tanh(msg + mu)
            mu = mu + 0.1 * msg
            log_var = log_var - 0.05 * msg.pow(2)
        
        # 重建观测
        x_recon = self.decoder(z)
        
        return z, mu, log_var, x_recon
    
    def elbo(self, x: torch.Tensor, 
            z: torch.Tensor, mu: torch.Tensor, 
            log_var: torch.Tensor) -> torch.Tensor:
        """
        计算ELBO
        
        Returns:
            elbo: 证据下界
        """
        # 重构损失
        x_recon = self.decoder(z)
        recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
        
        # KL散度
        kl_loss = -0.5 * torch.sum(
            1 + log_var - mu.pow(2) - log_var.exp()
        )
        
        return -(recon_loss + kl_loss)

3 高斯过程推断层

3.1 变分高斯过程

class VariationalGPLayer(nn.Module):
    """变分高斯过程层"""
    
    def __init__(self, input_dim: int, num_inducing: int, kernel_dim: int = 128):
        super().__init__()
        self.input_dim = input_dim
        self.num_inducing = num_inducing
        self.kernel_dim = kernel_dim
        
        # 诱导点
        self.inducing_points = nn.Parameter(
            torch.randn(num_inducing, input_dim) * 0.1
        )
        
        # 均值和方差参数
        self.mean = nn.Parameter(torch.zeros(num_inducing, 1))
        self.cov_log_diag = nn.Parameter(torch.zeros(num_inducing))
        
        # 核函数
        self.kernel = nn.Sequential(
            nn.Linear(input_dim, kernel_dim),
            nn.RBFKernel(),
            nn.Linear(kernel_dim, 1)
        )
        
        # 均值函数
        self.mean_fn = nn.Linear(input_dim, 1)
        
    def kernel_matrix(self, X: torch.Tensor, Z: torch.Tensor) -> torch.Tensor:
        """计算核矩阵"""
        # 使用RBF核
        pairwise_sq_dists = torch.cdist(X, Z).pow(2)
        K = torch.exp(-0.5 * pairwise_sq_dists)
        return K
    
    def forward(self, X: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        前向传播,返回预测均值和方差
        
        Returns:
            mean: (batch_size, 1) 预测均值
            var: (batch_size, 1) 预测方差
        """
        # 核矩阵
        Kzz = self.kernel_matrix(self.inducing_points, self.inducing_points)
        Kxz = self.kernel_matrix(X, self.inducing_points)
        
        # 添加 jitter 以确保数值稳定
        jitter = 1e-6 * torch.eye(self.num_inducing, device=Kzz.device)
        Kzz_inv = torch.inverse(Kzz + jitter)
        
        # 预测均值
        mean = Kxz @ Kzz_inv @ self.mean
        
        # 预测方差
        k_xx = torch.ones(X.size(0), 1, device=X.device)  # 对角线
        var = k_xx - (Kxz * (Kzz_inv @ Kxz.T).T).sum(dim=-1, keepdim=True)
        var = F.softplus(var + 1e-6)  # 确保正值
        
        return mean + self.mean_fn(X), var

4 端到端训练框架

4.1 概率推断训练器

from torch.utils.data import DataLoader
from typing import Dict, Any
 
class ProbabilisticInferenceTrainer:
    """概率推断模型训练器"""
    
    def __init__(self, model: nn.Module, 
                 optimizer: torch.optim.Optimizer,
                 device: str = 'cuda'):
        self.model = model
        self.optimizer = optimizer
        self.device = device
        self.model.to(device)
        
        # 训练历史
        self.train_history = []
        self.val_history = []
        
    def train_epoch(self, dataloader: DataLoader) -> Dict[str, float]:
        """训练一个epoch"""
        self.model.train()
        epoch_loss = 0.0
        epoch_metrics = {}
        
        for batch in dataloader:
            # 数据移动到设备
            batch = {k: v.to(self.device) for k, v in batch.items()}
            
            # 前向传播
            output = self.model(**batch)
            
            # 计算损失
            loss = self.compute_loss(output, batch)
            
            # 反向传播
            self.optimizer.zero_grad()
            loss.backward()
            
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(), max_norm=1.0
            )
            
            self.optimizer.step()
            
            epoch_loss += loss.item()
        
        epoch_loss /= len(dataloader)
        
        return {'loss': epoch_loss, **epoch_metrics}
    
    def compute_loss(self, output: Dict[str, torch.Tensor],
                    batch: Dict[str, Any]) -> torch.Tensor:
        """计算损失"""
        if 'loss' in output:
            return output['loss']
        
        # 默认使用重构损失
        if 'recon' in output and 'target' in batch:
            recon_loss = F.mse_loss(output['recon'], batch['target'])
            
            # KL损失
            kl_loss = 0.0
            if 'mu' in output and 'log_var' in output:
                kl_loss = -0.5 * torch.sum(
                    1 + output['log_var'] - output['mu'].pow(2) - output['log_var'].exp()
                )
            
            return recon_loss + 0.01 * kl_loss
        
        raise ValueError("Cannot compute loss from output")
    
    def validate(self, dataloader: DataLoader) -> Dict[str, float]:
        """验证"""
        self.model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for batch in dataloader:
                batch = {k: v.to(self.device) for k, v in batch.items()}
                output = self.model(**batch)
                loss = self.compute_loss(output, batch)
                val_loss += loss.item()
        
        return {'val_loss': val_loss / len(dataloader)}
    
    def fit(self, train_loader: DataLoader, 
           val_loader: DataLoader,
           num_epochs: int) -> None:
        """完整训练流程"""
        best_val_loss = float('inf')
        
        for epoch in range(num_epochs):
            # 训练
            train_metrics = self.train_epoch(train_loader)
            
            # 验证
            val_metrics = self.validate(val_loader)
            
            # 记录历史
            self.train_history.append(train_metrics)
            self.val_history.append(val_metrics)
            
            # 打印
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"  Train Loss: {train_metrics['loss']:.4f}")
            print(f"  Val Loss: {val_metrics['val_loss']:.4f}")
            
            # 保存最佳模型
            if val_metrics['val_loss'] < best_val_loss:
                best_val_loss = val_metrics['val_loss']
                self.save_checkpoint('best_model.pt')
    
    def save_checkpoint(self, path: str) -> None:
        """保存检查点"""
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, path)

5 实际应用案例

5.1 贝叶斯图神经网络

class BayesianGNN(nn.Module):
    """贝叶斯图神经网络用于节点分类"""
    
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
        super().__init__()
        
        # 消息传递层
        self.mpnn1 = DifferentiableMessagePassing(in_dim, 0, hidden_dim)
        self.mpnn2 = DifferentiableMessagePassing(hidden_dim, 0, hidden_dim)
        
        # 变分层
        self.vmp = VariationalMessagePassing(hidden_dim, hidden_dim)
        
        # 分类器
        self.classifier = nn.Linear(hidden_dim, out_dim)
        
    def forward(self, x: torch.Tensor, 
               edge_index: torch.Tensor,
               training: bool = True) -> Dict[str, torch.Tensor]:
        """
        前向传播
        
        Returns:
            logits: 分类logits
            mu: 变分均值
            log_var: 变分对数方差
        """
        # 消息传递
        h = F.relu(self.mpnn1(x, edge_index))
        h = F.relu(self.mpnn2(h, edge_index))
        
        # 变分推断
        z, mu, log_var, h_recon = self.vmp(h, edge_index)
        
        # 分类
        logits = self.classifier(h)
        
        return {
            'logits': logits,
            'mu': mu,
            'log_var': log_var,
            'recon': h_recon
        }
    
    def predict_with_uncertainty(self, x: torch.Tensor,
                                edge_index: torch.Tensor,
                                num_samples: int = 10) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        使用MC Dropout进行不确定性估计
        
        Returns:
            mean: 平均预测
            std: 预测标准差(不确定性)
        """
        self.train()  # 启用dropout
        
        predictions = []
        for _ in range(num_samples):
            logits = self.forward(x, edge_index)['logits']
            probs = F.softmax(logits, dim=-1)
            predictions.append(probs)
        
        predictions = torch.stack(predictions)
        mean = predictions.mean(dim=0)
        std = predictions.std(dim=0)
        
        return mean, std

5.2 概率链接预测

class ProbabilisticLinkPrediction(nn.Module):
    """概率链接预测模型"""
    
    def __init__(self, node_dim: int, num_layers: int = 3):
        super().__init__()
        
        # 节点嵌入
        self.embedding = nn.Parameter(
            torch.randn(1, node_dim) * 0.1
        )  # 广播到所有节点
        
        # 马尔可夫消息传递层
        self.message_layers = nn.ModuleList([
            DifferentiableMessagePassing(node_dim, 0, node_dim)
            for _ in range(num_layers)
        ])
        
        # 链接预测器
        self.link_predictor = nn.Sequential(
            nn.Linear(node_dim * 2, node_dim),
            nn.ReLU(),
            nn.Linear(node_dim, 1)
        )
        
    def forward(self, x: torch.Tensor, 
               edge_index: torch.Tensor,
               edge_index_pos: torch.Tensor,
               edge_index_neg: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        链接预测前向传播
        
        Args:
            x: 节点特征
            edge_index: 所有边(训练+测试)
            edge_index_pos: 正样本边
            edge_index_neg: 负样本边
        """
        num_nodes = x.size(0)
        
        # 初始化节点嵌入
        if x.size(0) == 1:
            h = self.embedding.expand(num_nodes, -1)
        else:
            h = x
        
        # 消息传递
        for layer in self.message_layers:
            h_new = layer(h, edge_index)
            h = h + h_new  # 残差连接
        
        # 正样本得分
        pos_src, pos_dst = edge_index_pos
        pos_h_src, pos_h_dst = h[pos_src], h[pos_dst]
        pos_score = self.link_predictor(
            torch.cat([pos_h_src, pos_h_dst], dim=-1)
        )
        
        # 负样本得分
        neg_src, neg_dst = edge_index_neg
        neg_h_src, neg_h_dst = h[neg_src], h[neg_dst]
        neg_score = self.link_predictor(
            torch.cat([neg_h_src, neg_h_dst], dim=-1)
        )
        
        return {
            'pos_score': pos_score,
            'neg_score': neg_score,
            'embeddings': h
        }
    
    def loss(self, output: Dict[str, torch.Tensor]) -> torch.Tensor:
        """链接预测损失"""
        pos_score = output['pos_score']
        neg_score = output['neg_score']
        
        # 铰链损失
        margin = 1.0
        loss = F.margin_ranking_loss(
            pos_score, neg_score, 
            torch.ones_like(pos_score),
            margin=margin
        )
        
        return loss

6 总结与展望

6.1 核心要点

  1. 可微分消息传递是连接概率推断与深度学习的桥梁
  2. 变分推断提供了近似推断的可扩展框架
  3. 端到端训练允许联合学习推断网络和下游任务

6.2 进阶方向

  • 结构化变分推断:利用图的稀疏性
  • 层次化消息传递:多尺度特征聚合
  • 可解释性:将不确定性量化融入模型解释
  • 组合优化:将推理问题嵌入神经网络

6.3 实际建议

  1. 从简单的消息传递层开始,逐步增加复杂度
  2. 使用梯度裁剪防止训练不稳定
  3. 监控KL散度与重构损失的比例
  4. 对于大规模图,考虑稀疏矩阵操作和采样技术

参考资料