贝叶斯深度学习实践指南

1. 引言

贝叶斯深度学习将贝叶斯推断与深度学习结合,为神经网络提供不确定性量化能力。1 在实际应用中,不确定性量化对于风险评估、主动学习和联邦学习等领域至关重要。

本指南从实践角度出发,系统介绍贝叶斯深度学习的核心方法、实现技巧和工程实践。

2. 不确定性类型与来源

2.1 认知不确定性(Epistemic Uncertainty)

定义:由于训练数据不足或模型表达能力有限导致的不确定性。

特点

  • 可以通过更多数据减少
  • 在数据稀少区域较高
  • 适合用于主动学习

建模方式

2.2 偶然不确定性(Aleatoric Uncertainty)

定义:数据本身的随机性导致的固有不确定性。

特点

  • 无法通过更多数据减少
  • 与输入数据相关
  • 适合用于异常检测

建模方式

  • 同方差 不随输入变化
  • 异方差 随输入变化

2.3 不确定性分解

3. 核心方法对比

3.1 方法总览

方法实现复杂度计算开销不确定性质量可扩展性
MC Dropout★☆☆★☆☆★★☆★★★
变分推断★★☆★★☆★★★★★☆
MCMC★★★★★★★★★★☆☆
集成方法★★☆★★★★★★★★☆
SWA-Gaussian★★☆★☆☆★★☆★★★

3.2 MC Dropout

MC Dropout 是最简单的不确定性估计方法。2

理论依据:Dropout 近似贝叶斯推断。

数学推导

其中 次前向传播时应用 Dropout 得到的权重。

import torch
import torch.nn as nn
 
class MCDropout(nn.Module):
    def __init__(self, model, dropout_rate=0.5, n_samples=50):
        super().__init__()
        self.model = model
        self.dropout = nn.Dropout(dropout_rate)
        self.n_samples = n_samples
    
    def forward(self, x, return_uncertainty=True):
        if not return_uncertainty:
            return self.model(x)
        
        predictions = []
        for _ in range(self.n_samples):
            self.model.train()  # 确保 Dropout 激活
            pred = self.model(x)
            predictions.append(pred)
        
        predictions = torch.stack(predictions)
        
        # 均值和方差
        mean = predictions.mean(dim=0)
        variance = predictions.var(dim=0)
        
        return mean, variance

3.3 变分推断

变分推断通过变分分布 近似后验

变分下界(ELBO)

class BayesianLinear(nn.Module):
    def __init__(self, in_features, out_features, prior_mean=0, prior_var=1):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # 变分参数
        self.weight_mu = nn.Parameter(torch.randn(out_features, in_features))
        self.weight_log_var = nn.Parameter(torch.zeros(out_features, in_features))
        
        self.bias_mu = nn.Parameter(torch.randn(out_features))
        self.bias_log_var = nn.Parameter(torch.zeros(out_features))
        
        # 先验参数
        self.prior_mean = prior_mean
        self.prior_var = prior_var
        
        # 噪声参数(用于异方差)
        self.log_noise = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        # 采样权重
        weight = self.weight_mu + torch.exp(0.5 * self.weight_log_var) * torch.randn_like(self.weight_mu)
        bias = self.bias_mu + torch.exp(0.5 * self.bias_log_var) * torch.randn_like(self.bias_mu)
        
        output = F.linear(x, weight, bias)
        return output
    
    def kl_divergence(self):
        """计算 KL 散度"""
        weight_kl = -0.5 * (
            1 + self.weight_log_var 
            - self.weight_mu.pow(2) 
            - torch.exp(self.weight_log_var)
        ).sum()
        
        bias_kl = -0.5 * (
            1 + self.bias_log_var 
            - self.bias_mu.pow(2) 
            - torch.exp(self.bias_log_var)
        ).sum()
        
        return weight_kl + bias_kl
    
    def loss(self, x, y, reduction='mean'):
        """计算带 KL 项的损失"""
        output = self(x)
        
        # 重构损失
        noise_var = torch.exp(self.log_noise)
        recon_loss = 0.5 * ((y - output).pow(2) / noise_var + torch.log(noise_var)).sum()
        
        # KL 损失
        kl_loss = self.kl_divergence()
        
        if reduction == 'mean':
            return (recon_loss + kl_loss) / x.size(0)
        return recon_loss + kl_loss

3.4 集成方法

集成多个模型的预测以估计不确定性。

