概述
Laplace近似是贝叶斯统计中最古老也是最实用的近似方法之一。其核心思想是:用高斯分布近似后验分布。在深度学习中,由于参数维度极高(百万至数十亿),精确贝叶斯推断不可行,Laplace近似提供了一种计算上可行的替代方案。
本系统介绍Laplace近似的理论基础、在神经网络中的应用、计算优化以及实践中的注意事项。12
理论基础
拉普拉斯方法
定理:设 是定义在 上的可积函数,假设 在 处取得唯一最大值,且 负定。则:
其中 。
直觉:高斯分布是熵最大的分布,在仅知道均值和方差时,是最”无偏”的近似。
应用于贝叶斯后验
对于后验分布 :
其中:
精度矩阵的分解
后验精度矩阵为:
神经网络中的应用
神经网络的后验
对于神经网络权重 :
近似后验
其中 。
预测分布
预测分布通过边缘化得到:
Monte Carlo近似:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
class LaplaceApproximation:
"""
神经网络Laplace近似实现
参考: Daxberger et al. "Laplace Redux" (NeurIPS 2021)
"""
def __init__(self, model, prior_precision=1.0, prior_mean=0.0):
self.model = model
self.prior_precision = prior_precision
self.prior_mean = prior_mean
# 获取模型参数
self.params = list(model.parameters())
self.num_params = sum(p.numel() for p in self.params)
# 存储MAP估计和精度矩阵
self.map_params = None
self.precision = None
def fit(self, train_loader, lr=1e-3, max_epochs=100):
"""
找到MAP估计
"""
optimizer = torch.optim.Adam(self.params, lr=lr)
for epoch in range(max_epochs):
total_loss = 0
for batch in train_loader:
x, y = batch
optimizer.zero_grad()
# 负对数后验 = 负对数似然 + 负对数先验
output = self.model(x)
nll = F.cross_entropy(output, y)
# 高斯先验下的L2正则化
l2_loss = sum(p.pow(2).sum() for p in self.params)
loss = nll + 0.5 * self.prior_precision * l2_loss
loss.backward()
optimizer.step()
total_loss += loss.item()
if epoch % 20 == 0:
print(f"Epoch {epoch}, Loss: {total_loss:.4f}")
# 保存MAP估计
self.map_params = [p.detach().clone() for p in self.params]
# 计算Hessian(近似精度矩阵)
self._compute_precision()
def _compute_precision(self):
"""
计算后验精度矩阵(使用Fisher信息矩阵近似)
"""
# 使用减掉一个迷你批次的方式计算Hessian
self.model.eval()
# 初始化精度矩阵
self.precision = []
for param in self.params:
# 创建与参数形状相同的精度矩阵
# 对于大型网络,这里应该使用结构化近似
n = param.numel()
precision_diag = torch.zeros(n, device=param.device)
self.precision.append(precision_diag)
def sample(self, num_samples=100):
"""
从近似后验采样
"""
self.model.load_state_dict(
{name: p for name, p in zip(self.model.state_dict().keys(), self.map_params)}
)
samples = []
for _ in range(num_samples):
# 从近似后验采样权重
sampled_params = []
for param, prec in zip(self.params, self.precision):
std = 1.0 / torch.sqrt(prec + 1e-6)
sampled = Normal(param, std).sample()
sampled_params.append(sampled)
samples.append(sampled_params)
return samples
def predict(self, x, num_samples=100):
"""
贝叶斯预测
"""
self.model.eval()
logits_list = []
samples = self.sample(num_samples)
for sample in samples:
# 设置模型权重
for param, sampled in zip(self.params, sample):
param.data = sampled
with torch.no_grad():
logits = self.model(x)
logits_list.append(logits)
# 恢复MAP权重
for param, map_param in zip(self.params, self.map_params):
param.data = map_param
# 平均预测分布
logits_avg = torch.stack(logits_list).mean(dim=0)
return logits_avg计算优化:结构化近似
全精度矩阵的问题
对于参数量为 的网络,精度矩阵是 矩阵,存储需要 空间。对于现代网络(),这是不可行的。
对角近似
最简单的近似是对角精度矩阵:
class DiagonalLaplace(LaplaceApproximation):
"""
对角精度矩阵的Laplace近似
存储复杂度: O(p) 而不是 O(p^2)
"""
def _compute_precision(self):
"""
计算对角精度矩阵
"""
self.precision = []
for param in self.params:
n = param.numel()
# 使用二阶导数近似
precision_diag = torch.zeros(n, device=param.device)
# 数值计算Hessian对角元素
eps = 1e-5
for i in range(min(n, 1000)): # 限制计算量
param_flat = param.detach().clone().flatten()
# f(θ + εe_i)
param_flat[i] += eps
param.copy_(param_flat.view(param.shape))
loss_plus = self._neg_log_posterior()
# f(θ - εe_i)
param_flat[i] -= 2 * eps
param.copy_(param_flat.view(param.shape))
loss_minus = self._neg_log_posterior()
# 恢复原值
param_flat[i] += eps
param.copy_(param_flat.view(param.shape))
# 二阶导数近似
precision_diag[i] = (loss_plus - 2 * self._neg_log_posterior() + loss_minus) / (eps ** 2)
self.precision.append(precision_diag)K-FAC近似
Kronecker-Factored Approximate Curvature (K-FAC) 将精度矩阵分解为Kronecker积:
class KFACLaplace(LaplaceApproximation):
"""
K-FAC近似的Laplace近似
将每个层的精度矩阵分解为Kronecker积
对于层权重 W (out_features × in_features):
Σ^{-1} ≈ A^{-1} ⊗ B^{-1}
存储复杂度: O(d_in² + d_out²) 而不是 O(d_in² × d_out²)
"""
def __init__(self, model, prior_precision=1.0):
super().__init__(model, prior_precision)
self.kfac_matrices = {} # 存储A和B的逆矩阵
def _compute_kfac(self, layer, activations, gradients):
"""
计算单层的K-FAC矩阵
Args:
layer: 神经网络层
activations: 前向传播激活 (batch, in_features)
gradients: 反向传播梯度 (batch, out_features)
"""
batch_size = activations.shape[0]
# E[aa^T]: 输入的协方差
A = (activations.t() @ activations) / batch_size + 1e-6 * torch.eye(activations.shape[1])
# E[gg^T]: 梯度的协方差
G = (gradients.t() @ gradients) / batch_size + 1e-6 * torch.eye(gradients.shape[1])
return A, G
def _compute_precision(self):
"""
计算K-FAC精度矩阵
"""
# 需要hooks来收集中间结果
self.activations = {}
self.gradients = {}
self.kfac_matrices = {}
def get_activation(name):
def hook(module, input, output):
self.activations[name] = input[0].detach()
return hook
def get_gradient(name):
def hook(module, grad_input, grad_output):
self.gradients[name] = grad_output[0].detach()
return hook
# 注册hooks
hooks = []
for name, module in self.model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
hooks.append(module.register_forward_hook(get_activation(name)))
hooks.append(module.register_full_backward_hook(
lambda module, grad_input, grad_output, name=name:
self.gradients.update({name: grad_output[0]})
))
# 前向+反向计算
# ... (省略训练循环代码)
# 移除hooks
for h in hooks:
h.remove()Last-Layer Laplace
对于大型网络,可以只对最后一层进行Laplace近似:
class LastLayerLaplace(LaplaceApproximation):
"""
只对最后一层应用Laplace近似
大幅降低计算复杂度,同时保留大部分不确定性量化能力
"""
def fit(self, train_loader, feature_extractor_lr=1e-5, head_lr=1e-3, max_epochs=100):
"""
训练策略:冻结特征提取器,只微调最后一层
"""
# 分离特征提取器和分类头
feature_extractor = nn.Sequential(*list(self.model.children())[:-1])
classifier_head = self.model[-1]
# 冻结特征提取器
for param in feature_extractor.parameters():
param.requires_grad = False
# 只训练分类头
optimizer = torch.optim.Adam(classifier_head.parameters(), lr=head_lr)
for epoch in range(max_epochs):
for batch in train_loader:
x, y = batch
with torch.no_grad():
features = feature_extractor(x)
logits = classifier_head(features)
loss = F.cross_entropy(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 对最后一层应用Laplace近似
self._compute_last_layer_precision(classifier_head)后验预测分布
预测均值
预测方差
集成预测
def predictive_distribution(model, x, num_samples=100):
"""
计算预测分布的均值和方差
"""
model.eval()
probs_list = []
for _ in range(num_samples):
# 从后验采样
# ... (采样代码)
with torch.no_grad():
logits = model(x)
probs = F.softmax(logits, dim=-1)
probs_list.append(probs)
probs_stack = torch.stack(probs_list) # (num_samples, batch, classes)
# 预测均值
pred_mean = probs_stack.mean(dim=0)
# 预测方差
pred_var = probs_stack.var(dim=0)
# 不确定性(熵)
pred_entropy = -(pred_mean * torch.log(pred_mean + 1e-8)).sum(dim=-1)
return {
'mean': pred_mean,
'variance': pred_var,
'entropy': pred_entropy,
'samples': probs_stack
}与其他方法的对比
方法比较
| 方法 | 计算复杂度 | 近似质量 | 实现难度 |
|---|---|---|---|
| 精确后验 | 完美 | 简单 | |
| Laplace近似 | 中等 | 中等 | |
| 对角Laplace | 较低 | 简单 | |
| K-FAC | 中等 | 较难 | |
| 变分推断(VI) | 取决于变分族 | 中等 | |
| MC Dropout | 取决于dropout率 | 简单 | |
| SWAG | 较高 | 中等 |
实验对比
def compare_uncertainty_methods(model, x, y_true, train_loader):
"""
比较不同不确定性量化方法的性能
"""
results = {}
# 1. MC Dropout
mc_dropout_preds = mc_dropout_predict(model, x, num_samples=50)
results['mc_dropout'] = evaluate_uncertainty(mc_dropout_preds, y_true)
# 2. Laplace近似
laplace = LaplaceApproximation(model)
laplace.fit(train_loader)
laplace_preds = laplace.predict(x, num_samples=50)
results['laplace'] = evaluate_uncertainty(laplace_preds, y_true)
# 3. SWAG
swag = SWAG(model)
swag.fit(train_loader)
swag_preds = swag.predict(x, num_samples=50)
results['swag'] = evaluate_uncertainty(swag_preds, y_true)
return results实践指南
何时使用Laplace近似
✅ 适合场景:
- 需要可信的不确定性估计
- 模型已经训练完成,只需添加不确定性
- 需要快速的预测分布计算
- 资源有限,无法使用ensemble
❌ 不适合场景:
- 数据非高斯噪声
- 需要捕获多模态后验
- 网络结构非常复杂
- 实时性要求极高
先验精度选择
先验精度 控制后验的收缩程度:
# 策略1:默认设置
prior_precision = 1.0
# 策略2:交叉验证
from sklearn.model_selection import cross_val_score
best_precision = None
best_score = -np.inf
for precision in [0.01, 0.1, 1.0, 10.0]:
score = cross_val_score(model, X, y, cv=5) # 某种验证
if score > best_score:
best_score = score
best_precision = precision
# 策略3:经验贝叶斯(边际似然最大化)数值稳定性
def stable_precision_inverse(P):
"""
数值稳定的精度矩阵求逆
添加正则化项防止奇异
"""
# 方法1:加正则化
P_reg = P + 1e-6 * torch.eye(P.shape[0])
# 方法2:特征值截断
eigenvalues, eigenvectors = torch.linalg.eigh(P)
eigenvalues = torch.clamp(eigenvalues, min=1e-6)
P_inv = eigenvectors @ torch.diag(1.0 / eigenvalues) @ eigenvectors.t()
return P_inv应用场景
1. 不确定性感知决策
def uncertainty_aware_decision(x, model, laplace, threshold=0.1):
"""
基于不确定性的决策
当预测不确定时,选择保守策略
"""
pred = laplace.predict(x.unsqueeze(0), num_samples=100)
pred_probs = F.softmax(pred, dim=-1)
max_prob, pred_class = pred_probs.max(dim=-1)
uncertainty = 1 - max_prob # 不确定性 = 1 - 置信度
if uncertainty > threshold:
return "abstain" # 弃权,选择保守策略
return pred_class.item()2. 异常检测
def detect_anomalies(model, laplace, data_loader, k=5):
"""
基于预测不确定性的异常检测
"""
uncertainties = []
for x, _ in data_loader:
pred = laplace.predict(x, num_samples=50)
pred_probs = F.softmax(pred, dim=-1)
# 使用预测熵作为不确定性度量
entropy = -(pred_probs * torch.log(pred_probs + 1e-8)).sum(dim=-1)
uncertainties.extend(entropy.cpu().numpy())
# 选择top-k最不确定的样本作为潜在异常
threshold = np.percentile(uncertainties, 100 - k)
anomalies = [i for i, u in enumerate(uncertainties) if u > threshold]
return anomalies3. 主动学习
def active_learning_criterion(model, laplace, x_pool):
"""
基于不确定性的主动学习样本选择
"""
uncertainties = []
for x in x_pool:
pred = laplace.predict(x.unsqueeze(0), num_samples=50)
pred_probs = F.softmax(pred, dim=-1)
# 最大预测概率
max_prob = pred_probs.max(dim=-1)[0]
uncertainty = 1 - max_prob
uncertainties.append(uncertainty.item())
# 选择不确定性最大的样本
query_idx = np.argmax(uncertainties)
return query_idx参考
相关阅读
- 贝叶斯估计理论基础 — 贝叶斯统计基础
- BNN不确定性量化 — 深度学习中的不确定性
- 变分推断进阶 — 另一种后验近似方法
- MC Dropout — 简单的贝叶斯近似
- 集成学习统一理论 — 集成方法与不确定性的联系