概述

随机梯度马尔可夫链蒙特卡洛(Stochastic Gradient MCMC,SG-MCMC)方法将经典MCMC方法与随机优化技术相结合,使得在大规模数据集上进行贝叶斯推断成为可能。1

传统MCMC方法(如Metropolis-Hastings)在高维深度学习参数空间中面临严重的计算瓶颈。SG-MCMC通过使用随机梯度来近似真实梯度,以 的每次迭代复杂度实现对后验分布的采样。


预备知识

贝叶斯深度学习回顾

给定数据集 ,参数的后验分布为:

其中:

  • 是似然
  • 是先验分布

目标

从后验分布 中采样,以估计:

  • 后验预测分布:
  • 后验统计量:

随机梯度Langevin动力学(SGLD)

经典Langevin动力学

物理中的Langevin方程描述粒子在势能场中的随机运动:

其中:

  • 是势能函数(负对数后验)
  • 是维纳过程(Wiener process)

离散化:Metropolis-adjusted Langevin算法(MALA)

使用欧拉-Maruyama方法离散化:

其中

SGLD算法

Welling & Teh (2011) 提出用随机梯度近似真实梯度。2

关键观察:对数后验梯度可以分解为:

很大时,计算 成本很高。使用小批量 近似:

SGLD更新公式

其中:

  • 是学习率(逐渐减小)
  • 是第 步的随机小批量

SGLD的Python实现

import torch
from torch import nn
 
class SGLD:
    def __init__(self, model, lr=1e-4, weight_decay=1e-5):
        self.model = model
        self.lr = lr
        self.weight_decay = weight_decay
    
    def step(self, batch_x, batch_y):
        # 启用梯度
        for p in self.model.parameters():
            if p.grad is not None:
                p.grad.zero_()
        
        # 前向传播
        output = self.model(batch_x)
        loss = nn.functional.cross_entropy(output, batch_y)
        
        # 添加先验梯度(权重衰减)
        for p in self.model.parameters():
            if p.requires_grad:
                p.grad = torch.autograd.grad(loss, p)[0] + self.weight_decay * p
        
        # 采样噪声
        noise_std = torch.sqrt(2 * self.lr)
        with torch.no_grad():
            for p in self.model.parameters():
                if p.requires_grad and p.grad is not None:
                    # 梯度下降 + 噪声
                    p.add_(p.grad, alpha=-0.5 * self.lr)
                    p.add_(torch.randn_like(p) * noise_std)
        
        return loss.item()
    
    def sample(self, num_samples=100, burn_in=500):
        """收集样本"""
        samples = []
        self.model.train()
        
        for t in range(burn_in + num_samples):
            # 获取小批量数据
            batch = self.get_batch()
            self.step(batch['x'], batch['y'])
            
            if t >= burn_in:
                samples.append(copy.deepcopy(self.model.state_dict()))
        
        return samples

随机梯度Hamiltonian蒙特卡洛(SGHMC)

Hamiltonian动力学

HMC利用Hamiltonian动力学来探索后验分布:

其中 是动量变量, 是质量矩阵(通常为恒等矩阵)。

动力学方程

离散化:Leapfrog积分

SGHMC

Chen et al. (2014) 将HMC扩展到随机梯度设置。3

核心思想:将热噪声融入动力学方程,避免Metropolis-Hastings接受步骤。

SGHMC更新公式

其中:

  • 是随机梯度
  • 是噪声协方差估计
  • 是摩擦系数矩阵
  • 是阻尼系数

SGHMC实现

