概述

Laplace近似是贝叶斯统计中最古老也是最实用的近似方法之一。其核心思想是:用高斯分布近似后验分布。在深度学习中,由于参数维度极高(百万至数十亿),精确贝叶斯推断不可行,Laplace近似提供了一种计算上可行的替代方案。

本系统介绍Laplace近似的理论基础、在神经网络中的应用、计算优化以及实践中的注意事项。12


理论基础

拉普拉斯方法

定理:设 是定义在 上的可积函数,假设 处取得唯一最大值,且 负定。则:

其中

直觉:高斯分布是熵最大的分布,在仅知道均值和方差时,是最”无偏”的近似。

应用于贝叶斯后验

对于后验分布

其中:

精度矩阵的分解

后验精度矩阵为:


神经网络中的应用

神经网络的后验

对于神经网络权重

近似后验

其中

预测分布

预测分布通过边缘化得到:

Monte Carlo近似

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
 
class LaplaceApproximation:
    """
    神经网络Laplace近似实现
    
    参考: Daxberger et al. "Laplace Redux" (NeurIPS 2021)
    """
    def __init__(self, model, prior_precision=1.0, prior_mean=0.0):
        self.model = model
        self.prior_precision = prior_precision
        self.prior_mean = prior_mean
        
        # 获取模型参数
        self.params = list(model.parameters())
        self.num_params = sum(p.numel() for p in self.params)
        
        # 存储MAP估计和精度矩阵
        self.map_params = None
        self.precision = None
    
    def fit(self, train_loader, lr=1e-3, max_epochs=100):
        """
        找到MAP估计
        """
        optimizer = torch.optim.Adam(self.params, lr=lr)
        
        for epoch in range(max_epochs):
            total_loss = 0
            for batch in train_loader:
                x, y = batch
                optimizer.zero_grad()
                
                # 负对数后验 = 负对数似然 + 负对数先验
                output = self.model(x)
                nll = F.cross_entropy(output, y)
                
                # 高斯先验下的L2正则化
                l2_loss = sum(p.pow(2).sum() for p in self.params)
                loss = nll + 0.5 * self.prior_precision * l2_loss
                
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            
            if epoch % 20 == 0:
                print(f"Epoch {epoch}, Loss: {total_loss:.4f}")
        
        # 保存MAP估计
        self.map_params = [p.detach().clone() for p in self.params]
        
        # 计算Hessian(近似精度矩阵)
        self._compute_precision()
    
    def _compute_precision(self):
        """
        计算后验精度矩阵(使用Fisher信息矩阵近似)
        """
        # 使用减掉一个迷你批次的方式计算Hessian
        self.model.eval()
        
        # 初始化精度矩阵
        self.precision = []
        
        for param in self.params:
            # 创建与参数形状相同的精度矩阵
            # 对于大型网络,这里应该使用结构化近似
            n = param.numel()
            precision_diag = torch.zeros(n, device=param.device)
            self.precision.append(precision_diag)
    
    def sample(self, num_samples=100):
        """
        从近似后验采样
        """
        self.model.load_state_dict(
            {name: p for name, p in zip(self.model.state_dict().keys(), self.map_params)}
        )
        
        samples = []
        for _ in range(num_samples):
            # 从近似后验采样权重
            sampled_params = []
            for param, prec in zip(self.params, self.precision):
                std = 1.0 / torch.sqrt(prec + 1e-6)
                sampled = Normal(param, std).sample()
                sampled_params.append(sampled)
            samples.append(sampled_params)
        
        return samples
    
    def predict(self, x, num_samples=100):
        """
        贝叶斯预测
        """
        self.model.eval()
        
        logits_list = []
        samples = self.sample(num_samples)
        
        for sample in samples:
            # 设置模型权重
            for param, sampled in zip(self.params, sample):
                param.data = sampled
            
            with torch.no_grad():
                logits = self.model(x)
                logits_list.append(logits)
        
        # 恢复MAP权重
        for param, map_param in zip(self.params, self.map_params):
            param.data = map_param
        
        # 平均预测分布
        logits_avg = torch.stack(logits_list).mean(dim=0)
        return logits_avg

计算优化:结构化近似

全精度矩阵的问题

对于参数量为 的网络,精度矩阵是 矩阵,存储需要 空间。对于现代网络(),这是不可行的。

对角近似

最简单的近似是对角精度矩阵

