概述

贝叶斯神经网络(Bayesian Neural Networks, BNNs)通过将网络权重视为随机变量来建模不确定性,是深度学习中不确定性量化的重要范式。1 然而,高维参数空间中的精确贝叶斯推断是一个根本性挑战。本文档介绍2024-2025年的最新进展,包括精确贝叶斯推断、哈密顿蒙特卡洛加速、以及函数空间变分推断等方法。


精确贝叶斯推断的进展

精确贝叶斯神经网络的挑战

传统BNN的精确推断需要计算权重后验 ,这涉及对数百万参数的高维积分:

其中 是无法解析计算的边缘似然。

精确贝叶斯神经网络(PBNN)

arXiv:2506.19726 提出了一种精确贝叶斯神经网络方法,通过精心设计的网络架构和归一化层来实现可追踪的贝叶斯推断。2

核心设计原则

  1. 对齐Hessian结构:使用与Hessian矩阵结构对齐的协方差参数化
  2. 方向不确定性:在参数空间中建模特定方向的不确定性
  3. 网络几何感知:利用神经网络的黎曼几何结构
import torch
import torch.nn as nn
import torch.nn.functional as F
 
class PreciseBayesianLayer(nn.Module):
    """
    精确贝叶斯神经网络层
    
    核心思想:在特定方向上建模不确定性,而非完全协方差
    """
    def __init__(self, in_features, out_features, num_directions=16):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_directions = num_directions
        
        # 均值参数
        self.mean = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
        
        # 方向精度参数(在特定方向上的精度)
        self.direction_precisions = nn.Parameter(torch.ones(num_directions))
        
        # 随机方向向量(固定)
        self.register_buffer(
            'directions', 
            torch.randn(num_directions, in_features)
        )
        # 归一化方向向量
        with torch.no_grad():
            self.directions = F.normalize(self.directions, dim=1)
    
    def sample_weights(self, num_samples=1):
        """
        从后验分布采样权重
        """
        batch_size = self.mean.shape[0]
        
        # 在方向子空间采样
        # w = mean + sum_i(precision_i * direction_i) * epsilon_i
        epsilons = torch.randn(num_samples, self.num_directions, device=self.mean.device)
        
        # 计算方向权重
        direction_weights = epsilons * torch.sqrt(self.direction_precisions)
        
        # 重构权重矩阵
        # weight[b] = mean[b] + directions^T @ direction_weights[b]
        direction_contribution = torch.einsum(
            'di,dsi->dsi', 
            self.directions, 
            direction_weights
        )
        
        # 权重: (out_features, in_features)
        weights = self.mean.unsqueeze(0) + direction_contribution.sum(dim=1)
        
        return weights
    
    def forward(self, x):
        """
        使用采样的权重进行前向传播
        """
        weights = self.sample_weights(num_samples=1)
        return F.linear(x, weights)
 
 
class PreciseBNN(nn.Module):
    """
    精确贝叶斯神经网络
    
    使用方向精度参数化实现可追踪推断
    """
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        
        self.layer1 = PreciseBayesianLayer(input_dim, hidden_dim, num_directions=32)
        self.layer2 = PreciseBayesianLayer(hidden_dim, hidden_dim, num_directions=32)
        self.layer3 = PreciseBayesianLayer(hidden_dim, output_dim, num_directions=16)
        
        self.activation = nn.ReLU()
    
    def forward(self, x, num_samples=1):
        """
        前向传播
        
        Args:
            x: 输入张量 (batch_size, input_dim)
            num_samples: MC采样数量
        """
        h = x.unsqueeze(0).expand(num_samples, -1, -1)  # (num_samples, batch, dim)
        
        h = self.activation(self.layer1.sample_weights(num_samples).unsqueeze(1) @ h.unsqueeze(-1))
        h = self.activation(self.layer2.sample_weights(num_samples).unsqueeze(1) @ h.unsqueeze(-1))
        logits = self.layer3.sample_weights(num_samples).unsqueeze(1) @ h.squeeze(-1)
        
        return logits  # (num_samples, batch, output_dim)
    
    def predict(self, x, num_samples=100):
        """
        预测:返回均值和不确定性
        """
        logits = self.forward(x, num_samples)
        
        # 概率均值
        probs = F.softmax(logits, dim=-1)
        mean_probs = probs.mean(dim=0)
        
        # 不确定性:熵
        entropy = -(mean_probs * torch.log(mean_probs + 1e-10)).sum(dim=-1)
        
        # 预测类别
        predictions = mean_probs.argmax(dim=-1)
        
        return predictions, entropy