class SGHMC:
    def __init__(self, model, lr=1e-4, alpha=0.01, gamma=0.1):
        self.model = model
        self.lr = lr
        self.alpha = alpha  # 噪声衰减
        self.gamma = gamma  # 摩擦系数
        
        # 动量缓冲
        self.momentum = {name: torch.zeros_like(param) 
                        for name, param in model.named_parameters()}
    
    def step(self, batch_x, batch_y):
        # 计算随机梯度
        output = self.model(batch_x)
        loss = nn.functional.cross_entropy(output, batch_y)
        loss.backward()
        
        noise_std = torch.sqrt(2 * self.lr * self.gamma)
        
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if param.grad is None:
                    continue
                
                grad = param.grad.data
                
                # 更新动量
                self.momentum[name].mul_(1 - 2 * self.lr * self.gamma)
                self.momentum[name].add_(grad, alpha=-self.lr)
                self.momentum[name].add_(torch.randn_like(param) * noise_std)
                
                # 更新参数
                param.add_(self.momentum[name], alpha=self.lr)
                
                # 清除梯度
                param.grad.zero_()
        
        return loss.item()

预条件器设计

预条件器可以改善SGMCMC的收敛性。

1. Riemannian SGLD

使用Fisher信息矩阵作为Riemannian度量:

class RiemannianSGLD:
    def __init__(self, model, lr=1e-4, preconditioner='fisher'):
        self.model = model
        self.lr = lr
        self.preconditioner = preconditioner
        self.fisher_ema = {}  # 指数移动平均的Fisher信息
    
    def compute_fisher(self, batch_x, batch_y, emp_fisher=True):
        """计算Fisher信息矩阵(或其对角近似)"""
        self.model.eval()
        output = self.model(batch_x)
        
        if emp_fisher:
            # 经验Fisher:对数似然的梯度外积
            log_prob = nn.functional.log_softmax(output, dim=-1)
            targets = torch.zeros_like(log_prob)
            targets[torch.arange(len(batch_y)), batch_y] = 1
            
            grads = torch.autograd.grad(
                log_prob.sum(), self.model.parameters()
            )
            fisher = [g.pow(2) for g in grads]
        else:
            # True Fisher:关于真实后验的Fisher
            fisher = [...]  # 真实后验的梯度外积
        
        return fisher
    
    def step(self, batch_x, batch_y):
        # 计算梯度
        output = self.model(batch_x)
        loss = nn.functional.cross_entropy(output, batch_y)
        
        grads = torch.autograd.grad(loss, self.model.parameters())
        grads = list(grad)
        
        # 更新Fisher估计
        fisher = self.compute_fisher(batch_x, batch_y)
        self._update_fisher_ema(fisher)
        
        # 应用预条件器
        with torch.no_grad():
            for i, (name, param) in enumerate(self.model.named_parameters()):
                # 对角Fisher的倒数作为预条件器
                precond = 1.0 / (self.fisher_ema[name] + 1e-6)
                
                # 更新参数
                param.add_(grads[i] * precond, alpha=-self.lr)
                param.add_(torch.randn_like(param) * torch.sqrt(self.lr * precond))
        
        return loss.item()

2. K-FAC预条件器

Kronecker-Factored Approximate Curvature (K-FAC) 提供了一种高效的曲率近似方法。


与变分推断的对比

特性SGMCMC变分推断
目标分布精确后验采样后验近似
计算成本中等
偏差无(渐近)有(近似误差)
收敛诊断困难N/A
隐式正则化依赖近似族
实现复杂度中等低-中

选择指南

场景推荐方法
小规模数据精确MCMC
大规模深度学习SGLD/SGHMC
需要快速推断变分推断
高度复杂后验SGMCMC

大规模训练中的应用

1. 分布式SGMCMC

class DistributedSGLD:
    def __init__(self, model, num_workers=4):
        self.model = model
        self.num_workers = num_workers
        self.world_size = dist.get_world_size()
    
    def distributed_step(self, batch_x, batch_y):
        # 所有workers计算局部梯度
        output = self.model(batch_x)
        loss = nn.functional.cross_entropy(output, batch_y)
        loss.backward()
        
        # 梯度平均
        for param in self.model.parameters():
            if param.grad is not None:
                dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
                param.grad.data.div_(self.world_size)
        
        # SGLD更新
        self._sgld_update()

