Bayes by Backprop

Bayes by Backprop 是由 Blundell 等人在 2015 年提出的变分推断方法,用于学习神经网络权重的概率分布。1 核心思想是用一个参数化的变分分布 逼近真实后验 ,并通过重参数化技巧实现端到端的梯度优化。

变分推断框架

目标函数

寻找最优变分参数 最小化 KL 散度:

根据贝叶斯公式:

展开 KL 散度:

由于 不依赖于 ,最小化 KL 散度等价于最大化证据下界(ELBO)

直观理解

ELBO = 重构能力 - 与先验的偏离

- 重构项高 → 模型能很好地拟合数据
- KL项低 → 后验接近先验(避免过拟合)

均值场近似

分解假设

Bayes by Backprop 使用均值场近似

其中 是每个权重的变分参数。

高斯先验

通常选择高斯先验:

KL 散度的闭式解

对于两个高斯分布,KL 散度有解析表达式:

对于所有权重求和:

def kl_divergence_gaussian(mu, log_var, prior_std=1.0):
    """
    计算高斯先验下的 KL 散度
    
    D_KL(N(mu, sigma²) || N(0, sigma_p²))
    """
    prior_var = prior_std ** 2
    var = torch.exp(log_var)
    
    kl = 0.5 * (
        log_var - torch.log(torch.tensor(prior_var))
        + (var + mu ** 2) / prior_var
        - 1.0
    )
    return kl.sum()

重参数化技巧

问题

KL 散度项可以直接计算,但重构项涉及期望:

其中 是随机变量,无法直接对 求导。

解决方案

使用重参数化技巧(Reparameterization Trick):

将随机变量表示为确定性变换:

这样 的随机性来自 ,而 是确定性参数。

梯度估计

通过蒙特卡洛采样估计:

def reparameterize(mu, log_var):
    """
    重参数化采样
    
    Args:
        mu: 均值 (任意形状)
        log_var: 对数方差 (任意形状)
    
    Returns:
        采样的权重 (与 mu 形状相同)
    """
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    return mu + eps * std

Bayes by Backprop 算法

完整的损失函数

其中 是第 次采样的权重。

算法流程

def bayes_by_backprop_loss(model, x, y, n_samples=1, lambda_reg=1.0):
    """
    Bayes by Backprop 损失函数
    
    Args:
        model: 贝叶斯神经网络
        x, y: 数据
        n_samples: 采样次数
        lambda_reg: KL 正则化权重
    """
    total_loss = 0.0
    total_kl = 0.0
    
    for _ in range(n_samples):
        # 重参数化采样
        kl = 0.0
        for name, param in model.named_parameters():
            if hasattr(param, 'mu') and hasattr(param, 'log_var'):
                # 这是一个贝叶斯层
                kl += kl_divergence_gaussian(param.mu, param.log_var)
            elif 'weight' in name or 'bias' in name:
                # 假设参数本身存储了 mu 和 log_var
                # 例如: model.layer.weight_mu, model.layer.weight_log_var
                pass
        
        # 前向传播
        output = model(x)
        
        # 重构损失(负对数似然)
        nll = F.cross_entropy(output, y, reduction='sum')
        
        total_loss += nll
        total_kl += kl
    
    # 平均
    loss = (total_loss / n_samples) + lambda_reg * total_kl
    
    return loss, total_loss / n_samples, total_kl

PyTorch 实现

完整实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class BayesianLinear(nn.Module):
    """
    贝叶斯线性层
    
    使用均值场高斯近似
    """
    
    def __init__(self, in_features, out_features, prior_std=1.0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.prior_std = prior_std
        
        # 变分参数:均值和对数方差
        # 使用 Xavier 初始化确定均值
        scale = math.sqrt(2.0 / (in_features + out_features))
        
        self.weight_mu = nn.Parameter(
            torch.randn(out_features, in_features) * scale
        )
        self.weight_log_var = nn.Parameter(
            torch.zeros(out_features, in_features) - 6  # log(0.001)
        )
        
        self.bias_mu = nn.Parameter(torch.zeros(out_features))
        self.bias_log_var = nn.Parameter(
            torch.zeros(out_features) - 6
        )
    
    def forward(self, x):
        # 重参数化采样
        weight = self.reparameterize(self.weight_mu, self.weight_log_var)
        bias = self.reparameterize(self.bias_mu, self.bias_log_var)
        
        return F.linear(x, weight, bias)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def kl_loss(self):
        """计算与先验的 KL 散度"""
        def kl_gaussian(mu, log_var, prior_var):
            var = torch.exp(log_var)
            return 0.5 * (
                log_var - torch.log(torch.tensor(prior_var, device=mu.device))
                + (var + mu ** 2) / prior_var
                - 1.0
            ).sum()
        
        prior_var = self.prior_std ** 2
        kl_w = kl_gaussian(self.weight_mu, self.weight_log_var, prior_var)
        kl_b = kl_gaussian(self.bias_mu, self.bias_log_var, prior_var)
        
        return kl_w + kl_b
 
 
class BayesianMLP(nn.Module):
    """
    贝叶斯多层感知机
    """
    
    def __init__(self, input_dim, hidden_dim, output_dim, prior_std=1.0):
        super().__init__()
        
        self.fc1 = BayesianLinear(input_dim, hidden_dim, prior_std)
        self.fc2 = BayesianLinear(hidden_dim, hidden_dim, prior_std)
        self.fc3 = BayesianLinear(hidden_dim, output_dim, prior_std)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)
    
    def kl_loss(self):
        return self.fc1.kl_loss() + self.fc2.kl_loss() + self.fc3.kl_loss()
    
    def predict(self, x, n_samples=50):
        """
        贝叶斯预测
        """
        predictions = []
        
        with torch.no_grad():
            for _ in range(n_samples):
                pred = self(x)
                predictions.append(pred)
        
        predictions = torch.stack(predictions)  # (T, batch, output)
        mean = predictions.mean(dim=0)
        variance = predictions.var(dim=0)
        
        return mean, variance, predictions
 
 
