概述
贝叶斯最后一层(Bayesian Last Layer, BLL)方法是一种混合贝叶斯-深度学习策略,只对网络的最后一层进行贝叶斯建模,同时保持其他层作为确定性特征提取器。1 这种方法在保持深度学习强大表达能力的同时,为预测提供了有原则的不确定性估计,特别适合大规模模型的高效部署。
为什么需要贝叶斯最后一层?
完整贝叶斯神经网络的挑战
对整个深度神经网络进行贝叶斯推断面临以下挑战:
- 参数规模巨大:现代大模型可能有数十亿参数
- 推断成本高昂:完整贝叶斯推断需要采样整个网络
- 后验近似困难:高维空间的变分推断质量难以保证
BLL的核心思想
BLL通过任务分解来简化问题:
- 特征提取层:使用预训练/微调的深度网络提取判别性特征
- 贝叶斯输出层:只对最后一层进行贝叶斯建模
┌─────────────────────────────────────────────────────┐
│ 完整深度网络 │
│ │
│ ┌──────────┐ ┌──────────┐ ┌──────────┐ │
│ │ Input │→ │ Hidden │→ │ Hidden │→ → ... │
│ │ Layer │ │ Layer 1 │ │ Layer 2 │ │
│ └──────────┘ └──────────┘ └──────────┘ │
│ │
│ ↓ (冻结/微调特征) │
│ │
│ ┌──────────┐ │
│ │ Features │ ← 确定性的特征表示 │
│ └────┬─────┘ │
│ │ │
│ ↓ │
│ ┌──────────┐ │
│ │ BLL │ ← 贝叶斯建模的最后一层 │
│ │ (Bayes) │ │
│ └────┬─────┘ │
│ │ │
│ ↓ │
│ ┌──────────┐ │
│ │ Output │ ← 预测 + 不确定性 │
│ └──────────┘ │
└─────────────────────────────────────────────────────┘
数学框架
问题形式化
设深度网络为 ,其中:
- 是特征提取层,参数为
- 是最后一层,参数为
BLL的目标是:
其中 是最后一层权重的后验分布。
关键假设
- 条件独立假设:给定最后一层权重 ,特征与输出条件独立:
- 特征确定性假设:前面的层参数 被视为确定的(或使用点估计)
预测分布
给定输入 ,预测分布为:
其中 是从变分后验采样的权重。
方法详解
1. 线性高斯BLL
最简单的BLL形式是假设最后一层是线性层,输出服从高斯分布。2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, Gamma
class BayesianLastLayerLinear(nn.Module):
"""
线性贝叶斯最后一层
最后一层: y = W @ h + b
先验: W_ij ~ N(0, α⁻¹), b_j ~ N(0, α⁻¹)
"""
def __init__(self, in_features, out_features, prior_precision=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# 均值参数
self.W_mean = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
self.b_mean = nn.Parameter(torch.zeros(out_features))
# 精度参数(对数方差)
self.W_log_precision = nn.Parameter(torch.zeros(out_features, in_features))
self.b_log_precision = nn.Parameter(torch.zeros(out_features))
# 先验精度
self.prior_precision = prior_precision
def forward(self, h, samples=1):
"""
前向传播
Args:
h: 特征输入 (batch_size, in_features)
samples: MC采样数量
"""
batch_size = h.shape[0]
# 从后验采样
W_samples = self._sample_weights(self.W_mean, self.W_log_precision, samples)
b_samples = self._sample_weights(self.b_mean, self.b_log_precision, samples)
# 计算输出 (samples, batch, out_features)
h_expanded = h.unsqueeze(0).expand(samples, -1, -1)
outputs = torch.bmm(h_expanded, W_samples.transpose(-2, -1)) + b_samples.unsqueeze(0).unsqueeze(1)
return outputs
def _sample_weights(self, mean, log_precision, num_samples):
"""从对角高斯后验采样"""
precision = torch.exp(log_precision)
std = 1.0 / torch.sqrt(precision)
# 采样形状: (num_samples, *shape)
eps = torch.randn(num_samples, *mean.shape, device=mean.device, dtype=mean.dtype)
samples = mean.unsqueeze(0) + std.unsqueeze(0) * eps
return samples
def kl_divergence(self):
"""
计算与先验的KL散度
KL(q(W) || p(W)) = sum_ij 0.5 * precision_ij * (sigma_ij² + (mu_ij)²) - 0.5 * log(precision_ij / prior_precision)
"""
precision = torch.exp(self.W_log_precision)
var = 1.0 / precision
# KL(W)
kl_W = 0.5 * precision * (var + self.W_mean ** 2) - 0.5 * (
self.W_log_precision - torch.log(torch.tensor(self.prior_precision))
)
# KL(b)
precision_b = torch.exp(self.b_log_precision)
var_b = 1.0 / precision_b
kl_b = 0.5 * precision_b * (var_b + self.b_mean ** 2) - 0.5 * (
self.b_log_precision - torch.log(torch.tensor(self.prior_precision))
)
return kl_W.sum() + kl_b.sum()
def elbo(self, h, y, num_samples=10):
"""
计算证据下界 (ELBO)
ELBO = E_q[log p(y|h,w)] - KL(q(w)||p(w))
"""
outputs = self.forward(h, samples=num_samples)
# 对数似然(使用平均池化来近似期望)
log_likelihood = F.cross_entropy(
outputs.mean(dim=0), # 使用均值
y,
reduction='mean'
)
# KL散度
kl = self.kl_divergence()
# 归一化的ELBO
batch_size = h.shape[0]
return -log_likelihood + kl / batch_size2. 回归任务的BLL
对于回归任务,BLL可以同时建模均值和方差(异方差回归):
class HeteroscedasticBLL(nn.Module):
"""
异方差贝叶斯最后一层
输出均值和方差,实现自适应不确定性
"""
def __init__(self, in_features, out_features):
super().__init__()
# 均值输出
self.mean_layer = nn.Linear(in_features, out_features)
# 对数方差输出(确保非负)
self.log_var_layer = nn.Sequential(
nn.Linear(in_features, out_features),
nn.Softplus() # 确保方差为正
)
# 先验精度
self.prior_log_precision = nn.Parameter(torch.zeros(1))
def forward(self, h, num_samples=1):
"""
前向传播
"""
mean = self.mean_layer(h)
log_var = torch.log(self.log_var_layer(h) + 1e-6)
if num_samples > 1 and self.training:
# MC采样
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(mean)
samples = mean + std * eps
return samples, mean, log_var
else:
return mean, mean, log_var
def loss(self, h, y, num_samples=5):
"""
异方差高斯回归的负ELBO
log p(y|x) >= -0.5 * sum((y - mean)² / var) - 0.5 * log var + const
"""
mean, mean_no_sample, log_var = self.forward(h, num_samples)
# 使用MC平均的样本计算损失
if self.training:
# MC估计的负对数似然
nll = 0.5 * ((y.unsqueeze(0) - mean) ** 2 * torch.exp(-log_var.unsqueeze(0)) + log_var.unsqueeze(0)).mean(dim=0)
else:
# 确定性情况
nll = 0.5 * ((y - mean_no_sample) ** 2 * torch.exp(-log_var) + log_var)
return nll.mean()3. 分类任务的BLL
对于分类任务,BLL结合softmax输出:
class BayesianLastLayerClassifier(nn.Module):
"""
分类任务的贝叶斯最后一层
"""
def __init__(self, in_features, num_classes, temperature=1.0):
super().__init__()
self.temperature = temperature
self.num_classes = num_classes
# 线性层参数
self.W = nn.Parameter(torch.randn(num_classes, in_features) * 0.01)
self.b = nn.Parameter(torch.zeros(num_classes))
# 权重精度(对角协方差)
self.W_log_precision = nn.Parameter(torch.zeros(num_classes, in_features))
self.b_log_precision = nn.Parameter(torch.zeros(num_classes))
def forward(self, h, num_samples=1):
"""
返回预测分布
"""
# 采样权重
W_sample = self._sample(self.W, self.W_log_precision, num_samples)
b_sample = self._sample(self.b, self.b_log_precision, num_samples)
# 计算logits
h_expanded = h.unsqueeze(0).expand(num_samples, -1, -1)
logits = torch.bmm(h_expanded, W_sample.transpose(-2, -1)) + b_sample.unsqueeze(0).unsqueeze(1)
# softmax(带温度缩放)
probs = F.softmax(logits / self.temperature, dim=-1)
return probs, logits
def _sample(self, mean, log_precision, num_samples):
"""从后验采样"""
precision = torch.exp(log_precision)
std = 1.0 / torch.sqrt(precision)
eps = torch.randn(num_samples, *mean.shape, device=mean.device)
return mean.unsqueeze(0) + std.unsqueeze(0) * eps
def predict(self, h, num_samples=100):
"""
预测:返回类别和置信度
"""
probs, _ = self.forward(h, num_samples)
# 平均预测概率
avg_probs = probs.mean(dim=0)
# 预测类别
predictions = avg_probs.argmax(dim=-1)
# 置信度(最大概率)
confidence = avg_probs.max(dim=-1)[0]
# 不确定性(熵)
entropy = -(avg_probs * torch.log(avg_probs + 1e-10)).sum(dim=-1)
return predictions, confidence, entropy训练策略
1. 两阶段训练
class TwoStageBLL:
"""
两阶段训练策略
阶段1: 训练特征提取器
阶段2: 固定特征,训练BLL
"""
def __init__(self, backbone, bll_layer):
self.backbone = backbone
self.bll_layer = bll_layer
# 冻结backbone
for param in self.backbone.parameters():
param.requires_grad = False
def train_backbone(self, train_loader, epochs=100):
"""阶段1: 训练backbone"""
for param in self.backbone.parameters():
param.requires_grad = True
optimizer = torch.optim.Adam(self.backbone.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
self.backbone.train()
for epoch in range(epochs):
for x, y in train_loader:
optimizer.zero_grad()
features = self.backbone(x)
loss = criterion(features, y)
loss.backward()
optimizer.step()
def train_bll(self, train_loader, epochs=50):
"""阶段2: 训练BLL"""
for param in self.backbone.parameters():
param.requires_grad = False
for param in self.bll_layer.parameters():
param.requires_grad = True
optimizer = torch.optim.Adam(self.bll_layer.parameters(), lr=1e-2)
self.backbone.eval()
self.bll_layer.train()
for epoch in range(epochs):
total_loss = 0
for x, y in train_loader:
optimizer.zero_grad()
with torch.no_grad():
features = self.backbone(x)
loss = self.bll_layer.elbo(features, y)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch}, Loss: {total_loss / len(train_loader):.4f}")2. 端到端训练
对于需要联合优化的场景:
class EndToEndBLL:
"""
端到端训练策略
允许backbone和BLL联合优化
"""
def __init__(self, backbone, bll_layer, backbone_lr=1e-4, bll_lr=1e-2):
self.backbone = backbone
self.bll_layer = bll_layer
# 不同的学习率
self.optimizer = torch.optim.Adam([
{'params': backbone.parameters(), 'lr': backbone_lr},
{'params': bll_layer.parameters(), 'lr': bll_lr}
])
def train_step(self, x, y, num_samples=5):
"""单步训练"""
self.optimizer.zero_grad()
features = self.backbone(x)
loss = self.bll_layer.elbo(features, y, num_samples=num_samples)
loss.backward()
self.optimizer.step()
return loss.item()大规模应用
与预训练模型的结合
BLL可以方便地与预训练模型结合:
class PretrainedBLL:
"""
基于预训练模型的BLL
"""
def __init__(self, backbone_name='resnet50', num_classes=10):
# 加载预训练backbone
self.backbone = timm.create_model(backbone_name, pretrained=True, num_classes=0)
in_features = self.backbone.num_features
# 添加BLL
self.bll = BayesianLastLayerLinear(in_features, num_classes)
def finetune(self, train_loader, val_loader, epochs=10):
"""微调整个模型"""
# 可选:解冻最后几层
# ...
for epoch in range(epochs):
# 训练
self.bll.train()
for x, y in train_loader:
features = self.backbone(x)
loss = self.bll.elbo(features, y)
# 反向传播和优化...
# 验证
self.evaluate(val_loader)
def predict_with_uncertainty(self, x, num_samples=100):
"""带不确定性的预测"""
self.backbone.eval()
self.bll.eval()
with torch.no_grad():
features = self.backbone(x)
predictions, confidence, entropy = self.bll.predict(features, num_samples)
return predictions, confidence, entropy模型部署
class BLLDeployment:
"""
BLL模型部署工具
"""
def __init__(self, model, device='cuda'):
self.model = model
self.device = device
self.model.to(device)
self.model.eval()
@torch.no_grad()
def predict(self, x, num_samples=50, batch_size=32):
"""
批量预测(内存高效)
"""
predictions = []
confidences = []
uncertainties = []
for i in range(0, len(x), batch_size):
batch = x[i:i+batch_size].to(self.device)
# 获取预测和不确定性
pred, conf, entropy = self._predict_batch(batch, num_samples)
predictions.append(pred.cpu())
confidences.append(conf.cpu())
uncertainties.append(entropy.cpu())
return (
torch.cat(predictions, dim=0),
torch.cat(confidences, dim=0),
torch.cat(uncertainties, dim=0)
)
def _predict_batch(self, x, num_samples):
"""单批预测"""
features = self.model.backbone(x)
return self.model.bll.predict(features, num_samples)不确定性估计
不确定性分解
BLL估计的不确定性可以分解为:
- 偶然不确定性(Aleatoric):数据本身的噪声
- 认知不确定性(Epistemic):模型对参数的不确定
class UncertaintyDecomposition:
"""
不确定性分解
"""
def __init__(self, bll_model):
self.model = bll_model
def estimate(self, h, num_samples=100):
"""
估计总不确定性和分解
返回:
total_uncertainty: 总不确定性(熵)
aleatoric: 偶然不确定性
epistemic: 认知不确定性
"""
# 多次采样
probs_list = []
for _ in range(num_samples):
logits = self.model.W @ h.T + self.model.b
probs = F.softmax(logits, dim=0)
probs_list.append(probs)
probs_stack = torch.stack(probs_list, dim=0) # (samples, batch, classes)
# 总不确定性:使用平均概率的熵
avg_probs = probs_stack.mean(dim=0)
total_uncertainty = -(avg_probs * torch.log(avg_probs + 1e-10)).sum(dim=-1)
# 偶然不确定性:平均采样的熵
individual_entropy = -(probs_stack * torch.log(probs_stack + 1e-10)).sum(dim=-1)
aleatoric = individual_entropy.mean(dim=0)
# 认知不确定性:总 - 偶然
epistemic = total_uncertainty - aleatoric
return total_uncertainty, aleatoric, epistemicOOD检测
BLL的认知不确定性可以用于分布外检测:
class OODDetector:
"""
基于不确定性的OOD检测
"""
def __init__(self, bll_model, threshold=None):
self.model = bll_model
self.threshold = threshold
self._calibrate_threshold()
def _calibrate_threshold(self):
"""在校准集上确定阈值"""
# 使用验证集的最小认知不确定性作为阈值
# ...
self.threshold = 0.5 # 示例
@torch.no_grad()
def detect(self, x, num_samples=100):
"""
检测OOD样本
返回:
is_id: 是否是in-distribution
uncertainty: 认知不确定性分数
"""
features = self.model.backbone(x)
_, aleatoric, epistemic = self.estimate_uncertainty(features, num_samples)
is_id = epistemic < self.threshold
return is_id, epistemic实践指南
何时使用BLL
| 场景 | 推荐使用BLL | 原因 |
|---|---|---|
| 大模型微调 | ✅ | 只需训练最后一层 |
| 不确定性量化 | ✅ | 原生支持不确定性 |
| OOD检测 | ✅ | 认知不确定性可分离 |
| 实时应用 | ✅ | 推理开销小 |
| 完全贝叶斯建模 | ❌ | 需要完整贝叶斯NN |
超参数设置
| 超参数 | 推荐值 | 说明 |
|---|---|---|
| 先验精度 | 1.0 | 权重衰减强度 |
| 温度 | 0.5-1.0 | 预测分布锐度 |
| MC样本数 | 50-100 | 推断精度vs速度 |
| 学习率(BLL) | 0.01-0.1 | 通常高于backbone |
常见问题
- KL散度过大:降低先验精度或增加学习率
- 预测过于自信:增加温度参数
- 训练不稳定:使用梯度裁剪或学习率调度
总结
贝叶斯最后一层方法的核心要点:
| 特性 | 说明 |
|---|---|
| 核心思想 | 只对最后一层进行贝叶斯建模 |
| 优势 | 计算高效、易于实现、与预训练模型兼容 |
| 适用 | 大模型微调、不确定性量化、OOD检测 |
| 局限 | 不捕获特征层的不确定性 |
参考
相关文章
- bayesian-neural-networks-uncertainty — 贝叶斯神经网络基础
- bayesian-neural-networks-advanced-inference — 高级推断方法
- mc-dropout — MC Dropout技术
- uncertainty-quantification-deep-learning — 深度学习中的不确定性量化