概述
随机梯度马尔可夫链蒙特卡洛(Stochastic Gradient MCMC,SG-MCMC)方法将经典MCMC方法与随机优化技术相结合,使得在大规模数据集上进行贝叶斯推断成为可能。1
传统MCMC方法(如Metropolis-Hastings)在高维深度学习参数空间中面临严重的计算瓶颈。SG-MCMC通过使用随机梯度来近似真实梯度,以 的每次迭代复杂度实现对后验分布的采样。
预备知识
贝叶斯深度学习回顾
给定数据集 ,参数的后验分布为:
其中:
- 是似然
- 是先验分布
目标
从后验分布 中采样,以估计:
- 后验预测分布:
- 后验统计量:,
随机梯度Langevin动力学(SGLD)
经典Langevin动力学
物理中的Langevin方程描述粒子在势能场中的随机运动:
其中:
- 是势能函数(负对数后验)
- 是维纳过程(Wiener process)
离散化:Metropolis-adjusted Langevin算法(MALA)
使用欧拉-Maruyama方法离散化:
其中 。
SGLD算法
Welling & Teh (2011) 提出用随机梯度近似真实梯度。2
关键观察:对数后验梯度可以分解为:
当 很大时,计算 成本很高。使用小批量 近似:
SGLD更新公式
其中:
- 是学习率(逐渐减小)
- 是第 步的随机小批量
SGLD的Python实现
import torch
from torch import nn
class SGLD:
def __init__(self, model, lr=1e-4, weight_decay=1e-5):
self.model = model
self.lr = lr
self.weight_decay = weight_decay
def step(self, batch_x, batch_y):
# 启用梯度
for p in self.model.parameters():
if p.grad is not None:
p.grad.zero_()
# 前向传播
output = self.model(batch_x)
loss = nn.functional.cross_entropy(output, batch_y)
# 添加先验梯度(权重衰减)
for p in self.model.parameters():
if p.requires_grad:
p.grad = torch.autograd.grad(loss, p)[0] + self.weight_decay * p
# 采样噪声
noise_std = torch.sqrt(2 * self.lr)
with torch.no_grad():
for p in self.model.parameters():
if p.requires_grad and p.grad is not None:
# 梯度下降 + 噪声
p.add_(p.grad, alpha=-0.5 * self.lr)
p.add_(torch.randn_like(p) * noise_std)
return loss.item()
def sample(self, num_samples=100, burn_in=500):
"""收集样本"""
samples = []
self.model.train()
for t in range(burn_in + num_samples):
# 获取小批量数据
batch = self.get_batch()
self.step(batch['x'], batch['y'])
if t >= burn_in:
samples.append(copy.deepcopy(self.model.state_dict()))
return samples随机梯度Hamiltonian蒙特卡洛(SGHMC)
Hamiltonian动力学
HMC利用Hamiltonian动力学来探索后验分布:
其中 是动量变量, 是质量矩阵(通常为恒等矩阵)。
动力学方程:
离散化:Leapfrog积分
SGHMC
Chen et al. (2014) 将HMC扩展到随机梯度设置。3
核心思想:将热噪声融入动力学方程,避免Metropolis-Hastings接受步骤。
SGHMC更新公式
其中:
- 是随机梯度
- 是噪声协方差估计
- 是摩擦系数矩阵
- 是阻尼系数
SGHMC实现
class SGHMC:
def __init__(self, model, lr=1e-4, alpha=0.01, gamma=0.1):
self.model = model
self.lr = lr
self.alpha = alpha # 噪声衰减
self.gamma = gamma # 摩擦系数
# 动量缓冲
self.momentum = {name: torch.zeros_like(param)
for name, param in model.named_parameters()}
def step(self, batch_x, batch_y):
# 计算随机梯度
output = self.model(batch_x)
loss = nn.functional.cross_entropy(output, batch_y)
loss.backward()
noise_std = torch.sqrt(2 * self.lr * self.gamma)
with torch.no_grad():
for name, param in self.model.named_parameters():
if param.grad is None:
continue
grad = param.grad.data
# 更新动量
self.momentum[name].mul_(1 - 2 * self.lr * self.gamma)
self.momentum[name].add_(grad, alpha=-self.lr)
self.momentum[name].add_(torch.randn_like(param) * noise_std)
# 更新参数
param.add_(self.momentum[name], alpha=self.lr)
# 清除梯度
param.grad.zero_()
return loss.item()预条件器设计
预条件器可以改善SGMCMC的收敛性。
1. Riemannian SGLD
使用Fisher信息矩阵作为Riemannian度量:
class RiemannianSGLD:
def __init__(self, model, lr=1e-4, preconditioner='fisher'):
self.model = model
self.lr = lr
self.preconditioner = preconditioner
self.fisher_ema = {} # 指数移动平均的Fisher信息
def compute_fisher(self, batch_x, batch_y, emp_fisher=True):
"""计算Fisher信息矩阵(或其对角近似)"""
self.model.eval()
output = self.model(batch_x)
if emp_fisher:
# 经验Fisher:对数似然的梯度外积
log_prob = nn.functional.log_softmax(output, dim=-1)
targets = torch.zeros_like(log_prob)
targets[torch.arange(len(batch_y)), batch_y] = 1
grads = torch.autograd.grad(
log_prob.sum(), self.model.parameters()
)
fisher = [g.pow(2) for g in grads]
else:
# True Fisher:关于真实后验的Fisher
fisher = [...] # 真实后验的梯度外积
return fisher
def step(self, batch_x, batch_y):
# 计算梯度
output = self.model(batch_x)
loss = nn.functional.cross_entropy(output, batch_y)
grads = torch.autograd.grad(loss, self.model.parameters())
grads = list(grad)
# 更新Fisher估计
fisher = self.compute_fisher(batch_x, batch_y)
self._update_fisher_ema(fisher)
# 应用预条件器
with torch.no_grad():
for i, (name, param) in enumerate(self.model.named_parameters()):
# 对角Fisher的倒数作为预条件器
precond = 1.0 / (self.fisher_ema[name] + 1e-6)
# 更新参数
param.add_(grads[i] * precond, alpha=-self.lr)
param.add_(torch.randn_like(param) * torch.sqrt(self.lr * precond))
return loss.item()2. K-FAC预条件器
Kronecker-Factored Approximate Curvature (K-FAC) 提供了一种高效的曲率近似方法。
与变分推断的对比
| 特性 | SGMCMC | 变分推断 |
|---|---|---|
| 目标分布 | 精确后验采样 | 后验近似 |
| 计算成本 | 中等 | 低 |
| 偏差 | 无(渐近) | 有(近似误差) |
| 收敛诊断 | 困难 | N/A |
| 隐式正则化 | 无 | 依赖近似族 |
| 实现复杂度 | 中等 | 低-中 |
选择指南
| 场景 | 推荐方法 |
|---|---|
| 小规模数据 | 精确MCMC |
| 大规模深度学习 | SGLD/SGHMC |
| 需要快速推断 | 变分推断 |
| 高度复杂后验 | SGMCMC |
大规模训练中的应用
1. 分布式SGMCMC
class DistributedSGLD:
def __init__(self, model, num_workers=4):
self.model = model
self.num_workers = num_workers
self.world_size = dist.get_world_size()
def distributed_step(self, batch_x, batch_y):
# 所有workers计算局部梯度
output = self.model(batch_x)
loss = nn.functional.cross_entropy(output, batch_y)
loss.backward()
# 梯度平均
for param in self.model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data.div_(self.world_size)
# SGLD更新
self._sgld_update()2. 异步SGMCMC
允许workers异步更新参数。
3. 冷却时间表
学习率的冷却策略对收敛至关重要:
def cosine_schedule(t, T, lr_min, lr_max):
"""余弦冷却"""
return lr_min + 0.5 * (lr_max - lr_min) * (1 + torch.cos(torch.pi * t / T))
def step_decay(t, milestones=[50000, 100000], gamma=0.1):
"""阶梯冷却"""
for m in milestones:
if t >= m:
return gamma
return 1.0后验预测推断
1. 样本收集
def collect_posterior_samples(model, train_loader, num_samples=100, burn_in=500):
"""收集后验样本"""
samples = []
sampler = SGLD(model, lr=1e-4)
for t, (x, y) in enumerate(train_loader):
sampler.step(x, y)
if t >= burn_in and (t - burn_in) % 10 == 0:
samples.append(copy.deepcopy(model.state_dict()))
if len(samples) >= num_samples:
break
return samples2. 预测分布
def posterior_predictive(model, samples, x_test):
"""
计算后验预测分布
p(y*|x*, D) ≈ (1/K) Σ p(y*|x*, θ_k)
"""
predictions = []
for state_dict in samples:
model.load_state_dict(state_dict)
model.eval()
with torch.no_grad():
pred = torch.softmax(model(x_test), dim=-1)
predictions.append(pred)
predictions = torch.stack(predictions)
# 点预测
mean_pred = predictions.mean(dim=0)
# 不确定性
pred_std = predictions.std(dim=0)
return {
'mean': mean_pred,
'std': pred_std,
'samples': predictions
}收敛诊断
1. Gelman-Rubin诊断(R̂)
def gelman_rubin(chains, target_dim):
"""
计算R̂统计量
R̂ ≈ 1 表示收敛
"""
m = len(chains) # 链数
n = chains[0].shape[0] # 每链样本数
# 计算链内方差
W = torch.stack([
chains[i][:, target_dim].var(dim=0) for i in range(m)
]).mean(dim=0)
# 计算链间方差
chain_means = torch.stack([
chains[i][:, target_dim].mean(dim=0) for i in range(m)
])
B = n * chain_means.var(dim=0) / (m - 1)
# 方差估计
var_hat = (n - 1) / n * W + B / n
# R̂
R_hat = torch.sqrt(var_hat / W)
return R_hat2. 有效样本量(ESS)
def effective_sample_size(chain):
"""
计算有效样本量
"""
n = len(chain)
# 自相关函数
acf = torch.tensor([torch.corrcoef(
torch.stack([chain[:-lag], chain[lag:]])
)[0, 1] for lag in range(1, n // 2)])
# 截断在第一个负自相关处
acf = acf[acf > 0]
# ESS
ess = n / (1 + 2 * acf.sum())
return ess实践技巧
1. 学习率选择
SGLD的学习率通常比标准SGD小一个数量级。
2. 预热期
在收集样本前进行足够长的预热(burn-in)。
3. 采样间隔
避免连续样本之间的相关性,使用间隔采样。
4. 先验设置
适当的先验(如高斯先验 + 权重衰减)可以改善采样质量。
参考
相关主题
- mcmc-methods - MCMC方法
- bayesian-neural-networks - 贝叶斯神经网络基础
- bayesian-neural-networks-advanced-inference - 贝叶斯神经网络高级推断
- variational-inference-advanced - 变分推断进阶
- bnn-uncertainty-quantification - BNN不确定性量化