class EnsembleModel:
    def __init__(self, base_model_fn, n_models=10):
        self.models = [base_model_fn() for _ in range(n_models)]
        self.optimizers = [torch.optim.Adam(m.parameters()) for m in self.models]
    
    def train_epoch(self, train_loader):
        for model, optimizer in zip(self.models, self.optimizers):
            model.train()
            for x, y in train_loader:
                optimizer.zero_grad()
                loss = model.loss(x, y)
                loss.backward()
                optimizer.step()
    
    def predict(self, x, return_uncertainty=True):
        predictions = []
        for model in self.models:
            model.eval()
            with torch.no_grad():
                pred = model(x)
                predictions.append(pred)
        
        predictions = torch.stack(predictions)
        mean = predictions.mean(dim=0)
        variance = predictions.var(dim=0)
        
        if return_uncertainty:
            return mean, variance
        return mean

4. 实践技巧

4.1 训练稳定性

学习率设置

  • 贝叶斯神经网络的训练学习率通常较低
  • 推荐使用:
  • 使用学习率预热和余弦退火

权重初始化

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

4.2 先验选择

先验适用场景特性
Gaussian通用无偏,数学性质好
Laplace稀疏性促进稀疏解
Horseshoe稀疏性层次先验,强稀疏
Spike-and-Slab稀疏性二元混合,最稀疏

Horseshoe 先验实现

class HorseshoePrior:
    def __init__(self, scale=1.0):
        self.scale = scale
    
    def log_prob(self, w):
        # Horseshoe log-density
        c = torch.log(1 + 3 * self.scale**2 / w**2)
        return -0.5 * torch.sum(c + torch.log(w**2))

4.3 异常值处理

异方差模型可以自动检测异常:

def detect_anomalies(model, data_loader, threshold=3):
    """基于预测方差检测异常"""
    model.eval()
    anomalies = []
    
    with torch.no_grad():
        for x, _ in data_loader:
            _, variance = model.predict(x)
            anomaly_mask = variance.sqrt() > threshold
            anomalies.extend(anomaly_mask.nonzero().squeeze().tolist())
    
    return anomalies

5. 不确定性量化评估

5.1 校准曲线

评估预测概率与真实频率的一致性:

import numpy as np
from sklearn.calibration import calibration_curve
import matplotlib.pyplot as plt
 
def plot_calibration(y_true, y_prob, n_bins=10):
    """绘制校准曲线"""
    fraction_of_positives, mean_predicted_value = calibration_curve(
        y_true, y_prob, n_bins=n_bins
    )
    
    plt.figure(figsize=(8, 6))
    plt.plot([0, 1], [0, 1], 'k--', label='Perfect calibration')
    plt.plot(mean_predicted_value, fraction_of_positives, 'o-', label='Model')
    plt.xlabel('Mean predicted probability')
    plt.ylabel('Fraction of positives')
    plt.legend()
    plt.show()

5.2 期望校准误差(ECE)

def expected_calibration_error(y_true, y_prob, n_bins=15):
    """计算 ECE"""
    bin_edges = np.linspace(0, 1, n_bins + 1)
    ece = 0.0
    
    for i in range(n_bins):
        mask = (y_prob >= bin_edges[i]) & (y_prob < bin_edges[i + 1])
        if mask.sum() > 0:
            acc = y_true[mask].mean()
            conf = y_prob[mask].mean()
            ece += mask.sum() / len(y_true) * abs(acc - conf)
    
    return ece

5.3 分布校准

评估预测分布与真实分布的一致性:

def nll(y_true, y_pred_mean, y_pred_var):
    """负对数似然"""
    return 0.5 * (np.log(y_pred_var) + (y_true - y_pred_mean)**2 / y_pred_var)
 
def crps(y_true, y_samples):
    """连续排名概率分数"""
    n_samples = y_samples.shape[1]
    crps = 0.0
    
    for i in range(len(y_true)):
        y_true_i = y_true[i]
        y_sample_i = y_samples[i]
        
        sorted_samples = np.sort(y_sample_i)
        rank = np.searchsorted(sorted_samples, y_true_i)
        
        crps += (sorted_samples - y_true_i).mean()
    
    return crps / len(y_true)

6. 大规模部署

6.1 内存优化

后验近似:使用低秩近似减少参数量。

class LowRankBayesianLinear(nn.Module):
    """使用低秩分解的贝叶斯线性层"""
    def __init__(self, in_features, out_features, rank=10):
        super().__init__()
        self.rank = rank
        
        # 低秩分解: W = UV^T
        self.U = nn.Parameter(torch.randn(out_features, rank))
        self.V = nn.Parameter(torch.randn(in_features, rank))
        
        # 逐元素方差
        self.log_var = nn.Parameter(torch.zeros(out_features, in_features))
    
    def forward(self, x):
        W = torch.mm(self.U, self.V.t())
        variance = torch.exp(self.log_var)
        
        # 采样
        noise = torch.randn_like(W)
        W_sample = W + variance.sqrt() * noise
        
        return F.linear(x, W_sample)