与传统方法的对比

方法协方差参数化推断复杂度不确定性质量
标准MC Dropout隐式中等
Mean Field VI对角协方差较低
PBNN方向精度
Full Covariance满协方差最高

哈密顿蒙特卡洛加速方法

混合推断框架

arXiv:2507.14652 提出了一种将变分推断与哈密顿蒙特卡洛(HMC)相结合的混合方法,用于加速贝叶斯神经网络推断。3

核心思想

  1. 变分引导:使用变分后验初始化HMC链
  2. 选择性HMC:只在最关键的参数方向上运行HMC
  3. 自适应步长:根据梯度信息自适应调整
import torch
import torch.nn as nn
from torch.distributions import Normal
 
class HybridBNNSampler:
    """
    混合VI-HMC采样器
    
    结合变分推断的效率和HMC的准确性
    """
    def __init__(self, model, vi_posterior, hmc_fraction=0.1):
        self.model = model
        self.vi_posterior = vi_posterior  # 预训练的变分后验
        self.hmc_fraction = hmc_fraction  # 用于HMC的参数比例
        
        # 确定HMC参数子集
        self._identify_important_parameters()
    
    def _identify_important_parameters(self):
        """
        识别需要精确HMC采样的关键参数
        
        策略:选择梯度方差最大的参数
        """
        # 计算参数重要性(梯度方差)
        param_importance = {}
        for name, param in self.model.named_parameters():
            # 近似计算参数的后验方差
            param_importance[name] = param.abs().mean()
        
        # 选择top-k最重要的参数
        sorted_params = sorted(param_importance.items(), key=lambda x: x[1], reverse=True)
        num_hmc = int(len(sorted_params) * self.hmc_fraction)
        
        self.hmc_param_names = set([name for name, _ in sorted_params[:num_hmc]])
        
    def sample(self, data, num_samples=100, leapfrog_steps=10, step_size=0.01):
        """
        混合采样过程
        """
        samples = []
        
        # 初始化:从变分后验采样
        current_params = {}
        for name, param in self.model.named_parameters():
            if name in self.hmc_param_names:
                # 对关键参数使用HMC
                current_params[name] = param.data.clone()
            else:
                # 对其他参数使用变分后验
                current_params[name] = self.vi_posterior.sample(param.shape)
        
        for _ in range(num_samples):
            # 步骤1:使用变分后验更新非关键参数
            for name, param in self.model.named_parameters():
                if name not in self.hmc_param_names:
                    current_params[name] = self.vi_posterior.sample(param.shape)
            
            # 步骤2:使用HMC更新关键参数
            current_params = self._hmc_step(
                current_params, 
                data, 
                leapfrog_steps, 
                step_size
            )
            
            # 记录样本
            samples.append({name: param.clone() for name, param in current_params.items()})
        
        return samples
    
    def _hmc_step(self, params, data, L, eps):
        """
        HMC步长
        """
        # 采样动量
        momentum = {name: torch.randn_like(p) for name, p in params.items()}
        
        # 保存当前位置
        current_params = {name: p.clone() for name, p in params.items()}
        
        # 计算当前能量
        current_energy = self._compute_energy(params, data)
        
        # 梯度
        current_grad = self._compute_gradients(params, data)
        
        # Leapfrog积分
        for l in range(L):
            # 更新动量
            for name in momentum:
                momentum[name] += 0.5 * eps * current_grad[name]
            
            # 更新位置
            for name in params:
                params[name] += eps * momentum[name]
            
            # 重新计算梯度
            current_grad = self._compute_gradients(params, data)
            
            # 更新动量
            for name in momentum:
                momentum[name] += 0.5 * eps * current_grad[name]
        
        # 计算提议能量
        proposed_energy = self._compute_energy(params, data)
        
        # Metropolis-Hastings接受
        delta_E = proposed_energy - current_energy
        
        if torch.rand(1).log() > delta_E:
            # 拒绝:恢复到原位置
            params = current_params
        
        return params
    
    def _compute_energy(self, params, data):
        """计算联合能量 -log p(w, D)"""
        # 设置模型参数
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                param.data = params[name]
        
        # 计算负对数似然
        x, y = data
        logits = self.model(x)
        nll = F.cross_entropy(logits, y, reduction='sum')
        
        # 计算负对数先验
        prior = sum(p.pow(2).sum() for p in params.values())
        
        return nll + 0.5 * prior
    
    def _compute_gradients(self, params, data):
        """计算能量关于参数的梯度"""
        # 设置模型参数
        for name, param in self.model.named_parameters():
            param.data = params[name]
            param.requires_grad_(name in self.hmc_param_names)
        
        x, y = data
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        
        grads = {}
        for name, param in self.model.named_parameters():
            if name in self.hmc_param_names:
                grads[name] = param.grad.clone()
        
        return grads

