Laplace 近似与贝叶斯神经网络

Laplace 近似是贝叶斯统计中一种经典的后验近似方法,通过用高斯分布逼近复杂后验来实现贝叶斯推断。对于神经网络,Laplace 近似将训练好的网络权重附近的后验分布用高斯分布建模,从而获得预测不确定性。1

Laplace 近似基础

拉普拉斯方法

对于难以精确计算的后验分布 ,Laplace 近似的核心思想是:

  1. 找到后验的峰值(MAP 估计)
  2. 在峰值处用二阶泰勒展开近似对数后验
  3. 得到的近似分布是高斯分布

数学推导

后验分布的对数:

在 MAP 估计 处进行二阶泰勒展开:

其中 Hessian 矩阵(对数后验的二阶导数):

因此后验近似为高斯分布:

关键挑战:Hessian 矩阵

对于大型神经网络:

  • 权重数量: -
  • Hessian 矩阵大小: -
  • 无法直接存储和求逆

神经网络的 Hessian

Hessian 的组成

对于分类任务(交叉熵损失),Hessian 可以分解为:

Generalized Gauss-Newton (GGN) 近似

精确的 Hessian 计算成本极高。GGN 近似使用 Fisher 信息矩阵代替:

优势:GGN 是 Hessian 的上界(Positive Semi-Definite),且计算更高效。

def ggn_matrix(model, dataloader, device='cuda'):
    """
    计算 Generalized Gauss-Newton 矩阵
    
    注意:返回的是 GGN 的对角近似
    """
    model.eval()
    
    # 收集梯度
    gradients = []
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        
        output = model(x)
        loss = F.cross_entropy(output, y)
        
        model.zero_grad()
        loss.backward()
        
        # 收集梯度
        grad = torch.cat([p.grad.flatten() for p in model.parameters()])
        gradients.append(grad)
    
    # GGN 的对角近似
    gradients = torch.stack(gradients)
    ggn_diag = (gradients ** 2).mean(dim=0)
    
    return ggn_diag

Kronecker-Factored 近似

基本思想

为了避免存储完整的 Hessian/K-FAC 将 Hessian 分解为层间独立的 Kronecker 积:

  • :输入激活的相关矩阵
  • :输出梯度的相关矩阵

K-FAC 的核心公式

对于第 层的权重

其中:

  • :第 个样本在第 层输入的激活
  • :第 层权重的梯度

PyTorch 实现(简化版)

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
 
class KFACPreconditioner:
    """
    Kronecker-Factored Approximate Curvature (K-FAC) 预条件器
    
    用于自然梯度下降和 Laplace 近似
    """
    
    def __init__(self, model, damping=1e-3):
        self.model = model
        self.damping = damping
        self.steps = 0
        
        # 存储 A 和 G 矩阵
        self.m_A = {}  # 输入协方差
        self.m_G = {}  # 输出梯度协方差
        self.momentum = {}
    
    def register(self):
        """注册 Hooks 用于收集统计量"""
        def forward_hook(module, input, output):
            if isinstance(module, nn.Linear):
                self._save_input(module, input[0].detach())
        
        def backward_hook(module, grad_input, grad_output):
            if isinstance(module, nn.Linear):
                self._save_grad_output(module, grad_output[0].detach())
        
        self.handles = []
        for module in self.model.modules():
            if isinstance(module, nn.Linear):
                self.handles.append(module.register_forward_hook(forward_hook))
                self.handles.append(module.register_backward_hook(backward_hook))
    
    def _save_input(self, module, input):
        # 计算输入的协方差
        x = input.flatten(0, -2)  # (batch * seq, dim)
        A = (x.T @ x) / x.shape[0]  # (dim, dim)
        
        name = self._get_module_name(module)
        self.m_A[name] = A
    
    def _save_grad_output(self, module, grad_output):
        # 计算梯度的协方差
        g = grad_output.flatten(0, -2)
        G = (g.T @ g) / g.shape[0]
        
        name = self._get_module_name(module)
        self.m_G[name] = G
    
    def _get_module_name(self, module):
        for name, m in self.model.named_modules():
            if m is module:
                return name
        return None
    
    def update_running_stats(self):
        """更新 K-FAC 矩阵的移动平均"""
        self.steps += 1
        
        for name in self.m_A:
            # 指数移动平均
            if name not in self.momentum:
                self.momentum[name] = {'A': 0.9, 'G': 0.9}
            
            # 更新 A 和 G
            pass  # 实现更新逻辑
    
    def precondition_grad(self, named_parameters):
        """
        用 K-FAC 近似曲率预条件梯度
        
        返回预条件后的梯度
        """
        for name, param in named_parameters:
            if name in self.m_A and name in self.m_G:
                A = self.m_A[name] + self.damping * torch.eye(param.shape[1], device=param.device)
                G = self.m_G[name] + self.damping * torch.eye(param.shape[0], device=param.device)
                
                # Kronecker 逆
                # inv(A ⊗ G) = inv(A) ⊗ inv(G)
                grad = param.grad.view(param.shape[0], -1)
                
                # 预条件
                # grad_new = G^{-1} @ grad @ A^{-1}
                grad_new = torch.linalg.solve(G, grad)
                grad_new = torch.linalg.solve(A.T, grad_new.T).T
                
                param.grad.copy_(grad_new.view_as(param))

