Laplace 近似与贝叶斯神经网络
Laplace 近似是贝叶斯统计中一种经典的后验近似方法,通过用高斯分布逼近复杂后验来实现贝叶斯推断。对于神经网络,Laplace 近似将训练好的网络权重附近的后验分布用高斯分布建模,从而获得预测不确定性。1
Laplace 近似基础
拉普拉斯方法
对于难以精确计算的后验分布 ,Laplace 近似的核心思想是:
- 找到后验的峰值(MAP 估计)
- 在峰值处用二阶泰勒展开近似对数后验
- 得到的近似分布是高斯分布
数学推导
后验分布的对数:
在 MAP 估计 处进行二阶泰勒展开:
其中 是Hessian 矩阵(对数后验的二阶导数):
因此后验近似为高斯分布:
关键挑战:Hessian 矩阵
对于大型神经网络:
- 权重数量: -
- Hessian 矩阵大小: -
- 无法直接存储和求逆
神经网络的 Hessian
Hessian 的组成
对于分类任务(交叉熵损失),Hessian 可以分解为:
Generalized Gauss-Newton (GGN) 近似
精确的 Hessian 计算成本极高。GGN 近似使用 Fisher 信息矩阵代替:
优势:GGN 是 Hessian 的上界(Positive Semi-Definite),且计算更高效。
def ggn_matrix(model, dataloader, device='cuda'):
"""
计算 Generalized Gauss-Newton 矩阵
注意:返回的是 GGN 的对角近似
"""
model.eval()
# 收集梯度
gradients = []
for x, y in dataloader:
x, y = x.to(device), y.to(device)
output = model(x)
loss = F.cross_entropy(output, y)
model.zero_grad()
loss.backward()
# 收集梯度
grad = torch.cat([p.grad.flatten() for p in model.parameters()])
gradients.append(grad)
# GGN 的对角近似
gradients = torch.stack(gradients)
ggn_diag = (gradients ** 2).mean(dim=0)
return ggn_diagKronecker-Factored 近似
基本思想
为了避免存储完整的 Hessian/K-FAC 将 Hessian 分解为层间独立的 Kronecker 积:
- :输入激活的相关矩阵
- :输出梯度的相关矩阵
K-FAC 的核心公式
对于第 层的权重 :
其中:
- :第 个样本在第 层输入的激活
- :第 层权重的梯度
PyTorch 实现(简化版)
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
class KFACPreconditioner:
"""
Kronecker-Factored Approximate Curvature (K-FAC) 预条件器
用于自然梯度下降和 Laplace 近似
"""
def __init__(self, model, damping=1e-3):
self.model = model
self.damping = damping
self.steps = 0
# 存储 A 和 G 矩阵
self.m_A = {} # 输入协方差
self.m_G = {} # 输出梯度协方差
self.momentum = {}
def register(self):
"""注册 Hooks 用于收集统计量"""
def forward_hook(module, input, output):
if isinstance(module, nn.Linear):
self._save_input(module, input[0].detach())
def backward_hook(module, grad_input, grad_output):
if isinstance(module, nn.Linear):
self._save_grad_output(module, grad_output[0].detach())
self.handles = []
for module in self.model.modules():
if isinstance(module, nn.Linear):
self.handles.append(module.register_forward_hook(forward_hook))
self.handles.append(module.register_backward_hook(backward_hook))
def _save_input(self, module, input):
# 计算输入的协方差
x = input.flatten(0, -2) # (batch * seq, dim)
A = (x.T @ x) / x.shape[0] # (dim, dim)
name = self._get_module_name(module)
self.m_A[name] = A
def _save_grad_output(self, module, grad_output):
# 计算梯度的协方差
g = grad_output.flatten(0, -2)
G = (g.T @ g) / g.shape[0]
name = self._get_module_name(module)
self.m_G[name] = G
def _get_module_name(self, module):
for name, m in self.model.named_modules():
if m is module:
return name
return None
def update_running_stats(self):
"""更新 K-FAC 矩阵的移动平均"""
self.steps += 1
for name in self.m_A:
# 指数移动平均
if name not in self.momentum:
self.momentum[name] = {'A': 0.9, 'G': 0.9}
# 更新 A 和 G
pass # 实现更新逻辑
def precondition_grad(self, named_parameters):
"""
用 K-FAC 近似曲率预条件梯度
返回预条件后的梯度
"""
for name, param in named_parameters:
if name in self.m_A and name in self.m_G:
A = self.m_A[name] + self.damping * torch.eye(param.shape[1], device=param.device)
G = self.m_G[name] + self.damping * torch.eye(param.shape[0], device=param.device)
# Kronecker 逆
# inv(A ⊗ G) = inv(A) ⊗ inv(G)
grad = param.grad.view(param.shape[0], -1)
# 预条件
# grad_new = G^{-1} @ grad @ A^{-1}
grad_new = torch.linalg.solve(G, grad)
grad_new = torch.linalg.solve(A.T, grad_new.T).T
param.grad.copy_(grad_new.view_as(param))Last-layer Laplace
核心思想
Last-layer Laplace 只对神经网络的最后一层进行贝叶斯推断,将前面的层视为固定的特征提取器。2
输入 → [固定特征提取层] → [贝叶斯最后一层] → 输出
为什么只考虑最后一层?
| 考虑因素 | 说明 |
|---|---|
| 计算效率 | 只需存储最后一层的协方差矩阵 |
| 内存占用 | 而非 ,其中 |
| 实用性 | 对于大多数任务足够了 |
数学形式
最后一层权重 的 Laplace 近似:
其中 是训练得到的最后一层权重, 通过 Hessian/K-FAC 近似。
预测分布:
laplace-torch 使用示例
# 安装:pip install laplace-torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from laplace import Laplace
# 1. 训练基础模型
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x)
model = MLP()
# ... 训练模型 ...
# 2. 使用 Laplace 近似
la = Laplace(model, 'classification',
subset_of_weights='last_layer', # 只对最后一层
hessian_structure='diag') # 对角近似
# 3. 拟合 Laplace
la.fit(train_loader)
# 4. 预测(包含不确定性)
predictions = la.predict(test_x)
# predictions 包含预测概率和预测方差
# 5. OOD 检测
ood_score = la.predictive_entropy(test_x) # 预测熵完整示例
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from laplace import Laplace
from laplace.utils import LargestVarianceDiag
# 定义网络
class Net(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Linear(1, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
)
self.output = nn.Linear(64, 1)
def forward(self, x):
x = self.features(x)
return self.output(x)
# 训练数据
torch.manual_seed(42)
X_train = torch.linspace(-5, 5, 100).unsqueeze(1)
y_train = torch.sin(X_train) + 0.3 * torch.randn_like(X_train)
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=32)
# 训练模型
model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(100):
for x, y in train_loader:
optimizer.zero_grad()
pred = model(x)
loss = F.mse_loss(pred, y)
loss.backward()
optimizer.step()
# 使用 Laplace 近似(last-layer, 对角)
la = Laplace(model, 'regression',
subset_ofweights='last_layer',
hessian_structure='diag')
la.fit(train_loader)
# 预测(包含不确定性)
X_test = torch.linspace(-6, 6, 200).unsqueeze(1)
pred_mean, pred_std = la.predict(X_test)
# 可视化
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
plt.scatter(X_train, y_train, alpha=0.5, label='Training data')
plt.plot(X_test, pred_mean.squeeze(), 'r-', label='Mean prediction')
plt.fill_between(X_test.squeeze(),
(pred_mean - 2*pred_std).squeeze(),
(pred_mean + 2*pred_std).squeeze(),
alpha=0.3, label='±2 std')
plt.legend()
plt.show()
# 分析不确定性的组成
# 在训练数据附近:Epistemic uncertainty 小
# 在训练数据之外:Epistemic uncertainty 大不同近似方法的比较
| 方法 | 近似精度 | 计算成本 | 内存占用 | 适用场景 |
|---|---|---|---|---|
| Full Laplace | 高 | 小模型 | ||
| Diag Laplace | 中 | 中等模型 | ||
| KF-Laplace | 中高 | 大模型 | ||
| Last-layer | 中 | 超大模型 |
其中 是总参数量, 是最后一层参数量。
预测分布的计算
线性化预测
Laplace 近似下的预测分布通过线性化近似:
因此预测分布为:
其中 是数据噪声方差。
协方差传播
def laplace_predict(model, la, x, n_samples=100):
"""
Laplace 近似的预测
Returns:
mean: 预测均值
variance: 预测方差
"""
model.eval()
# 获取预测均值
with torch.no_grad():
mean = model(x)
# 获取雅可比
def jacobian(y, x):
return torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y),
create_graph=True)[0]
# 计算预测方差
# Var[y] = J @ Sigma @ J.T
variance = torch.zeros_like(mean)
for i in range(x.shape[0]):
J = jacobian(model(x[i:i+1]), model.output.weight)
var_i = J @ la.posterior_covariance @ J.T
variance[i] = var_i
return mean, variance核心公式速查
| 概念 | 公式 |
|---|---|
| Laplace 近似 | |
| GGN 近似 | |
| K-FAC 分解 | |
| 预测方差 |
参考
相关文章
- 贝叶斯神经网络 — BNN 基础概念
- MC Dropout — Dropout 的贝叶斯解释
- Bayes by Backprop — 变分推断方法
- 变分推断 — VI 理论基础