函数空间变分推断

从参数空间到函数空间

传统的变分推断在参数空间进行,但参数空间的几何结构复杂。函数空间变分推断直接在函数空间进行建模,更加自然。4

核心思想

是神经网络定义的函数,变分推断目标是近似后验

将先验放在函数空间而非参数空间:

其中 是高斯过程, 是核函数。

class FunctionSpaceVariationalBNN(nn.Module):
    """
    函数空间变分贝叶斯神经网络
    
    核心思想:用随机特征展开近似函数空间先验
    """
    def __init__(self, input_dim, hidden_dim, output_dim, num_features=512):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_features = num_features
        
        # 随机特征展开:近似高斯过程核
        self.feature_dim = num_features // 2
        
        # 随机频率(从先验采样后固定)
        self.register_buffer(
            'random_frequencies', 
            torch.randn(self.feature_dim, input_dim) / np.sqrt(input_dim)
        )
        
        # 变分后验的均值和方差(函数空间)
        self.function_mean = nn.Parameter(torch.zeros(output_dim, num_features))
        self.function_log_var = nn.Parameter(torch.zeros(num_features))
        
        # 输出层
        self.output_layer = nn.Linear(num_features, output_dim)
    
    def forward(self, x):
        """
        前向传播:使用随机特征展开
        """
        # 随机傅里叶特征
        z = torch.cat([
            torch.cos(x @ self.random_frequencies.T),
            torch.sin(x @ self.random_frequencies.T)
        ], dim=-1)
        
        # 应用函数空间变分
        mean = z @ self.function_mean.T
        # 对角协方差
        var = torch.exp(self.function_log_var)
        std = torch.sqrt(var)
        
        # 添加噪声
        output = mean + std * torch.randn_like(mean)
        
        return output
    
    def kl_divergence(self):
        """
        计算与函数空间先验的KL散度
        
        先验:p(f) = GP(0, k)
        近似后验:q(f) = N(mu, diag(sigma^2))
        """
        # KL(q(f) || p(f)) 的变分下界
        # 这里使用简化的近似
        kl = 0.5 * (
            torch.exp(self.function_log_var).sum() + 
            (self.function_mean ** 2).sum() - 
            self.num_features
        )
        return kl

神经切线核视角

函数空间变分推断与神经切线核(NTK)理论有深刻联系:

  • NTK 描述了神经网络在无限宽度极限下的贝叶斯后验
  • 函数空间变分 直接在 上建模,而非

模式连接性与推断

模式连接性理论

最近的研究发现,BNN的推断问题可以通过模式连接性得到简化。5

核心发现

