概述
贝叶斯神经网络(Bayesian Neural Networks, BNNs)通过将网络权重视为随机变量来建模不确定性,是深度学习中不确定性量化的重要范式。1 然而,高维参数空间中的精确贝叶斯推断是一个根本性挑战。本文档介绍2024-2025年的最新进展,包括精确贝叶斯推断、哈密顿蒙特卡洛加速、以及函数空间变分推断等方法。
精确贝叶斯推断的进展
精确贝叶斯神经网络的挑战
传统BNN的精确推断需要计算权重后验 ,这涉及对数百万参数的高维积分:
其中 是无法解析计算的边缘似然。
精确贝叶斯神经网络(PBNN)
arXiv:2506.19726 提出了一种精确贝叶斯神经网络方法,通过精心设计的网络架构和归一化层来实现可追踪的贝叶斯推断。2
核心设计原则
- 对齐Hessian结构:使用与Hessian矩阵结构对齐的协方差参数化
- 方向不确定性:在参数空间中建模特定方向的不确定性
- 网络几何感知:利用神经网络的黎曼几何结构
import torch
import torch.nn as nn
import torch.nn.functional as F
class PreciseBayesianLayer(nn.Module):
"""
精确贝叶斯神经网络层
核心思想:在特定方向上建模不确定性,而非完全协方差
"""
def __init__(self, in_features, out_features, num_directions=16):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.num_directions = num_directions
# 均值参数
self.mean = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
# 方向精度参数(在特定方向上的精度)
self.direction_precisions = nn.Parameter(torch.ones(num_directions))
# 随机方向向量(固定)
self.register_buffer(
'directions',
torch.randn(num_directions, in_features)
)
# 归一化方向向量
with torch.no_grad():
self.directions = F.normalize(self.directions, dim=1)
def sample_weights(self, num_samples=1):
"""
从后验分布采样权重
"""
batch_size = self.mean.shape[0]
# 在方向子空间采样
# w = mean + sum_i(precision_i * direction_i) * epsilon_i
epsilons = torch.randn(num_samples, self.num_directions, device=self.mean.device)
# 计算方向权重
direction_weights = epsilons * torch.sqrt(self.direction_precisions)
# 重构权重矩阵
# weight[b] = mean[b] + directions^T @ direction_weights[b]
direction_contribution = torch.einsum(
'di,dsi->dsi',
self.directions,
direction_weights
)
# 权重: (out_features, in_features)
weights = self.mean.unsqueeze(0) + direction_contribution.sum(dim=1)
return weights
def forward(self, x):
"""
使用采样的权重进行前向传播
"""
weights = self.sample_weights(num_samples=1)
return F.linear(x, weights)
class PreciseBNN(nn.Module):
"""
精确贝叶斯神经网络
使用方向精度参数化实现可追踪推断
"""
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.layer1 = PreciseBayesianLayer(input_dim, hidden_dim, num_directions=32)
self.layer2 = PreciseBayesianLayer(hidden_dim, hidden_dim, num_directions=32)
self.layer3 = PreciseBayesianLayer(hidden_dim, output_dim, num_directions=16)
self.activation = nn.ReLU()
def forward(self, x, num_samples=1):
"""
前向传播
Args:
x: 输入张量 (batch_size, input_dim)
num_samples: MC采样数量
"""
h = x.unsqueeze(0).expand(num_samples, -1, -1) # (num_samples, batch, dim)
h = self.activation(self.layer1.sample_weights(num_samples).unsqueeze(1) @ h.unsqueeze(-1))
h = self.activation(self.layer2.sample_weights(num_samples).unsqueeze(1) @ h.unsqueeze(-1))
logits = self.layer3.sample_weights(num_samples).unsqueeze(1) @ h.squeeze(-1)
return logits # (num_samples, batch, output_dim)
def predict(self, x, num_samples=100):
"""
预测:返回均值和不确定性
"""
logits = self.forward(x, num_samples)
# 概率均值
probs = F.softmax(logits, dim=-1)
mean_probs = probs.mean(dim=0)
# 不确定性:熵
entropy = -(mean_probs * torch.log(mean_probs + 1e-10)).sum(dim=-1)
# 预测类别
predictions = mean_probs.argmax(dim=-1)
return predictions, entropy与传统方法的对比
| 方法 | 协方差参数化 | 推断复杂度 | 不确定性质量 |
|---|---|---|---|
| 标准MC Dropout | 隐式 | 中等 | |
| Mean Field VI | 对角协方差 | 较低 | |
| PBNN | 方向精度 | 高 | |
| Full Covariance | 满协方差 | 最高 |
哈密顿蒙特卡洛加速方法
混合推断框架
arXiv:2507.14652 提出了一种将变分推断与哈密顿蒙特卡洛(HMC)相结合的混合方法,用于加速贝叶斯神经网络推断。3
核心思想
- 变分引导:使用变分后验初始化HMC链
- 选择性HMC:只在最关键的参数方向上运行HMC
- 自适应步长:根据梯度信息自适应调整
import torch
import torch.nn as nn
from torch.distributions import Normal
class HybridBNNSampler:
"""
混合VI-HMC采样器
结合变分推断的效率和HMC的准确性
"""
def __init__(self, model, vi_posterior, hmc_fraction=0.1):
self.model = model
self.vi_posterior = vi_posterior # 预训练的变分后验
self.hmc_fraction = hmc_fraction # 用于HMC的参数比例
# 确定HMC参数子集
self._identify_important_parameters()
def _identify_important_parameters(self):
"""
识别需要精确HMC采样的关键参数
策略:选择梯度方差最大的参数
"""
# 计算参数重要性(梯度方差)
param_importance = {}
for name, param in self.model.named_parameters():
# 近似计算参数的后验方差
param_importance[name] = param.abs().mean()
# 选择top-k最重要的参数
sorted_params = sorted(param_importance.items(), key=lambda x: x[1], reverse=True)
num_hmc = int(len(sorted_params) * self.hmc_fraction)
self.hmc_param_names = set([name for name, _ in sorted_params[:num_hmc]])
def sample(self, data, num_samples=100, leapfrog_steps=10, step_size=0.01):
"""
混合采样过程
"""
samples = []
# 初始化:从变分后验采样
current_params = {}
for name, param in self.model.named_parameters():
if name in self.hmc_param_names:
# 对关键参数使用HMC
current_params[name] = param.data.clone()
else:
# 对其他参数使用变分后验
current_params[name] = self.vi_posterior.sample(param.shape)
for _ in range(num_samples):
# 步骤1:使用变分后验更新非关键参数
for name, param in self.model.named_parameters():
if name not in self.hmc_param_names:
current_params[name] = self.vi_posterior.sample(param.shape)
# 步骤2:使用HMC更新关键参数
current_params = self._hmc_step(
current_params,
data,
leapfrog_steps,
step_size
)
# 记录样本
samples.append({name: param.clone() for name, param in current_params.items()})
return samples
def _hmc_step(self, params, data, L, eps):
"""
HMC步长
"""
# 采样动量
momentum = {name: torch.randn_like(p) for name, p in params.items()}
# 保存当前位置
current_params = {name: p.clone() for name, p in params.items()}
# 计算当前能量
current_energy = self._compute_energy(params, data)
# 梯度
current_grad = self._compute_gradients(params, data)
# Leapfrog积分
for l in range(L):
# 更新动量
for name in momentum:
momentum[name] += 0.5 * eps * current_grad[name]
# 更新位置
for name in params:
params[name] += eps * momentum[name]
# 重新计算梯度
current_grad = self._compute_gradients(params, data)
# 更新动量
for name in momentum:
momentum[name] += 0.5 * eps * current_grad[name]
# 计算提议能量
proposed_energy = self._compute_energy(params, data)
# Metropolis-Hastings接受
delta_E = proposed_energy - current_energy
if torch.rand(1).log() > delta_E:
# 拒绝:恢复到原位置
params = current_params
return params
def _compute_energy(self, params, data):
"""计算联合能量 -log p(w, D)"""
# 设置模型参数
with torch.no_grad():
for name, param in self.model.named_parameters():
param.data = params[name]
# 计算负对数似然
x, y = data
logits = self.model(x)
nll = F.cross_entropy(logits, y, reduction='sum')
# 计算负对数先验
prior = sum(p.pow(2).sum() for p in params.values())
return nll + 0.5 * prior
def _compute_gradients(self, params, data):
"""计算能量关于参数的梯度"""
# 设置模型参数
for name, param in self.model.named_parameters():
param.data = params[name]
param.requires_grad_(name in self.hmc_param_names)
x, y = data
logits = self.model(x)
loss = F.cross_entropy(logits, y)
loss.backward()
grads = {}
for name, param in self.model.named_parameters():
if name in self.hmc_param_names:
grads[name] = param.grad.clone()
return grads函数空间变分推断
从参数空间到函数空间
传统的变分推断在参数空间进行,但参数空间的几何结构复杂。函数空间变分推断直接在函数空间进行建模,更加自然。4
核心思想
设 是神经网络定义的函数,变分推断目标是近似后验 。
将先验放在函数空间而非参数空间:
其中 是高斯过程, 是核函数。
class FunctionSpaceVariationalBNN(nn.Module):
"""
函数空间变分贝叶斯神经网络
核心思想:用随机特征展开近似函数空间先验
"""
def __init__(self, input_dim, hidden_dim, output_dim, num_features=512):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.num_features = num_features
# 随机特征展开:近似高斯过程核
self.feature_dim = num_features // 2
# 随机频率(从先验采样后固定)
self.register_buffer(
'random_frequencies',
torch.randn(self.feature_dim, input_dim) / np.sqrt(input_dim)
)
# 变分后验的均值和方差(函数空间)
self.function_mean = nn.Parameter(torch.zeros(output_dim, num_features))
self.function_log_var = nn.Parameter(torch.zeros(num_features))
# 输出层
self.output_layer = nn.Linear(num_features, output_dim)
def forward(self, x):
"""
前向传播:使用随机特征展开
"""
# 随机傅里叶特征
z = torch.cat([
torch.cos(x @ self.random_frequencies.T),
torch.sin(x @ self.random_frequencies.T)
], dim=-1)
# 应用函数空间变分
mean = z @ self.function_mean.T
# 对角协方差
var = torch.exp(self.function_log_var)
std = torch.sqrt(var)
# 添加噪声
output = mean + std * torch.randn_like(mean)
return output
def kl_divergence(self):
"""
计算与函数空间先验的KL散度
先验:p(f) = GP(0, k)
近似后验:q(f) = N(mu, diag(sigma^2))
"""
# KL(q(f) || p(f)) 的变分下界
# 这里使用简化的近似
kl = 0.5 * (
torch.exp(self.function_log_var).sum() +
(self.function_mean ** 2).sum() -
self.num_features
)
return kl神经切线核视角
函数空间变分推断与神经切线核(NTK)理论有深刻联系:
- NTK 描述了神经网络在无限宽度极限下的贝叶斯后验
- 函数空间变分 直接在 上建模,而非
模式连接性与推断
模式连接性理论
最近的研究发现,BNN的推断问题可以通过模式连接性得到简化。5
核心发现
深度网络的损失景观中,局部最优解通过低损失路径相互连接。这意味着:
- 后验 主要集中在连接模式(mode-connecting)路径上
- 可以从任意初始点通过这些路径到达后验分布的核心区域
class ModeConnectedBNNSampler:
"""
基于模式连接性的BNN采样器
"""
def __init__(self, model):
self.model = model
def path_interpolation(self, w1, w2, alpha):
"""
沿连接路径插值
使用球面线性插值(SLERP)或线性插值
"""
# 简单线性插值
return (1 - alpha) * w1 + alpha * w2
def sample_via_mode_connection(self, num_samples=100):
"""
通过模式连接进行采样
策略:
1. 找到多个局部最优解(通过多次训练)
2. 在这些解之间采样
"""
# 训练多个模型以找到不同的模式
modes = []
for _ in range(5):
model = self._train_mode()
modes.append(self._extract_weights(model))
# 在模式之间采样
samples = []
for _ in range(num_samples):
# 随机选择两个模式
i, j = np.random.choice(len(modes), 2, replace=False)
# 随机插值系数
alpha = np.random.beta(2, 2)
# 生成样本
sample = {}
for name in modes[0].keys():
sample[name] = self.path_interpolation(
modes[i][name],
modes[j][name],
alpha
)
samples.append(sample)
return samples
def _train_mode(self):
"""训练以找到一个局部最优模式"""
model = copy.deepcopy(self.model)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 短期训练
for _ in range(100):
for x, y in dataloader:
optimizer.zero_grad()
loss = F.cross_entropy(model(x), y)
loss.backward()
optimizer.step()
return model实用建议
| 策略 | 适用场景 | 复杂度 |
|---|---|---|
| MC Dropout | 快速原型、实时应用 | |
| Mean Field VI | 资源受限场景 | |
| PBNN | 需要精确不确定性 | |
| HMC Hybrid | 高精度需求 | |
| Function Space VI | 函数空间建模 | |
| Mode Connection | 探索后验结构 |
准确率-不确定性权衡
架构与推断的选择
arXiv:2503.11808 分析了BNN中准确率和不确定性量化之间的权衡。6
关键发现
- 架构影响:更宽的网络倾向于更好的点估计,但可能削弱不确定性
- 推断方法:不同的推断方法在准确率和不确定性之间有不同的权衡
- 正则化:权重衰减等正则化同时影响两者
class AccuracyUncertaintyTradeoff:
"""
分析准确率和不确定性的权衡
"""
def __init__(self, model, train_data, test_data):
self.model = model
self.train_data = train_data
self.test_data = test_data
def evaluate_tradeoff(self, inference_method='mc_dropout', num_samples=100):
"""
评估给定推断方法的准确率-不确定性权衡
"""
# 获取预测和不确定性
predictions, uncertainties = self._get_predictions(
inference_method,
num_samples
)
# 计算指标
accuracy = (predictions == self.test_data[1]).float().mean()
mean_uncertainty = uncertainties.mean()
# 校准误差
calibration_error = self._compute_calibration(predictions, uncertainties)
return {
'accuracy': accuracy.item(),
'mean_uncertainty': mean_uncertainty.item(),
'calibration_error': calibration_error.item()
}
def _compute_calibration(self, predictions, uncertainties, num_bins=15):
"""
计算ECE(Expected Calibration Error)
"""
confidences = 1 - uncertainties # 转换为置信度
accuracies = (predictions == self.test_data[1]).float()
bin_boundaries = torch.linspace(0, 1, num_bins + 1)
ece = 0.0
for i in range(num_bins):
in_bin = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
if in_bin.sum() > 0:
bin_accuracy = accuracies[in_bin].mean()
bin_confidence = confidences[in_bin].mean()
ece += in_bin.float().mean() * abs(bin_accuracy - bin_confidence)
return ece代码实现对比
主流方法对比
| 方法 | 库 | 特点 | 适用场景 |
|---|---|---|---|
| MC Dropout | PyTorch原生 | 简单、无额外参数 | 实时应用 |
| Bayes by Backprop | blitz | 变分推断、KL正则 | 研究原型 |
| Stochastic Weight Averaging | torchcontrib | 宽极小值、平滑 | 泛化提升 |
| Last Layer BNN | bbox | 高效、部分贝叶斯 | 大模型 |
| SWA-Gaussian | 自定义 | SWA+不确定性 | 生产部署 |
完整实现示例
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
class BayesianInferenceEnsemble:
"""
统一的贝叶斯推断接口
支持多种推断方法的集成
"""
def __init__(self, model, inference_method='all'):
self.model = model
self.method = inference_method
def predict(self, x, num_samples=100):
"""
统一的预测接口
"""
if self.method == 'mc_dropout':
return self._mc_dropout_predict(x, num_samples)
elif self.method == 'vi':
return self._vi_predict(x, num_samples)
elif self.method == 'swa':
return self._swa_predict(x)
elif self.method == 'all':
return self._ensemble_predict(x, num_samples)
def _mc_dropout_predict(self, x, num_samples):
"""MC Dropout预测"""
self.model.train() # 启用dropout
predictions = []
for _ in range(num_samples):
with torch.no_grad():
pred = self.model(x)
predictions.append(pred)
predictions = torch.stack(predictions)
mean_pred = predictions.mean(dim=0)
uncertainty = predictions.std(dim=0)
return mean_pred, uncertainty
def _vi_predict(self, x, num_samples):
"""变分推断预测"""
self.model.eval()
predictions = []
for _ in range(num_samples):
pred = self.model.sample(x)
predictions.append(pred)
predictions = torch.stack(predictions)
return predictions.mean(dim=0), predictions.std(dim=0)
def _swa_predict(self, x):
"""SWA预测"""
self.model.eval()
# SWA模型的预测(单一确定性模型)
with torch.no_grad():
mean_pred = self.model(x)
# 估计不确定性(基于训练数据分布)
uncertainty = self._estimate_prior_uncertainty(x)
return mean_pred, uncertainty
def _ensemble_predict(self, x, num_samples):
"""集成多种方法"""
results = {
'mc_dropout': self._mc_dropout_predict(x, num_samples // 3),
'vi': self._vi_predict(x, num_samples // 3),
'swa': self._swa_predict(x)
}
# 加权平均
weights = {'mc_dropout': 0.4, 'vi': 0.4, 'swa': 0.2}
mean_pred = sum(w * r[0] for w, r in weights.items())
uncertainty = sum(w * r[1] for w, r in weights.items())
return mean_pred, uncertainty总结
贝叶斯神经网络高级推断方法的总结:
| 方法 | 推断复杂度 | 不确定性质量 | 实现难度 |
|---|---|---|---|
| MC Dropout | 中等 | 低 | |
| Mean Field VI | 较低 | 中 | |
| PBNN | 高 | 中 | |
| HMC Hybrid | 最高 | 高 | |
| Function Space VI | 高 | 中 | |
| Mode Connection | 高 | 高 |
实践建议:
- 实时应用:使用MC Dropout
- 资源受限:使用Mean Field VI或Last Layer BNN
- 高精度需求:使用HMC Hybrid或PBNN
- 通用场景:使用Function Space VI
参考
相关文章
- bayesian-neural-networks-uncertainty — 贝叶斯神经网络与不确定性量化基础
- mc-dropout — MC Dropout方法详解
- variational-inference — 变分推断基础
- expectation-propagation — 期望传播算法
- bayesian-last-layer-deep-learning — 贝叶斯最后一层方法
Footnotes
-
Accelerating Hamiltonian Monte Carlo for Bayesian Neural Networks (arXiv:2507.14652) ↩
-
Tractable Function-Space Variational Inference in Bayesian Neural Networks (arXiv:2312.17199) ↩
-
Connecting the Dots: Is Mode-Connectedness the Key to Feasible Sample-Based Inference in BNNs? (arXiv:2402.01484) ↩
-
Understanding the Trade-offs in Accuracy and Uncertainty Quantification (arXiv:2503.11808) ↩