贝叶斯深度学习实践指南
1. 引言
贝叶斯深度学习将贝叶斯推断与深度学习结合,为神经网络提供不确定性量化能力。1 在实际应用中,不确定性量化对于风险评估、主动学习和联邦学习等领域至关重要。
本指南从实践角度出发,系统介绍贝叶斯深度学习的核心方法、实现技巧和工程实践。
2. 不确定性类型与来源
2.1 认知不确定性(Epistemic Uncertainty)
定义:由于训练数据不足或模型表达能力有限导致的不确定性。
特点:
- 可以通过更多数据减少
- 在数据稀少区域较高
- 适合用于主动学习
建模方式:
2.2 偶然不确定性(Aleatoric Uncertainty)
定义:数据本身的随机性导致的固有不确定性。
特点:
- 无法通过更多数据减少
- 与输入数据相关
- 适合用于异常检测
建模方式:
- 同方差: 不随输入变化
- 异方差: 随输入变化
2.3 不确定性分解
3. 核心方法对比
3.1 方法总览
| 方法 | 实现复杂度 | 计算开销 | 不确定性质量 | 可扩展性 |
|---|---|---|---|---|
| MC Dropout | ★☆☆ | ★☆☆ | ★★☆ | ★★★ |
| 变分推断 | ★★☆ | ★★☆ | ★★★ | ★★☆ |
| MCMC | ★★★ | ★★★ | ★★★ | ★☆☆ |
| 集成方法 | ★★☆ | ★★★ | ★★★ | ★★☆ |
| SWA-Gaussian | ★★☆ | ★☆☆ | ★★☆ | ★★★ |
3.2 MC Dropout
MC Dropout 是最简单的不确定性估计方法。2
理论依据:Dropout 近似贝叶斯推断。
数学推导:
其中 是 次前向传播时应用 Dropout 得到的权重。
import torch
import torch.nn as nn
class MCDropout(nn.Module):
def __init__(self, model, dropout_rate=0.5, n_samples=50):
super().__init__()
self.model = model
self.dropout = nn.Dropout(dropout_rate)
self.n_samples = n_samples
def forward(self, x, return_uncertainty=True):
if not return_uncertainty:
return self.model(x)
predictions = []
for _ in range(self.n_samples):
self.model.train() # 确保 Dropout 激活
pred = self.model(x)
predictions.append(pred)
predictions = torch.stack(predictions)
# 均值和方差
mean = predictions.mean(dim=0)
variance = predictions.var(dim=0)
return mean, variance3.3 变分推断
变分推断通过变分分布 近似后验 。
变分下界(ELBO):
class BayesianLinear(nn.Module):
def __init__(self, in_features, out_features, prior_mean=0, prior_var=1):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# 变分参数
self.weight_mu = nn.Parameter(torch.randn(out_features, in_features))
self.weight_log_var = nn.Parameter(torch.zeros(out_features, in_features))
self.bias_mu = nn.Parameter(torch.randn(out_features))
self.bias_log_var = nn.Parameter(torch.zeros(out_features))
# 先验参数
self.prior_mean = prior_mean
self.prior_var = prior_var
# 噪声参数(用于异方差)
self.log_noise = nn.Parameter(torch.zeros(1))
def forward(self, x):
# 采样权重
weight = self.weight_mu + torch.exp(0.5 * self.weight_log_var) * torch.randn_like(self.weight_mu)
bias = self.bias_mu + torch.exp(0.5 * self.bias_log_var) * torch.randn_like(self.bias_mu)
output = F.linear(x, weight, bias)
return output
def kl_divergence(self):
"""计算 KL 散度"""
weight_kl = -0.5 * (
1 + self.weight_log_var
- self.weight_mu.pow(2)
- torch.exp(self.weight_log_var)
).sum()
bias_kl = -0.5 * (
1 + self.bias_log_var
- self.bias_mu.pow(2)
- torch.exp(self.bias_log_var)
).sum()
return weight_kl + bias_kl
def loss(self, x, y, reduction='mean'):
"""计算带 KL 项的损失"""
output = self(x)
# 重构损失
noise_var = torch.exp(self.log_noise)
recon_loss = 0.5 * ((y - output).pow(2) / noise_var + torch.log(noise_var)).sum()
# KL 损失
kl_loss = self.kl_divergence()
if reduction == 'mean':
return (recon_loss + kl_loss) / x.size(0)
return recon_loss + kl_loss3.4 集成方法
集成多个模型的预测以估计不确定性。
class EnsembleModel:
def __init__(self, base_model_fn, n_models=10):
self.models = [base_model_fn() for _ in range(n_models)]
self.optimizers = [torch.optim.Adam(m.parameters()) for m in self.models]
def train_epoch(self, train_loader):
for model, optimizer in zip(self.models, self.optimizers):
model.train()
for x, y in train_loader:
optimizer.zero_grad()
loss = model.loss(x, y)
loss.backward()
optimizer.step()
def predict(self, x, return_uncertainty=True):
predictions = []
for model in self.models:
model.eval()
with torch.no_grad():
pred = model(x)
predictions.append(pred)
predictions = torch.stack(predictions)
mean = predictions.mean(dim=0)
variance = predictions.var(dim=0)
if return_uncertainty:
return mean, variance
return mean4. 实践技巧
4.1 训练稳定性
学习率设置:
- 贝叶斯神经网络的训练学习率通常较低
- 推荐使用: 到
- 使用学习率预热和余弦退火
权重初始化:
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)4.2 先验选择
| 先验 | 适用场景 | 特性 |
|---|---|---|
| Gaussian | 通用 | 无偏,数学性质好 |
| Laplace | 稀疏性 | 促进稀疏解 |
| Horseshoe | 稀疏性 | 层次先验,强稀疏 |
| Spike-and-Slab | 稀疏性 | 二元混合,最稀疏 |
Horseshoe 先验实现:
class HorseshoePrior:
def __init__(self, scale=1.0):
self.scale = scale
def log_prob(self, w):
# Horseshoe log-density
c = torch.log(1 + 3 * self.scale**2 / w**2)
return -0.5 * torch.sum(c + torch.log(w**2))4.3 异常值处理
异方差模型可以自动检测异常:
def detect_anomalies(model, data_loader, threshold=3):
"""基于预测方差检测异常"""
model.eval()
anomalies = []
with torch.no_grad():
for x, _ in data_loader:
_, variance = model.predict(x)
anomaly_mask = variance.sqrt() > threshold
anomalies.extend(anomaly_mask.nonzero().squeeze().tolist())
return anomalies5. 不确定性量化评估
5.1 校准曲线
评估预测概率与真实频率的一致性:
import numpy as np
from sklearn.calibration import calibration_curve
import matplotlib.pyplot as plt
def plot_calibration(y_true, y_prob, n_bins=10):
"""绘制校准曲线"""
fraction_of_positives, mean_predicted_value = calibration_curve(
y_true, y_prob, n_bins=n_bins
)
plt.figure(figsize=(8, 6))
plt.plot([0, 1], [0, 1], 'k--', label='Perfect calibration')
plt.plot(mean_predicted_value, fraction_of_positives, 'o-', label='Model')
plt.xlabel('Mean predicted probability')
plt.ylabel('Fraction of positives')
plt.legend()
plt.show()5.2 期望校准误差(ECE)
def expected_calibration_error(y_true, y_prob, n_bins=15):
"""计算 ECE"""
bin_edges = np.linspace(0, 1, n_bins + 1)
ece = 0.0
for i in range(n_bins):
mask = (y_prob >= bin_edges[i]) & (y_prob < bin_edges[i + 1])
if mask.sum() > 0:
acc = y_true[mask].mean()
conf = y_prob[mask].mean()
ece += mask.sum() / len(y_true) * abs(acc - conf)
return ece5.3 分布校准
评估预测分布与真实分布的一致性:
def nll(y_true, y_pred_mean, y_pred_var):
"""负对数似然"""
return 0.5 * (np.log(y_pred_var) + (y_true - y_pred_mean)**2 / y_pred_var)
def crps(y_true, y_samples):
"""连续排名概率分数"""
n_samples = y_samples.shape[1]
crps = 0.0
for i in range(len(y_true)):
y_true_i = y_true[i]
y_sample_i = y_samples[i]
sorted_samples = np.sort(y_sample_i)
rank = np.searchsorted(sorted_samples, y_true_i)
crps += (sorted_samples - y_true_i).mean()
return crps / len(y_true)6. 大规模部署
6.1 内存优化
后验近似:使用低秩近似减少参数量。
class LowRankBayesianLinear(nn.Module):
"""使用低秩分解的贝叶斯线性层"""
def __init__(self, in_features, out_features, rank=10):
super().__init__()
self.rank = rank
# 低秩分解: W = UV^T
self.U = nn.Parameter(torch.randn(out_features, rank))
self.V = nn.Parameter(torch.randn(in_features, rank))
# 逐元素方差
self.log_var = nn.Parameter(torch.zeros(out_features, in_features))
def forward(self, x):
W = torch.mm(self.U, self.V.t())
variance = torch.exp(self.log_var)
# 采样
noise = torch.randn_like(W)
W_sample = W + variance.sqrt() * noise
return F.linear(x, W_sample)6.2 推理加速
单次前向传播近似:
class SWAG(nn.Module):
"""Stochastic Weight Averaging Gaussian"""
def __init__(self, model, start_epoch=10):
super().__init__()
self.model = model
self.start_epoch = start_epoch
self.weights = []
def collect_weights(self, epoch):
if epoch >= self.start_epoch:
state = {k: v.clone() for k, v in self.model.state_dict().items()}
self.weights.append(state)
def predict(self, x):
# 从收集的权重中采样
idx = np.random.randint(len(self.weights))
state = {k: v + 0.1 * torch.randn_like(v)
for k, v in self.weights[idx].items()}
self.model.load_state_dict(state)
return self.model(x)6.3 在线学习
贝叶斯在线学习支持增量更新:
class BayesianOnlineUpdate:
def __init__(self, prior_mean, prior_cov):
self.posterior_mean = prior_mean
self.posterior_cov = prior_cov
def update(self, x, y, noise_var=0.1):
"""在线贝叶斯更新"""
# 卡尔曼增益
S = x @ self.posterior_cov @ x + noise_var
K = self.posterior_cov @ x / S
# 更新后验均值
residual = y - x @ self.posterior_mean
self.posterior_mean = self.posterior_mean + K * residual
# 更新后验协方差
I_Kx = torch.eye(len(self.posterior_mean)) - K.unsqueeze(-1) * x
self.posterior_cov = I_Kx @ self.posterior_cov @ I_Kx.t() + K.unsqueeze(-1) * K.unsqueeze(-1) * noise_var7. 应用场景
7.1 自动驾驶
# 驾驶决策的不确定性感知
def driving_decision(model, observations, uncertainty_threshold=0.3):
pred_mean, pred_var = model.predict(observations)
if pred_var.mean() > uncertainty_threshold:
return "CONSERVATIVE" # 高不确定性时采用保守策略
return "NORMAL"7.2 医疗诊断
# 诊断不确定性报告
def diagnostic_report(model, patient_data):
pred_mean, pred_var = model.predict(patient_data)
report = {
"diagnosis": torch.sigmoid(pred_mean).item(),
"uncertainty": pred_var.sqrt().item(),
"confidence_interval": [
(pred_mean - 2 * pred_var.sqrt()).item(),
(pred_mean + 2 * pred_var.sqrt()).item()
]
}
return report7.3 主动学习
def select_samples_for_annotation(model, unlabeled_loader, n_select=100):
"""基于不确定性的样本选择"""
uncertainties = []
for x, _ in unlabeled_loader:
_, var = model.predict(x)
uncertainties.extend(var.mean(dim=1).tolist())
# 选择不确定性最高的样本
indices = np.argsort(uncertainties)[-n_select:]
return indices8. 常见问题与解决方案
8.1 KL 消失问题
问题:KL 项在训练早期过小,导致后验近似退化为先验。
解决方案:
- 使用 KL 退火策略
- 增加先验强度
- 使用非平凡先验
def kl_annealing(epoch, warmup=10, method='cyclical'):
if method == 'linear':
return min(1.0, epoch / warmup)
elif method == 'cyclical':
return (np.sin(epoch * np.pi / warmup) + 1) / 28.2 过度自信
问题:模型在训练集上过度自信。
解决方案:
- 添加标签平滑
- 使用 mixup/cutmix
- 温度缩放
def temperature_scale(logits, temperature):
"""温度缩放校准"""
return logits / temperature
# 优化温度
def find_optimal_temperature(model, val_loader):
from scipy.optimize import minimize_scalar
import torch.nn.functional as F
def nll(T):
total_nll = 0
for x, y in val_loader:
logits = model(x)
scaled_logits = temperature_scale(logits, T)
total_nll += F.cross_entropy(scaled_logits, y).item()
return total_nll
result = minimize_scalar(nll, bounds=(0.5, 5.0))
return result.x