深度网络的损失景观中,局部最优解通过低损失路径相互连接。这意味着:

  1. 后验 主要集中在连接模式(mode-connecting)路径上
  2. 可以从任意初始点通过这些路径到达后验分布的核心区域
class ModeConnectedBNNSampler:
    """
    基于模式连接性的BNN采样器
    """
    def __init__(self, model):
        self.model = model
    
    def path_interpolation(self, w1, w2, alpha):
        """
        沿连接路径插值
        
        使用球面线性插值(SLERP)或线性插值
        """
        # 简单线性插值
        return (1 - alpha) * w1 + alpha * w2
    
    def sample_via_mode_connection(self, num_samples=100):
        """
        通过模式连接进行采样
        
        策略:
        1. 找到多个局部最优解(通过多次训练)
        2. 在这些解之间采样
        """
        # 训练多个模型以找到不同的模式
        modes = []
        for _ in range(5):
            model = self._train_mode()
            modes.append(self._extract_weights(model))
        
        # 在模式之间采样
        samples = []
        for _ in range(num_samples):
            # 随机选择两个模式
            i, j = np.random.choice(len(modes), 2, replace=False)
            
            # 随机插值系数
            alpha = np.random.beta(2, 2)
            
            # 生成样本
            sample = {}
            for name in modes[0].keys():
                sample[name] = self.path_interpolation(
                    modes[i][name], 
                    modes[j][name], 
                    alpha
                )
            samples.append(sample)
        
        return samples
    
    def _train_mode(self):
        """训练以找到一个局部最优模式"""
        model = copy.deepcopy(self.model)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        
        # 短期训练
        for _ in range(100):
            for x, y in dataloader:
                optimizer.zero_grad()
                loss = F.cross_entropy(model(x), y)
                loss.backward()
                optimizer.step()
        
        return model

实用建议

策略适用场景复杂度
MC Dropout快速原型、实时应用
Mean Field VI资源受限场景
PBNN需要精确不确定性
HMC Hybrid高精度需求
Function Space VI函数空间建模
Mode Connection探索后验结构

准确率-不确定性权衡

架构与推断的选择

arXiv:2503.11808 分析了BNN中准确率和不确定性量化之间的权衡。6

关键发现

  1. 架构影响:更宽的网络倾向于更好的点估计,但可能削弱不确定性
  2. 推断方法:不同的推断方法在准确率和不确定性之间有不同的权衡
  3. 正则化:权重衰减等正则化同时影响两者
class AccuracyUncertaintyTradeoff:
    """
    分析准确率和不确定性的权衡
    """
    def __init__(self, model, train_data, test_data):
        self.model = model
        self.train_data = train_data
        self.test_data = test_data
    
    def evaluate_tradeoff(self, inference_method='mc_dropout', num_samples=100):
        """
        评估给定推断方法的准确率-不确定性权衡
        """
        # 获取预测和不确定性
        predictions, uncertainties = self._get_predictions(
            inference_method, 
            num_samples
        )
        
        # 计算指标
        accuracy = (predictions == self.test_data[1]).float().mean()
        mean_uncertainty = uncertainties.mean()
        
        # 校准误差
        calibration_error = self._compute_calibration(predictions, uncertainties)
        
        return {
            'accuracy': accuracy.item(),
            'mean_uncertainty': mean_uncertainty.item(),
            'calibration_error': calibration_error.item()
        }
    
    def _compute_calibration(self, predictions, uncertainties, num_bins=15):
        """
        计算ECE(Expected Calibration Error)
        """
        confidences = 1 - uncertainties  # 转换为置信度
        accuracies = (predictions == self.test_data[1]).float()
        
        bin_boundaries = torch.linspace(0, 1, num_bins + 1)
        ece = 0.0
        
        for i in range(num_bins):
            in_bin = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
            if in_bin.sum() > 0:
                bin_accuracy = accuracies[in_bin].mean()
                bin_confidence = confidences[in_bin].mean()
                ece += in_bin.float().mean() * abs(bin_accuracy - bin_confidence)
        
        return ece