class DiagonalLaplace(LaplaceApproximation):
    """
    对角精度矩阵的Laplace近似
    
    存储复杂度: O(p) 而不是 O(p^2)
    """
    
    def _compute_precision(self):
        """
        计算对角精度矩阵
        """
        self.precision = []
        
        for param in self.params:
            n = param.numel()
            # 使用二阶导数近似
            precision_diag = torch.zeros(n, device=param.device)
            
            # 数值计算Hessian对角元素
            eps = 1e-5
            for i in range(min(n, 1000)):  # 限制计算量
                param_flat = param.detach().clone().flatten()
                
                # f(θ + εe_i)
                param_flat[i] += eps
                param.copy_(param_flat.view(param.shape))
                loss_plus = self._neg_log_posterior()
                
                # f(θ - εe_i)
                param_flat[i] -= 2 * eps
                param.copy_(param_flat.view(param.shape))
                loss_minus = self._neg_log_posterior()
                
                # 恢复原值
                param_flat[i] += eps
                param.copy_(param_flat.view(param.shape))
                
                # 二阶导数近似
                precision_diag[i] = (loss_plus - 2 * self._neg_log_posterior() + loss_minus) / (eps ** 2)
            
            self.precision.append(precision_diag)

K-FAC近似

Kronecker-Factored Approximate Curvature (K-FAC) 将精度矩阵分解为Kronecker积:

class KFACLaplace(LaplaceApproximation):
    """
    K-FAC近似的Laplace近似
    
    将每个层的精度矩阵分解为Kronecker积
    
    对于层权重 W (out_features × in_features):
    Σ^{-1} ≈ A^{-1} ⊗ B^{-1}
    
    存储复杂度: O(d_in² + d_out²) 而不是 O(d_in² × d_out²)
    """
    
    def __init__(self, model, prior_precision=1.0):
        super().__init__(model, prior_precision)
        self.kfac_matrices = {}  # 存储A和B的逆矩阵
    
    def _compute_kfac(self, layer, activations, gradients):
        """
        计算单层的K-FAC矩阵
        
        Args:
            layer: 神经网络层
            activations: 前向传播激活 (batch, in_features)
            gradients: 反向传播梯度 (batch, out_features)
        """
        batch_size = activations.shape[0]
        
        # E[aa^T]: 输入的协方差
        A = (activations.t() @ activations) / batch_size + 1e-6 * torch.eye(activations.shape[1])
        
        # E[gg^T]: 梯度的协方差
        G = (gradients.t() @ gradients) / batch_size + 1e-6 * torch.eye(gradients.shape[1])
        
        return A, G
    
    def _compute_precision(self):
        """
        计算K-FAC精度矩阵
        """
        # 需要hooks来收集中间结果
        self.activations = {}
        self.gradients = {}
        self.kfac_matrices = {}
        
        def get_activation(name):
            def hook(module, input, output):
                self.activations[name] = input[0].detach()
            return hook
        
        def get_gradient(name):
            def hook(module, grad_input, grad_output):
                self.gradients[name] = grad_output[0].detach()
            return hook
        
        # 注册hooks
        hooks = []
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                hooks.append(module.register_forward_hook(get_activation(name)))
                hooks.append(module.register_full_backward_hook(
                    lambda module, grad_input, grad_output, name=name: 
                    self.gradients.update({name: grad_output[0]})
                ))
        
        # 前向+反向计算
        # ... (省略训练循环代码)
        
        # 移除hooks
        for h in hooks:
            h.remove()

Last-Layer Laplace

对于大型网络,可以只对最后一层进行Laplace近似:

class LastLayerLaplace(LaplaceApproximation):
    """
    只对最后一层应用Laplace近似
    
    大幅降低计算复杂度,同时保留大部分不确定性量化能力
    """
    
    def fit(self, train_loader, feature_extractor_lr=1e-5, head_lr=1e-3, max_epochs=100):
        """
        训练策略:冻结特征提取器,只微调最后一层
        """
        # 分离特征提取器和分类头
        feature_extractor = nn.Sequential(*list(self.model.children())[:-1])
        classifier_head = self.model[-1]
        
        # 冻结特征提取器
        for param in feature_extractor.parameters():
            param.requires_grad = False
        
        # 只训练分类头
        optimizer = torch.optim.Adam(classifier_head.parameters(), lr=head_lr)
        
        for epoch in range(max_epochs):
            for batch in train_loader:
                x, y = batch
                
                with torch.no_grad():
                    features = feature_extractor(x)
                
                logits = classifier_head(features)
                loss = F.cross_entropy(logits, y)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
        # 对最后一层应用Laplace近似
        self._compute_last_layer_precision(classifier_head)

后验预测分布

预测均值

预测方差

集成预测

def predictive_distribution(model, x, num_samples=100):
    """
    计算预测分布的均值和方差
    """
    model.eval()
    
    probs_list = []
    
    for _ in range(num_samples):
        # 从后验采样
        # ... (采样代码)
        
        with torch.no_grad():
            logits = model(x)
            probs = F.softmax(logits, dim=-1)
            probs_list.append(probs)
    
    probs_stack = torch.stack(probs_list)  # (num_samples, batch, classes)
    
    # 预测均值
    pred_mean = probs_stack.mean(dim=0)
    
    # 预测方差
    pred_var = probs_stack.var(dim=0)
    
    # 不确定性(熵)
    pred_entropy = -(pred_mean * torch.log(pred_mean + 1e-8)).sum(dim=-1)
    
    return {
        'mean': pred_mean,
        'variance': pred_var,
        'entropy': pred_entropy,
        'samples': probs_stack
    }