6.2 推理加速

单次前向传播近似

class SWAG(nn.Module):
    """Stochastic Weight Averaging Gaussian"""
    def __init__(self, model, start_epoch=10):
        super().__init__()
        self.model = model
        self.start_epoch = start_epoch
        self.weights = []
    
    def collect_weights(self, epoch):
        if epoch >= self.start_epoch:
            state = {k: v.clone() for k, v in self.model.state_dict().items()}
            self.weights.append(state)
    
    def predict(self, x):
        # 从收集的权重中采样
        idx = np.random.randint(len(self.weights))
        state = {k: v + 0.1 * torch.randn_like(v) 
                 for k, v in self.weights[idx].items()}
        
        self.model.load_state_dict(state)
        return self.model(x)

6.3 在线学习

贝叶斯在线学习支持增量更新:

class BayesianOnlineUpdate:
    def __init__(self, prior_mean, prior_cov):
        self.posterior_mean = prior_mean
        self.posterior_cov = prior_cov
    
    def update(self, x, y, noise_var=0.1):
        """在线贝叶斯更新"""
        # 卡尔曼增益
        S = x @ self.posterior_cov @ x + noise_var
        K = self.posterior_cov @ x / S
        
        # 更新后验均值
        residual = y - x @ self.posterior_mean
        self.posterior_mean = self.posterior_mean + K * residual
        
        # 更新后验协方差
        I_Kx = torch.eye(len(self.posterior_mean)) - K.unsqueeze(-1) * x
        self.posterior_cov = I_Kx @ self.posterior_cov @ I_Kx.t() + K.unsqueeze(-1) * K.unsqueeze(-1) * noise_var

7. 应用场景

7.1 自动驾驶

# 驾驶决策的不确定性感知
def driving_decision(model, observations, uncertainty_threshold=0.3):
    pred_mean, pred_var = model.predict(observations)
    
    if pred_var.mean() > uncertainty_threshold:
        return "CONSERVATIVE"  # 高不确定性时采用保守策略
    return "NORMAL"

7.2 医疗诊断

# 诊断不确定性报告
def diagnostic_report(model, patient_data):
    pred_mean, pred_var = model.predict(patient_data)
    
    report = {
        "diagnosis": torch.sigmoid(pred_mean).item(),
        "uncertainty": pred_var.sqrt().item(),
        "confidence_interval": [
            (pred_mean - 2 * pred_var.sqrt()).item(),
            (pred_mean + 2 * pred_var.sqrt()).item()
        ]
    }
    return report

7.3 主动学习

def select_samples_for_annotation(model, unlabeled_loader, n_select=100):
    """基于不确定性的样本选择"""
    uncertainties = []
    
    for x, _ in unlabeled_loader:
        _, var = model.predict(x)
        uncertainties.extend(var.mean(dim=1).tolist())
    
    # 选择不确定性最高的样本
    indices = np.argsort(uncertainties)[-n_select:]
    return indices

8. 常见问题与解决方案

8.1 KL 消失问题

问题:KL 项在训练早期过小,导致后验近似退化为先验。

解决方案

  • 使用 KL 退火策略
  • 增加先验强度
  • 使用非平凡先验
def kl_annealing(epoch, warmup=10, method='cyclical'):
    if method == 'linear':
        return min(1.0, epoch / warmup)
    elif method == 'cyclical':
        return (np.sin(epoch * np.pi / warmup) + 1) / 2

8.2 过度自信

问题:模型在训练集上过度自信。

解决方案

  • 添加标签平滑
  • 使用 mixup/cutmix
  • 温度缩放
def temperature_scale(logits, temperature):
    """温度缩放校准"""
    return logits / temperature
 
# 优化温度
def find_optimal_temperature(model, val_loader):
    from scipy.optimize import minimize_scalar
    import torch.nn.functional as F
    
    def nll(T):
        total_nll = 0
        for x, y in val_loader:
            logits = model(x)
            scaled_logits = temperature_scale(logits, T)
            total_nll += F.cross_entropy(scaled_logits, y).item()
        return total_nll
    
    result = minimize_scalar(nll, bounds=(0.5, 5.0))
    return result.x

9. 参考资料

Footnotes

  1. Gal (2016). “Uncertainty in Deep Learning.” PhD Thesis, University of Cambridge.

  2. Gal & Ghahramani (2016). “Dropout as a Bayesian Approximation.” ICML 2016.