代码实现对比

主流方法对比

方法特点适用场景
MC DropoutPyTorch原生简单、无额外参数实时应用
Bayes by Backpropblitz变分推断、KL正则研究原型
Stochastic Weight Averagingtorchcontrib宽极小值、平滑泛化提升
Last Layer BNNbbox高效、部分贝叶斯大模型
SWA-Gaussian自定义SWA+不确定性生产部署

完整实现示例

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
 
class BayesianInferenceEnsemble:
    """
    统一的贝叶斯推断接口
    支持多种推断方法的集成
    """
    def __init__(self, model, inference_method='all'):
        self.model = model
        self.method = inference_method
    
    def predict(self, x, num_samples=100):
        """
        统一的预测接口
        """
        if self.method == 'mc_dropout':
            return self._mc_dropout_predict(x, num_samples)
        elif self.method == 'vi':
            return self._vi_predict(x, num_samples)
        elif self.method == 'swa':
            return self._swa_predict(x)
        elif self.method == 'all':
            return self._ensemble_predict(x, num_samples)
    
    def _mc_dropout_predict(self, x, num_samples):
        """MC Dropout预测"""
        self.model.train()  # 启用dropout
        
        predictions = []
        for _ in range(num_samples):
            with torch.no_grad():
                pred = self.model(x)
                predictions.append(pred)
        
        predictions = torch.stack(predictions)
        mean_pred = predictions.mean(dim=0)
        uncertainty = predictions.std(dim=0)
        
        return mean_pred, uncertainty
    
    def _vi_predict(self, x, num_samples):
        """变分推断预测"""
        self.model.eval()
        
        predictions = []
        for _ in range(num_samples):
            pred = self.model.sample(x)
            predictions.append(pred)
        
        predictions = torch.stack(predictions)
        return predictions.mean(dim=0), predictions.std(dim=0)
    
    def _swa_predict(self, x):
        """SWA预测"""
        self.model.eval()
        
        # SWA模型的预测(单一确定性模型)
        with torch.no_grad():
            mean_pred = self.model(x)
        
        # 估计不确定性(基于训练数据分布)
        uncertainty = self._estimate_prior_uncertainty(x)
        
        return mean_pred, uncertainty
    
    def _ensemble_predict(self, x, num_samples):
        """集成多种方法"""
        results = {
            'mc_dropout': self._mc_dropout_predict(x, num_samples // 3),
            'vi': self._vi_predict(x, num_samples // 3),
            'swa': self._swa_predict(x)
        }
        
        # 加权平均
        weights = {'mc_dropout': 0.4, 'vi': 0.4, 'swa': 0.2}
        
        mean_pred = sum(w * r[0] for w, r in weights.items())
        uncertainty = sum(w * r[1] for w, r in weights.items())
        
        return mean_pred, uncertainty

总结

贝叶斯神经网络高级推断方法的总结:

方法推断复杂度不确定性质量实现难度
MC Dropout中等
Mean Field VI较低
PBNN
HMC Hybrid最高
Function Space VI
Mode Connection

实践建议

  • 实时应用:使用MC Dropout
  • 资源受限:使用Mean Field VI或Last Layer BNN
  • 高精度需求:使用HMC Hybrid或PBNN
  • 通用场景:使用Function Space VI

参考


相关文章

Footnotes

  1. A Survey on Bayesian Deep Learning (arXiv:2404.14642)

  2. Precise Bayesian Neural Networks (arXiv:2506.19726)

  3. Accelerating Hamiltonian Monte Carlo for Bayesian Neural Networks (arXiv:2507.14652)

  4. Tractable Function-Space Variational Inference in Bayesian Neural Networks (arXiv:2312.17199)

  5. Connecting the Dots: Is Mode-Connectedness the Key to Feasible Sample-Based Inference in BNNs? (arXiv:2402.01484)

  6. Understanding the Trade-offs in Accuracy and Uncertainty Quantification (arXiv:2503.11808)