Bayes by Backprop
Bayes by Backprop 是由 Blundell 等人在 2015 年提出的变分推断方法,用于学习神经网络权重的概率分布。1 核心思想是用一个参数化的变分分布 逼近真实后验 ,并通过重参数化技巧实现端到端的梯度优化。
变分推断框架
目标函数
寻找最优变分参数 最小化 KL 散度:
根据贝叶斯公式:
展开 KL 散度:
由于 不依赖于 ,最小化 KL 散度等价于最大化证据下界(ELBO):
直观理解
ELBO = 重构能力 - 与先验的偏离
- 重构项高 → 模型能很好地拟合数据
- KL项低 → 后验接近先验(避免过拟合)
均值场近似
分解假设
Bayes by Backprop 使用均值场近似:
其中 是每个权重的变分参数。
高斯先验
通常选择高斯先验:
KL 散度的闭式解
对于两个高斯分布,KL 散度有解析表达式:
对于所有权重求和:
def kl_divergence_gaussian(mu, log_var, prior_std=1.0):
"""
计算高斯先验下的 KL 散度
D_KL(N(mu, sigma²) || N(0, sigma_p²))
"""
prior_var = prior_std ** 2
var = torch.exp(log_var)
kl = 0.5 * (
log_var - torch.log(torch.tensor(prior_var))
+ (var + mu ** 2) / prior_var
- 1.0
)
return kl.sum()重参数化技巧
问题
KL 散度项可以直接计算,但重构项涉及期望:
其中 是随机变量,无法直接对 求导。
解决方案
使用重参数化技巧(Reparameterization Trick):
将随机变量表示为确定性变换:
这样 的随机性来自 ,而 是确定性参数。
梯度估计
通过蒙特卡洛采样估计:
def reparameterize(mu, log_var):
"""
重参数化采样
Args:
mu: 均值 (任意形状)
log_var: 对数方差 (任意形状)
Returns:
采样的权重 (与 mu 形状相同)
"""
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * stdBayes by Backprop 算法
完整的损失函数
其中 是第 次采样的权重。
算法流程
def bayes_by_backprop_loss(model, x, y, n_samples=1, lambda_reg=1.0):
"""
Bayes by Backprop 损失函数
Args:
model: 贝叶斯神经网络
x, y: 数据
n_samples: 采样次数
lambda_reg: KL 正则化权重
"""
total_loss = 0.0
total_kl = 0.0
for _ in range(n_samples):
# 重参数化采样
kl = 0.0
for name, param in model.named_parameters():
if hasattr(param, 'mu') and hasattr(param, 'log_var'):
# 这是一个贝叶斯层
kl += kl_divergence_gaussian(param.mu, param.log_var)
elif 'weight' in name or 'bias' in name:
# 假设参数本身存储了 mu 和 log_var
# 例如: model.layer.weight_mu, model.layer.weight_log_var
pass
# 前向传播
output = model(x)
# 重构损失(负对数似然)
nll = F.cross_entropy(output, y, reduction='sum')
total_loss += nll
total_kl += kl
# 平均
loss = (total_loss / n_samples) + lambda_reg * total_kl
return loss, total_loss / n_samples, total_klPyTorch 实现
完整实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class BayesianLinear(nn.Module):
"""
贝叶斯线性层
使用均值场高斯近似
"""
def __init__(self, in_features, out_features, prior_std=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.prior_std = prior_std
# 变分参数:均值和对数方差
# 使用 Xavier 初始化确定均值
scale = math.sqrt(2.0 / (in_features + out_features))
self.weight_mu = nn.Parameter(
torch.randn(out_features, in_features) * scale
)
self.weight_log_var = nn.Parameter(
torch.zeros(out_features, in_features) - 6 # log(0.001)
)
self.bias_mu = nn.Parameter(torch.zeros(out_features))
self.bias_log_var = nn.Parameter(
torch.zeros(out_features) - 6
)
def forward(self, x):
# 重参数化采样
weight = self.reparameterize(self.weight_mu, self.weight_log_var)
bias = self.reparameterize(self.bias_mu, self.bias_log_var)
return F.linear(x, weight, bias)
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def kl_loss(self):
"""计算与先验的 KL 散度"""
def kl_gaussian(mu, log_var, prior_var):
var = torch.exp(log_var)
return 0.5 * (
log_var - torch.log(torch.tensor(prior_var, device=mu.device))
+ (var + mu ** 2) / prior_var
- 1.0
).sum()
prior_var = self.prior_std ** 2
kl_w = kl_gaussian(self.weight_mu, self.weight_log_var, prior_var)
kl_b = kl_gaussian(self.bias_mu, self.bias_log_var, prior_var)
return kl_w + kl_b
class BayesianMLP(nn.Module):
"""
贝叶斯多层感知机
"""
def __init__(self, input_dim, hidden_dim, output_dim, prior_std=1.0):
super().__init__()
self.fc1 = BayesianLinear(input_dim, hidden_dim, prior_std)
self.fc2 = BayesianLinear(hidden_dim, hidden_dim, prior_std)
self.fc3 = BayesianLinear(hidden_dim, output_dim, prior_std)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
def kl_loss(self):
return self.fc1.kl_loss() + self.fc2.kl_loss() + self.fc3.kl_loss()
def predict(self, x, n_samples=50):
"""
贝叶斯预测
"""
predictions = []
with torch.no_grad():
for _ in range(n_samples):
pred = self(x)
predictions.append(pred)
predictions = torch.stack(predictions) # (T, batch, output)
mean = predictions.mean(dim=0)
variance = predictions.var(dim=0)
return mean, variance, predictions
class BayesianTrainer:
"""
贝叶斯神经网络训练器
"""
def __init__(self, model, lr=0.001, kl_weight=1.0):
self.model = model
self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
self.kl_weight = kl_weight
def train_step(self, x, y, n_samples=1):
self.optimizer.zero_grad()
total_loss = 0.0
nll_sum = 0.0
kl_sum = 0.0
for _ in range(n_samples):
output = self.model(x)
nll = F.cross_entropy(output, y, reduction='sum')
kl = self.model.kl_loss()
loss = nll + self.kl_weight * kl
loss.backward()
total_loss += loss.item()
nll_sum += nll.item()
kl_sum += kl.item()
self.optimizer.step()
return {
'loss': total_loss / n_samples,
'nll': nll_sum / n_samples,
'kl': kl_sum / n_samples
}
def predict(self, x, n_samples=50):
return self.model.predict(x, n_samples)
# 训练示例
model = BayesianMLP(input_dim=784, hidden_dim=256, output_dim=10)
trainer = BayesianTrainer(model, lr=0.001, kl_weight=1.0)
for epoch in range(10):
for batch_x, batch_y in dataloader:
metrics = trainer.train_step(batch_x, batch_y, n_samples=1)
print(f"Loss: {metrics['loss']:.4f}, NLL: {metrics['nll']:.4f}, KL: {metrics['kl']:.4f}")局部重参数化技巧
为了减少梯度估计的方差,可以使用局部重参数化技巧:
class BayesianLinearLocalReparam(nn.Module):
"""
使用局部重参数化技巧的贝叶斯线性层
关键优化:直接在层的输入空间采样,而非在参数空间采样
这减少了梯度估计的方差
"""
def __init__(self, in_features, out_features, prior_std=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.prior_std = prior_std
# 只存储均值(确定性)
self.weight_mu = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
self.bias_mu = nn.Parameter(torch.zeros(out_features))
# 存储对数方差
self.weight_log_var = nn.Parameter(torch.zeros(out_features, in_features) - 6)
self.bias_log_var = nn.Parameter(torch.zeros(out_features) - 6)
def forward(self, x):
# 计算输出分布的均值和方差
# E[y] = x @ E[w] + E[b]
# Var[y] = x² @ Var[w] + Var[b]
weight_mean = self.weight_mu
bias_mean = self.bias_mu
weight_var = torch.exp(self.weight_log_var)
bias_var = torch.exp(self.bias_log_var)
# 局部重参数化:直接在输出空间采样
# mean = x @ w_mu + b_mu
output_mean = F.linear(x, weight_mean, bias_mean)
# var = (x² @ w_var) + b_var
output_var = F.linear(x ** 2, weight_var, bias_var)
# 采样
output_std = torch.sqrt(output_var + 1e-8)
eps = torch.randn_like(output_mean)
output = output_mean + eps * output_std
return output
def kl_loss(self):
prior_var = self.prior_std ** 2
def kl_gaussian(mu, log_var):
var = torch.exp(log_var)
return 0.5 * (
log_var - torch.log(torch.tensor(prior_var, device=mu.device))
+ (var + mu ** 2) / prior_var
- 1.0
).sum()
return kl_gaussian(self.weight_mu, self.weight_log_var) + \
kl_gaussian(self.bias_mu, self.bias_log_var)与其他方法的比较
| 方法 | 后验近似 | 计算复杂度 | 实现难度 |
|---|---|---|---|
| Bayes by Backprop | 均值场高斯 | 中等 | 中等 |
| MC Dropout | Bernoulli Dropout | 低 | 低 |
| Laplace 近似 | 高斯(曲率来自 Hessian) | 高 | 中等 |
| MCMC | 精确后验采样 | 极高 | 高 |
Bayes by Backprop 的优势
- 灵活性:可以指定任意先验分布
- 端到端:与标准神经网络训练流程一致
- 不确定性量化:同时估计 aleatoric 和 epistemic 不确定性
Bayes by Backprop 的劣势
- 均值场假设:假设权重独立,可能过于简化
- 后验协方差未知:只学习对角协方差
- 训练不稳定:需要仔细调参(如 kl_weight)
核心公式速查
| 概念 | 公式 |
|---|---|
| ELBO | |
| 重参数化 | |
| KL 高斯 | |
| 梯度估计 |
参考
相关文章
- 贝叶斯神经网络 — BNN 基础概念
- MC Dropout — Dropout 的贝叶斯解释
- Laplace近似 — Hessian 近似方法
- 变分推断 — VI 理论基础
Footnotes
-
Blundell, C., et al. (2015). “Weight Uncertainty in Neural Networks”. ICML 2015. ↩