概述

贝叶斯最后一层(Bayesian Last Layer, BLL)方法是一种混合贝叶斯-深度学习策略,只对网络的最后一层进行贝叶斯建模,同时保持其他层作为确定性特征提取器。1 这种方法在保持深度学习强大表达能力的同时,为预测提供了有原则的不确定性估计,特别适合大规模模型的高效部署。


为什么需要贝叶斯最后一层?

完整贝叶斯神经网络的挑战

对整个深度神经网络进行贝叶斯推断面临以下挑战:

  1. 参数规模巨大:现代大模型可能有数十亿参数
  2. 推断成本高昂:完整贝叶斯推断需要采样整个网络
  3. 后验近似困难:高维空间的变分推断质量难以保证

BLL的核心思想

BLL通过任务分解来简化问题:

  • 特征提取层:使用预训练/微调的深度网络提取判别性特征
  • 贝叶斯输出层:只对最后一层进行贝叶斯建模
┌─────────────────────────────────────────────────────┐
│                  完整深度网络                         │
│                                                     │
│  ┌──────────┐  ┌──────────┐  ┌──────────┐          │
│  │ Input    │→ │ Hidden  │→ │ Hidden  │→ → ...    │
│  │ Layer    │  │ Layer 1 │  │ Layer 2 │          │
│  └──────────┘  └──────────┘  └──────────┘          │
│                                                     │
│           ↓ (冻结/微调特征)                         │
│                                                     │
│  ┌──────────┐                                      │
│  │ Features │  ← 确定性的特征表示                    │
│  └────┬─────┘                                      │
│       │                                             │
│       ↓                                             │
│  ┌──────────┐                                      │
│  │ BLL      │  ← 贝叶斯建模的最后一层                │
│  │ (Bayes)  │                                      │
│  └────┬─────┘                                      │
│       │                                             │
│       ↓                                             │
│  ┌──────────┐                                      │
│  │ Output   │  ← 预测 + 不确定性                     │
│  └──────────┘                                      │
└─────────────────────────────────────────────────────┘

数学框架

问题形式化

设深度网络为 ,其中:

  • 是特征提取层,参数为
  • 是最后一层,参数为

BLL的目标是:

其中 是最后一层权重的后验分布。

关键假设

  1. 条件独立假设:给定最后一层权重 ,特征与输出条件独立:
  1. 特征确定性假设:前面的层参数 被视为确定的(或使用点估计)

预测分布

给定输入 ,预测分布为:

其中 是从变分后验采样的权重。


方法详解

1. 线性高斯BLL

最简单的BLL形式是假设最后一层是线性层,输出服从高斯分布。2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, Gamma
 