与其他方法的对比

方法比较

方法计算复杂度近似质量实现难度
精确后验完美简单
Laplace近似中等中等
对角Laplace较低简单
K-FAC中等较难
变分推断(VI)取决于变分族中等
MC Dropout取决于dropout率简单
SWAG较高中等

实验对比

def compare_uncertainty_methods(model, x, y_true, train_loader):
    """
    比较不同不确定性量化方法的性能
    """
    results = {}
    
    # 1. MC Dropout
    mc_dropout_preds = mc_dropout_predict(model, x, num_samples=50)
    results['mc_dropout'] = evaluate_uncertainty(mc_dropout_preds, y_true)
    
    # 2. Laplace近似
    laplace = LaplaceApproximation(model)
    laplace.fit(train_loader)
    laplace_preds = laplace.predict(x, num_samples=50)
    results['laplace'] = evaluate_uncertainty(laplace_preds, y_true)
    
    # 3. SWAG
    swag = SWAG(model)
    swag.fit(train_loader)
    swag_preds = swag.predict(x, num_samples=50)
    results['swag'] = evaluate_uncertainty(swag_preds, y_true)
    
    return results

实践指南

何时使用Laplace近似

适合场景

  • 需要可信的不确定性估计
  • 模型已经训练完成,只需添加不确定性
  • 需要快速的预测分布计算
  • 资源有限,无法使用ensemble

不适合场景

  • 数据非高斯噪声
  • 需要捕获多模态后验
  • 网络结构非常复杂
  • 实时性要求极高

先验精度选择

先验精度 控制后验的收缩程度:

# 策略1:默认设置
prior_precision = 1.0
 
# 策略2:交叉验证
from sklearn.model_selection import cross_val_score
 
best_precision = None
best_score = -np.inf
for precision in [0.01, 0.1, 1.0, 10.0]:
    score = cross_val_score(model, X, y, cv=5)  # 某种验证
    if score > best_score:
        best_score = score
        best_precision = precision
 
# 策略3:经验贝叶斯(边际似然最大化)

数值稳定性

def stable_precision_inverse(P):
    """
    数值稳定的精度矩阵求逆
    
    添加正则化项防止奇异
    """
    # 方法1:加正则化
    P_reg = P + 1e-6 * torch.eye(P.shape[0])
    
    # 方法2:特征值截断
    eigenvalues, eigenvectors = torch.linalg.eigh(P)
    eigenvalues = torch.clamp(eigenvalues, min=1e-6)
    P_inv = eigenvectors @ torch.diag(1.0 / eigenvalues) @ eigenvectors.t()
    
    return P_inv

应用场景

1. 不确定性感知决策

def uncertainty_aware_decision(x, model, laplace, threshold=0.1):
    """
    基于不确定性的决策
    
    当预测不确定时,选择保守策略
    """
    pred = laplace.predict(x.unsqueeze(0), num_samples=100)
    pred_probs = F.softmax(pred, dim=-1)
    
    max_prob, pred_class = pred_probs.max(dim=-1)
    uncertainty = 1 - max_prob  # 不确定性 = 1 - 置信度
    
    if uncertainty > threshold:
        return "abstain"  # 弃权,选择保守策略
    return pred_class.item()

2. 异常检测

def detect_anomalies(model, laplace, data_loader, k=5):
    """
    基于预测不确定性的异常检测
    """
    uncertainties = []
    
    for x, _ in data_loader:
        pred = laplace.predict(x, num_samples=50)
        pred_probs = F.softmax(pred, dim=-1)
        
        # 使用预测熵作为不确定性度量
        entropy = -(pred_probs * torch.log(pred_probs + 1e-8)).sum(dim=-1)
        uncertainties.extend(entropy.cpu().numpy())
    
    # 选择top-k最不确定的样本作为潜在异常
    threshold = np.percentile(uncertainties, 100 - k)
    anomalies = [i for i, u in enumerate(uncertainties) if u > threshold]
    
    return anomalies

3. 主动学习

def active_learning_criterion(model, laplace, x_pool):
    """
    基于不确定性的主动学习样本选择
    """
    uncertainties = []
    
    for x in x_pool:
        pred = laplace.predict(x.unsqueeze(0), num_samples=50)
        pred_probs = F.softmax(pred, dim=-1)
        
        # 最大预测概率
        max_prob = pred_probs.max(dim=-1)[0]
        uncertainty = 1 - max_prob
        uncertainties.append(uncertainty.item())
    
    # 选择不确定性最大的样本
    query_idx = np.argmax(uncertainties)
    return query_idx

参考


相关阅读


Footnotes

  1. Daxberger, E., et al. (2021). “Laplace Redux — Effortless Bayesian Deep Learning”. NeurIPS 2021.

  2. Kristiadi, A., et al. (2020). “Being Bayesian about Uncertainty”. ICML 2020.