Last-layer Laplace

核心思想

Last-layer Laplace 只对神经网络的最后一层进行贝叶斯推断,将前面的层视为固定的特征提取器。2

输入 → [固定特征提取层] → [贝叶斯最后一层] → 输出

为什么只考虑最后一层?

考虑因素说明
计算效率只需存储最后一层的协方差矩阵
内存占用 而非 ,其中
实用性对于大多数任务足够了

数学形式

最后一层权重 的 Laplace 近似:

其中 是训练得到的最后一层权重, 通过 Hessian/K-FAC 近似。

预测分布:

laplace-torch 使用示例

# 安装:pip install laplace-torch
 
import torch
import torch.nn as nn
import torch.nn.functional as F
from laplace import Laplace
 
# 1. 训练基础模型
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 10)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)
 
model = MLP()
# ... 训练模型 ...
 
# 2. 使用 Laplace 近似
la = Laplace(model, 'classification', 
             subset_of_weights='last_layer',  # 只对最后一层
             hessian_structure='diag')         # 对角近似
 
# 3. 拟合 Laplace
la.fit(train_loader)
 
# 4. 预测(包含不确定性)
predictions = la.predict(test_x)
# predictions 包含预测概率和预测方差
 
# 5. OOD 检测
ood_score = la.predictive_entropy(test_x)  # 预测熵

完整示例

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from laplace import Laplace
from laplace.utils import LargestVarianceDiag
 
# 定义网络
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
        )
        self.output = nn.Linear(64, 1)
    
    def forward(self, x):
        x = self.features(x)
        return self.output(x)
 
# 训练数据
torch.manual_seed(42)
X_train = torch.linspace(-5, 5, 100).unsqueeze(1)
y_train = torch.sin(X_train) + 0.3 * torch.randn_like(X_train)
 
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=32)
 
# 训练模型
model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
for epoch in range(100):
    for x, y in train_loader:
        optimizer.zero_grad()
        pred = model(x)
        loss = F.mse_loss(pred, y)
        loss.backward()
        optimizer.step()
 
# 使用 Laplace 近似(last-layer, 对角)
la = Laplace(model, 'regression', 
             subset_ofweights='last_layer',
             hessian_structure='diag')
 
la.fit(train_loader)
 
# 预测(包含不确定性)
X_test = torch.linspace(-6, 6, 200).unsqueeze(1)
pred_mean, pred_std = la.predict(X_test)
 
# 可视化
import matplotlib.pyplot as plt
 
plt.figure(figsize=(10, 6))
plt.scatter(X_train, y_train, alpha=0.5, label='Training data')
plt.plot(X_test, pred_mean.squeeze(), 'r-', label='Mean prediction')
plt.fill_between(X_test.squeeze(), 
                  (pred_mean - 2*pred_std).squeeze(),
                  (pred_mean + 2*pred_std).squeeze(), 
                  alpha=0.3, label='±2 std')
plt.legend()
plt.show()
 
# 分析不确定性的组成
# 在训练数据附近:Epistemic uncertainty 小
# 在训练数据之外:Epistemic uncertainty 大

不同近似方法的比较

方法近似精度计算成本内存占用适用场景
Full Laplace小模型
Diag Laplace中等模型
KF-Laplace中高大模型
Last-layer超大模型

其中 是总参数量, 是最后一层参数量。

预测分布的计算

线性化预测

Laplace 近似下的预测分布通过线性化近似:

因此预测分布为:

其中 是数据噪声方差。

协方差传播

def laplace_predict(model, la, x, n_samples=100):
    """
    Laplace 近似的预测
    
    Returns:
        mean: 预测均值
        variance: 预测方差
    """
    model.eval()
    
    # 获取预测均值
    with torch.no_grad():
        mean = model(x)
    
    # 获取雅可比
    def jacobian(y, x):
        return torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y),
                                   create_graph=True)[0]
    
    # 计算预测方差
    # Var[y] = J @ Sigma @ J.T
    variance = torch.zeros_like(mean)
    
    for i in range(x.shape[0]):
        J = jacobian(model(x[i:i+1]), model.output.weight)
        var_i = J @ la.posterior_covariance @ J.T
        variance[i] = var_i
    
    return mean, variance

核心公式速查

概念公式
Laplace 近似
GGN 近似
K-FAC 分解
预测方差

参考

相关文章

Footnotes

  1. Daxberger, E., et al. (2021). “Laplace Redux — Effortless Bayesian Deep Learning”. NeurIPS 2021.

  2. Kristiadi, A., et al. (2020). “Being Bayesian, Even Just a Bit Counts When Estimating Uncertainty”. arxiv:2006.10577.