class BayesianLastLayerLinear(nn.Module):
    """
    线性贝叶斯最后一层
    
    最后一层: y = W @ h + b
    先验: W_ij ~ N(0, α⁻¹), b_j ~ N(0, α⁻¹)
    """
    def __init__(self, in_features, out_features, prior_precision=1.0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # 均值参数
        self.W_mean = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
        self.b_mean = nn.Parameter(torch.zeros(out_features))
        
        # 精度参数(对数方差)
        self.W_log_precision = nn.Parameter(torch.zeros(out_features, in_features))
        self.b_log_precision = nn.Parameter(torch.zeros(out_features))
        
        # 先验精度
        self.prior_precision = prior_precision
    
    def forward(self, h, samples=1):
        """
        前向传播
        
        Args:
            h: 特征输入 (batch_size, in_features)
            samples: MC采样数量
        """
        batch_size = h.shape[0]
        
        # 从后验采样
        W_samples = self._sample_weights(self.W_mean, self.W_log_precision, samples)
        b_samples = self._sample_weights(self.b_mean, self.b_log_precision, samples)
        
        # 计算输出 (samples, batch, out_features)
        h_expanded = h.unsqueeze(0).expand(samples, -1, -1)
        outputs = torch.bmm(h_expanded, W_samples.transpose(-2, -1)) + b_samples.unsqueeze(0).unsqueeze(1)
        
        return outputs
    
    def _sample_weights(self, mean, log_precision, num_samples):
        """从对角高斯后验采样"""
        precision = torch.exp(log_precision)
        std = 1.0 / torch.sqrt(precision)
        
        # 采样形状: (num_samples, *shape)
        eps = torch.randn(num_samples, *mean.shape, device=mean.device, dtype=mean.dtype)
        samples = mean.unsqueeze(0) + std.unsqueeze(0) * eps
        
        return samples
    
    def kl_divergence(self):
        """
        计算与先验的KL散度
        
        KL(q(W) || p(W)) = sum_ij 0.5 * precision_ij * (sigma_ij² + (mu_ij)²) - 0.5 * log(precision_ij / prior_precision)
        """
        precision = torch.exp(self.W_log_precision)
        var = 1.0 / precision
        
        # KL(W)
        kl_W = 0.5 * precision * (var + self.W_mean ** 2) - 0.5 * (
            self.W_log_precision - torch.log(torch.tensor(self.prior_precision))
        )
        
        # KL(b)
        precision_b = torch.exp(self.b_log_precision)
        var_b = 1.0 / precision_b
        kl_b = 0.5 * precision_b * (var_b + self.b_mean ** 2) - 0.5 * (
            self.b_log_precision - torch.log(torch.tensor(self.prior_precision))
        )
        
        return kl_W.sum() + kl_b.sum()
    
    def elbo(self, h, y, num_samples=10):
        """
        计算证据下界 (ELBO)
        
        ELBO = E_q[log p(y|h,w)] - KL(q(w)||p(w))
        """
        outputs = self.forward(h, samples=num_samples)
        
        # 对数似然(使用平均池化来近似期望)
        log_likelihood = F.cross_entropy(
            outputs.mean(dim=0),  # 使用均值
            y,
            reduction='mean'
        )
        
        # KL散度
        kl = self.kl_divergence()
        
        # 归一化的ELBO
        batch_size = h.shape[0]
        return -log_likelihood + kl / batch_size

2. 回归任务的BLL

对于回归任务,BLL可以同时建模均值和方差(异方差回归):

class HeteroscedasticBLL(nn.Module):
    """
    异方差贝叶斯最后一层
    
    输出均值和方差,实现自适应不确定性
    """
    def __init__(self, in_features, out_features):
        super().__init__()
        
        # 均值输出
        self.mean_layer = nn.Linear(in_features, out_features)
        
        # 对数方差输出(确保非负)
        self.log_var_layer = nn.Sequential(
            nn.Linear(in_features, out_features),
            nn.Softplus()  # 确保方差为正
        )
        
        # 先验精度
        self.prior_log_precision = nn.Parameter(torch.zeros(1))
    
    def forward(self, h, num_samples=1):
        """
        前向传播
        """
        mean = self.mean_layer(h)
        log_var = torch.log(self.log_var_layer(h) + 1e-6)
        
        if num_samples > 1 and self.training:
            # MC采样
            std = torch.exp(0.5 * log_var)
            eps = torch.randn_like(mean)
            samples = mean + std * eps
            return samples, mean, log_var
        else:
            return mean, mean, log_var
    
    def loss(self, h, y, num_samples=5):
        """
        异方差高斯回归的负ELBO
        
        log p(y|x) >= -0.5 * sum((y - mean)² / var) - 0.5 * log var + const
        """
        mean, mean_no_sample, log_var = self.forward(h, num_samples)
        
        # 使用MC平均的样本计算损失
        if self.training:
            # MC估计的负对数似然
            nll = 0.5 * ((y.unsqueeze(0) - mean) ** 2 * torch.exp(-log_var.unsqueeze(0)) + log_var.unsqueeze(0)).mean(dim=0)
        else:
            # 确定性情况
            nll = 0.5 * ((y - mean_no_sample) ** 2 * torch.exp(-log_var) + log_var)
        
        return nll.mean()

3. 分类任务的BLL

对于分类任务,BLL结合softmax输出:

class BayesianLastLayerClassifier(nn.Module):
    """
    分类任务的贝叶斯最后一层
    """
    def __init__(self, in_features, num_classes, temperature=1.0):
        super().__init__()
        
        self.temperature = temperature
        self.num_classes = num_classes
        
        # 线性层参数
        self.W = nn.Parameter(torch.randn(num_classes, in_features) * 0.01)
        self.b = nn.Parameter(torch.zeros(num_classes))
        
        # 权重精度(对角协方差)
        self.W_log_precision = nn.Parameter(torch.zeros(num_classes, in_features))
        self.b_log_precision = nn.Parameter(torch.zeros(num_classes))
    
    def forward(self, h, num_samples=1):
        """
        返回预测分布
        """
        # 采样权重
        W_sample = self._sample(self.W, self.W_log_precision, num_samples)
        b_sample = self._sample(self.b, self.b_log_precision, num_samples)
        
        # 计算logits
        h_expanded = h.unsqueeze(0).expand(num_samples, -1, -1)
        logits = torch.bmm(h_expanded, W_sample.transpose(-2, -1)) + b_sample.unsqueeze(0).unsqueeze(1)
        
        # softmax(带温度缩放)
        probs = F.softmax(logits / self.temperature, dim=-1)
        
        return probs, logits
    
    def _sample(self, mean, log_precision, num_samples):
        """从后验采样"""
        precision = torch.exp(log_precision)
        std = 1.0 / torch.sqrt(precision)
        
        eps = torch.randn(num_samples, *mean.shape, device=mean.device)
        return mean.unsqueeze(0) + std.unsqueeze(0) * eps
    
    def predict(self, h, num_samples=100):
        """
        预测:返回类别和置信度
        """
        probs, _ = self.forward(h, num_samples)
        
        # 平均预测概率
        avg_probs = probs.mean(dim=0)
        
        # 预测类别
        predictions = avg_probs.argmax(dim=-1)
        
        # 置信度(最大概率)
        confidence = avg_probs.max(dim=-1)[0]
        
        # 不确定性(熵)
        entropy = -(avg_probs * torch.log(avg_probs + 1e-10)).sum(dim=-1)
        
        return predictions, confidence, entropy

训练策略

1. 两阶段训练

class TwoStageBLL:
    """
    两阶段训练策略
    
    阶段1: 训练特征提取器
    阶段2: 固定特征,训练BLL
    """
    def __init__(self, backbone, bll_layer):
        self.backbone = backbone
        self.bll_layer = bll_layer
        
        # 冻结backbone
        for param in self.backbone.parameters():
            param.requires_grad = False
    
    def train_backbone(self, train_loader, epochs=100):
        """阶段1: 训练backbone"""
        for param in self.backbone.parameters():
            param.requires_grad = True
        
        optimizer = torch.optim.Adam(self.backbone.parameters(), lr=1e-3)
        criterion = nn.CrossEntropyLoss()
        
        self.backbone.train()
        for epoch in range(epochs):
            for x, y in train_loader:
                optimizer.zero_grad()
                features = self.backbone(x)
                loss = criterion(features, y)
                loss.backward()
                optimizer.step()
    
    def train_bll(self, train_loader, epochs=50):
        """阶段2: 训练BLL"""
        for param in self.backbone.parameters():
            param.requires_grad = False
        for param in self.bll_layer.parameters():
            param.requires_grad = True
        
        optimizer = torch.optim.Adam(self.bll_layer.parameters(), lr=1e-2)
        
        self.backbone.eval()
        self.bll_layer.train()
        
        for epoch in range(epochs):
            total_loss = 0
            for x, y in train_loader:
                optimizer.zero_grad()
                
                with torch.no_grad():
                    features = self.backbone(x)
                
                loss = self.bll_layer.elbo(features, y)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            print(f"Epoch {epoch}, Loss: {total_loss / len(train_loader):.4f}")

2. 端到端训练

对于需要联合优化的场景:

class EndToEndBLL:
    """
    端到端训练策略
    
    允许backbone和BLL联合优化
    """
    def __init__(self, backbone, bll_layer, backbone_lr=1e-4, bll_lr=1e-2):
        self.backbone = backbone
        self.bll_layer = bll_layer
        
        # 不同的学习率
        self.optimizer = torch.optim.Adam([
            {'params': backbone.parameters(), 'lr': backbone_lr},
            {'params': bll_layer.parameters(), 'lr': bll_lr}
        ])
    
    def train_step(self, x, y, num_samples=5):
        """单步训练"""
        self.optimizer.zero_grad()
        
        features = self.backbone(x)
        loss = self.bll_layer.elbo(features, y, num_samples=num_samples)
        
        loss.backward()
        self.optimizer.step()
        
        return loss.item()

大规模应用

与预训练模型的结合

BLL可以方便地与预训练模型结合:

class PretrainedBLL:
    """
    基于预训练模型的BLL
    """
    def __init__(self, backbone_name='resnet50', num_classes=10):
        # 加载预训练backbone
        self.backbone = timm.create_model(backbone_name, pretrained=True, num_classes=0)
        in_features = self.backbone.num_features
        
        # 添加BLL
        self.bll = BayesianLastLayerLinear(in_features, num_classes)
    
    def finetune(self, train_loader, val_loader, epochs=10):
        """微调整个模型"""
        # 可选:解冻最后几层
        # ...
        
        for epoch in range(epochs):
            # 训练
            self.bll.train()
            for x, y in train_loader:
                features = self.backbone(x)
                loss = self.bll.elbo(features, y)
                # 反向传播和优化...
            
            # 验证
            self.evaluate(val_loader)
    
    def predict_with_uncertainty(self, x, num_samples=100):
        """带不确定性的预测"""
        self.backbone.eval()
        self.bll.eval()
        
        with torch.no_grad():
            features = self.backbone(x)
            predictions, confidence, entropy = self.bll.predict(features, num_samples)
        
        return predictions, confidence, entropy

模型部署

class BLLDeployment:
    """
    BLL模型部署工具
    """
    def __init__(self, model, device='cuda'):
        self.model = model
        self.device = device
        self.model.to(device)
        self.model.eval()
    
    @torch.no_grad()
    def predict(self, x, num_samples=50, batch_size=32):
        """
        批量预测(内存高效)
        """
        predictions = []
        confidences = []
        uncertainties = []
        
        for i in range(0, len(x), batch_size):
            batch = x[i:i+batch_size].to(self.device)
            
            # 获取预测和不确定性
            pred, conf, entropy = self._predict_batch(batch, num_samples)
            
            predictions.append(pred.cpu())
            confidences.append(conf.cpu())
            uncertainties.append(entropy.cpu())
        
        return (
            torch.cat(predictions, dim=0),
            torch.cat(confidences, dim=0),
            torch.cat(uncertainties, dim=0)
        )
    
    def _predict_batch(self, x, num_samples):
        """单批预测"""
        features = self.model.backbone(x)
        return self.model.bll.predict(features, num_samples)

不确定性估计

不确定性分解

BLL估计的不确定性可以分解为:

  1. 偶然不确定性(Aleatoric):数据本身的噪声
  2. 认知不确定性(Epistemic):模型对参数的不确定
class UncertaintyDecomposition:
    """
    不确定性分解
    """
    def __init__(self, bll_model):
        self.model = bll_model
    
    def estimate(self, h, num_samples=100):
        """
        估计总不确定性和分解
        
        返回:
            total_uncertainty: 总不确定性(熵)
            aleatoric: 偶然不确定性
            epistemic: 认知不确定性
        """
        # 多次采样
        probs_list = []
        for _ in range(num_samples):
            logits = self.model.W @ h.T + self.model.b
            probs = F.softmax(logits, dim=0)
            probs_list.append(probs)
        
        probs_stack = torch.stack(probs_list, dim=0)  # (samples, batch, classes)
        
        # 总不确定性:使用平均概率的熵
        avg_probs = probs_stack.mean(dim=0)
        total_uncertainty = -(avg_probs * torch.log(avg_probs + 1e-10)).sum(dim=-1)
        
        # 偶然不确定性:平均采样的熵
        individual_entropy = -(probs_stack * torch.log(probs_stack + 1e-10)).sum(dim=-1)
        aleatoric = individual_entropy.mean(dim=0)
        
        # 认知不确定性:总 - 偶然
        epistemic = total_uncertainty - aleatoric
        
        return total_uncertainty, aleatoric, epistemic

OOD检测

BLL的认知不确定性可以用于分布外检测:

class OODDetector:
    """
    基于不确定性的OOD检测
    """
    def __init__(self, bll_model, threshold=None):
        self.model = bll_model
        self.threshold = threshold
        self._calibrate_threshold()
    
    def _calibrate_threshold(self):
        """在校准集上确定阈值"""
        # 使用验证集的最小认知不确定性作为阈值
        # ...
        self.threshold = 0.5  # 示例
    
    @torch.no_grad()
    def detect(self, x, num_samples=100):
        """
        检测OOD样本
        
        返回:
            is_id: 是否是in-distribution
            uncertainty: 认知不确定性分数
        """
        features = self.model.backbone(x)
        _, aleatoric, epistemic = self.estimate_uncertainty(features, num_samples)
        
        is_id = epistemic < self.threshold
        
        return is_id, epistemic

实践指南

何时使用BLL

场景推荐使用BLL原因
大模型微调只需训练最后一层
不确定性量化原生支持不确定性
OOD检测认知不确定性可分离
实时应用推理开销小
完全贝叶斯建模需要完整贝叶斯NN

超参数设置

超参数推荐值说明
先验精度 1.0权重衰减强度
温度 0.5-1.0预测分布锐度
MC样本数50-100推断精度vs速度
学习率(BLL)0.01-0.1通常高于backbone

常见问题

  1. KL散度过大:降低先验精度或增加学习率
  2. 预测过于自信:增加温度参数
  3. 训练不稳定:使用梯度裁剪或学习率调度

总结

贝叶斯最后一层方法的核心要点:

特性说明
核心思想只对最后一层进行贝叶斯建模
优势计算高效、易于实现、与预训练模型兼容
适用大模型微调、不确定性量化、OOD检测
局限不捕获特征层的不确定性

参考


相关文章

Footnotes

  1. Bayesian Last Layer: A Simple and Efficient Approach to Bayesian Deep Learning (2019)

  2. Bayesian Layers: A Module for Efficient Neural Network Uncertainty (NeurIPS 2019 Demo)