2. 异步SGMCMC

允许workers异步更新参数。

3. 冷却时间表

学习率的冷却策略对收敛至关重要:

def cosine_schedule(t, T, lr_min, lr_max):
    """余弦冷却"""
    return lr_min + 0.5 * (lr_max - lr_min) * (1 + torch.cos(torch.pi * t / T))
 
def step_decay(t, milestones=[50000, 100000], gamma=0.1):
    """阶梯冷却"""
    for m in milestones:
        if t >= m:
            return gamma
    return 1.0

后验预测推断

1. 样本收集

def collect_posterior_samples(model, train_loader, num_samples=100, burn_in=500):
    """收集后验样本"""
    samples = []
    sampler = SGLD(model, lr=1e-4)
    
    for t, (x, y) in enumerate(train_loader):
        sampler.step(x, y)
        
        if t >= burn_in and (t - burn_in) % 10 == 0:
            samples.append(copy.deepcopy(model.state_dict()))
            
        if len(samples) >= num_samples:
            break
    
    return samples

2. 预测分布

def posterior_predictive(model, samples, x_test):
    """
    计算后验预测分布
    p(y*|x*, D) ≈ (1/K) Σ p(y*|x*, θ_k)
    """
    predictions = []
    
    for state_dict in samples:
        model.load_state_dict(state_dict)
        model.eval()
        
        with torch.no_grad():
            pred = torch.softmax(model(x_test), dim=-1)
            predictions.append(pred)
    
    predictions = torch.stack(predictions)
    
    # 点预测
    mean_pred = predictions.mean(dim=0)
    
    # 不确定性
    pred_std = predictions.std(dim=0)
    
    return {
        'mean': mean_pred,
        'std': pred_std,
        'samples': predictions
    }

收敛诊断

1. Gelman-Rubin诊断(R̂)

def gelman_rubin(chains, target_dim):
    """
    计算R̂统计量
    R̂ ≈ 1 表示收敛
    """
    m = len(chains)  # 链数
    n = chains[0].shape[0]  # 每链样本数
    
    # 计算链内方差
    W = torch.stack([
        chains[i][:, target_dim].var(dim=0) for i in range(m)
    ]).mean(dim=0)
    
    # 计算链间方差
    chain_means = torch.stack([
        chains[i][:, target_dim].mean(dim=0) for i in range(m)
    ])
    B = n * chain_means.var(dim=0) / (m - 1)
    
    # 方差估计
    var_hat = (n - 1) / n * W + B / n
    
    # R̂
    R_hat = torch.sqrt(var_hat / W)
    
    return R_hat

2. 有效样本量(ESS)

def effective_sample_size(chain):
    """
    计算有效样本量
    """
    n = len(chain)
    
    # 自相关函数
    acf = torch.tensor([torch.corrcoef(
        torch.stack([chain[:-lag], chain[lag:]])
    )[0, 1] for lag in range(1, n // 2)])
    
    # 截断在第一个负自相关处
    acf = acf[acf > 0]
    
    # ESS
    ess = n / (1 + 2 * acf.sum())
    
    return ess

实践技巧

1. 学习率选择

SGLD的学习率通常比标准SGD小一个数量级。

2. 预热期

在收集样本前进行足够长的预热(burn-in)。

3. 采样间隔

避免连续样本之间的相关性,使用间隔采样。

4. 先验设置

适当的先验(如高斯先验 + 权重衰减)可以改善采样质量。


参考


相关主题

Footnotes

  1. Ma et al. (2015). “A Complete Recipe for Stochastic Gradient MCMC”. NIPS 2015.

  2. Welling & Teh (2011). “Bayesian Learning via Stochastic Gradient Langevin Dynamics”. ICML 2011.

  3. Chen et al. (2014). “Stochastic Gradient Hamiltonian Monte Carlo”. ICML 2014.