class BayesianTrainer:
    """
    贝叶斯神经网络训练器
    """
    
    def __init__(self, model, lr=0.001, kl_weight=1.0):
        self.model = model
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        self.kl_weight = kl_weight
    
    def train_step(self, x, y, n_samples=1):
        self.optimizer.zero_grad()
        
        total_loss = 0.0
        nll_sum = 0.0
        kl_sum = 0.0
        
        for _ in range(n_samples):
            output = self.model(x)
            nll = F.cross_entropy(output, y, reduction='sum')
            kl = self.model.kl_loss()
            
            loss = nll + self.kl_weight * kl
            loss.backward()
            
            total_loss += loss.item()
            nll_sum += nll.item()
            kl_sum += kl.item()
        
        self.optimizer.step()
        
        return {
            'loss': total_loss / n_samples,
            'nll': nll_sum / n_samples,
            'kl': kl_sum / n_samples
        }
    
    def predict(self, x, n_samples=50):
        return self.model.predict(x, n_samples)
 
 
# 训练示例
model = BayesianMLP(input_dim=784, hidden_dim=256, output_dim=10)
trainer = BayesianTrainer(model, lr=0.001, kl_weight=1.0)
 
for epoch in range(10):
    for batch_x, batch_y in dataloader:
        metrics = trainer.train_step(batch_x, batch_y, n_samples=1)
        print(f"Loss: {metrics['loss']:.4f}, NLL: {metrics['nll']:.4f}, KL: {metrics['kl']:.4f}")

局部重参数化技巧

为了减少梯度估计的方差,可以使用局部重参数化技巧

class BayesianLinearLocalReparam(nn.Module):
    """
    使用局部重参数化技巧的贝叶斯线性层
    
    关键优化:直接在层的输入空间采样,而非在参数空间采样
    这减少了梯度估计的方差
    """
    
    def __init__(self, in_features, out_features, prior_std=1.0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.prior_std = prior_std
        
        # 只存储均值(确定性)
        self.weight_mu = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
        self.bias_mu = nn.Parameter(torch.zeros(out_features))
        
        # 存储对数方差
        self.weight_log_var = nn.Parameter(torch.zeros(out_features, in_features) - 6)
        self.bias_log_var = nn.Parameter(torch.zeros(out_features) - 6)
    
    def forward(self, x):
        # 计算输出分布的均值和方差
        # E[y] = x @ E[w] + E[b]
        # Var[y] = x² @ Var[w] + Var[b]
        
        weight_mean = self.weight_mu
        bias_mean = self.bias_mu
        weight_var = torch.exp(self.weight_log_var)
        bias_var = torch.exp(self.bias_log_var)
        
        # 局部重参数化:直接在输出空间采样
        # mean = x @ w_mu + b_mu
        output_mean = F.linear(x, weight_mean, bias_mean)
        
        # var = (x² @ w_var) + b_var
        output_var = F.linear(x ** 2, weight_var, bias_var)
        
        # 采样
        output_std = torch.sqrt(output_var + 1e-8)
        eps = torch.randn_like(output_mean)
        output = output_mean + eps * output_std
        
        return output
    
    def kl_loss(self):
        prior_var = self.prior_std ** 2
        
        def kl_gaussian(mu, log_var):
            var = torch.exp(log_var)
            return 0.5 * (
                log_var - torch.log(torch.tensor(prior_var, device=mu.device))
                + (var + mu ** 2) / prior_var
                - 1.0
            ).sum()
        
        return kl_gaussian(self.weight_mu, self.weight_log_var) + \
               kl_gaussian(self.bias_mu, self.bias_log_var)

与其他方法的比较

方法后验近似计算复杂度实现难度
Bayes by Backprop均值场高斯中等中等
MC DropoutBernoulli Dropout
Laplace 近似高斯(曲率来自 Hessian)中等
MCMC精确后验采样极高

Bayes by Backprop 的优势

  1. 灵活性:可以指定任意先验分布
  2. 端到端:与标准神经网络训练流程一致
  3. 不确定性量化:同时估计 aleatoric 和 epistemic 不确定性

Bayes by Backprop 的劣势

  1. 均值场假设:假设权重独立,可能过于简化
  2. 后验协方差未知:只学习对角协方差
  3. 训练不稳定:需要仔细调参(如 kl_weight)

核心公式速查

概念公式
ELBO
重参数化
KL 高斯
梯度估计

参考

相关文章

Footnotes

  1. Blundell, C., et al. (2015). “Weight Uncertainty in Neural Networks”. ICML 2015.