变分推断与采样方法对比

1 引言

变分推断(Variational Inference, VI)和蒙特卡洛采样(Monte Carlo Sampling)是近似贝叶斯推断的两大核心范式。1 理解它们的优缺点对于实际问题中选择合适的方法至关重要。

核心问题

  • 何时选择变分推断?
  • 何时选择采样方法?
  • 能否结合两者的优势?

本章从理论、效率、表达能力三个维度进行深度对比分析。


2 两种范式概述

2.1 变分推断:优化视角

变分推断将推断问题转化为优化问题

目标:找到最接近真实后验的近似分布

等价形式:最大化ELBO

特点

  • 将随机问题转化为确定性问题
  • 利用优化算法(梯度下降)
  • 需要选择近似族

2.2 采样方法:随机视角

采样方法通过随机模拟来近似后验分布:

目标:从后验分布 采样

估计量

特点

  • 渐近无偏(
  • 不需要选择近似族
  • 收敛速度慢(

2.3 直观对比

变分推断                    采样方法
     │                          │
     ▼                          ▼
  ┌──────┐                  ┌──────┐
  │优化器│                  │随机游走│
  └──┬───┘                  └──┬───┘
     │                          │
     ▼                          ▼
  快速但有偏              慢但无偏
  (固定近似族)            (任意精度)

3 理论基础对比

3.1 渐近性质

方面变分推断采样方法
偏差有偏(近似族限制)无偏(渐近)
方差优化方差蒙特卡洛方差
收敛确定性优化随机过程

变分推断的偏差来源

  • 近似族 的表达能力
  • 优化器的局部最优

采样方法的方差来源

  • 有限样本数
  • 自相关(相关采样)

3.2 复杂度分析

设数据量为 ,参数维度为

变分推断复杂度

  • 每步:(似然计算)+(KL项)
  • 总迭代:通常数百到数千步

采样方法复杂度

  • 每次接受:(梯度计算)+ (提议生成)
  • 收敛样本数:可能需要数万到百万样本

3.3 维度依赖

维度灾难的影响

维度变分推断采样方法
低维可能欠拟合高效
中维可行可行(需小心)
高维可扩展(mean-field)挑战大

关键洞察

  • 变分推断在高维时可通过结构化近似族保持可行
  • MCMC在高维时面临慢混合问题

4 计算效率对比

4.1 收敛速度

变分推断

  • 早期快速收敛
  • 后期缓慢改进
  • 通常 100-1000 次迭代即可

采样方法

  • 预热期无效
  • 渐近收敛但速度慢
  • 通常需要 10000+ 样本

4.2 梯度利用

变分推断

  • 直接利用梯度信息
  • 自然梯度、随机梯度
  • 高效的优化算法

采样方法

  • HMC等利用梯度
  • 随机游走类方法不利用梯度
  • 效率差异巨大

4.3 可扩展性

数据可扩展性

方法小数据集大数据集
变分推断✓(随机变分推断)
MCMC△(需要特殊处理)

随机变分推断(SVI)

其中 是基于小批量数据的无偏估计。

4.4 并行化

变分推断

  • 数据并行:天然适合
  • 模型并行:需小心处理

采样方法

  • 多链独立:天然并行
  • 链内并行:较难

5 表达能力对比

5.1 近似族的影响

变分推断的表达能力完全取决于近似族 的选择。

常见近似族

近似族表达能力计算复杂度
平均场
均值场变分低-中
结构化变分中-高
归一化流
混合模型

5.2 平均场假设的限制

平均场变分分布

问题

  • 忽略参数间相关性
  • 后验协方差估计不准确
  • 在高度相关的问题上表现差

5.3 高表达能力变分族

归一化流(Normalizing Flows)

通过可逆变换逐步增加表达能力。

树状变分分布

保留团结构的相关性。

5.4 采样的表达能力

采样方法的表达能力

  • 无偏:理论上可以表示任意分布
  • 实际:受链混合速度限制

在高度多峰分布上

          平均场VI              MCMC
              │                  │
              ▼                  ▼
           ┌─────┐           ┌─────┐
           │单峰 │           │多峰 │
           │(有偏)│           │(无偏)│
           └─────┘           └─────┘

6 混合方法

6.1 变分指导的采样

使用变分后验指导MCMC的提议分布:

提议分布

优势

  • 提议更接近目标分布
  • 提高接受率
  • 加速混合
def variational_guided_mcmc(target_log_prob, vi_mean, vi_cov, 
                            n_samples=10000, warmup=1000):
    """
    变分引导的MCMC采样
    """
    # 使用变分后验作为提议
    def proposal(mean, cov, current):
        return np.random.multivariate_normal(mean, cov)
    
    # 或者作为辅助分布
    def modified_target(theta):
        log_post = target_log_prob(theta)
        log_vi = multivariate_normal.logpdf(theta, vi_mean, vi_cov)
        return log_post - log_vi  # 调整目标

6.2 采样增强的变分

使用MCMC步骤增强变分分布:

随机变分推断+重参数化MCMC

class SVIMCMC:
    """采样增强的变分推断"""
    
    def __init__(self, target_log_prob, n_particles=10):
        self.target = target_log_prob
        self.n_particles = n_particles
    
    def step(self, q_params, data_batch):
        # E步:MCMC更新粒子
        particles = self.mcmc_update(q_params, self.n_particles)
        
        # M步:更新变分参数
        loss = self.compute_elbo(particles, q_params, data_batch)
        
        return {'loss': loss, 'particles': particles}
    
    def mcmc_update(self, q_params, n_particles):
        """MCMC更新变分粒子"""
        # 使用随机游走或HMC
        # ...

6.3 正则化流增强

使用采样技术训练更灵活的变分分布:

流正则化

其中 是基于采样的正则化项。


7 实践选择指南

7.1 决策流程

开始
   │
   ├─ 数据规模
   │    ├─ 小数据(< 10K)──> 两者皆可
   │    └─ 大数据(> 10K)──> 变分推断(SVI)
   │
   ├─ 维度
   │    ├─ 低维(< 10)──> MCMC
   │    ├─ 中维(10-100)──> HMC/变分
   │    └─ 高维(> 100)──> 结构化变分
   │
   ├─ 后验复杂度
   │    ├─ 单峰/近似高斯──> 变分推断
   │    ├─ 多峰/复杂──> MCMC
   │    └─ 高度相关──> HMC
   │
   └─ 计算资源
        ├─ GPU/TPU ──> 两者皆可
        └─ CPU ──> 变分推断更快

7.2 具体场景推荐

场景推荐方法理由
实时推断变分推断一次优化,多次使用
批量贝叶斯分析MCMC精确后验估计
深度学习集成变分推断可微分、端到端
小数据集MCMC需要无偏估计
稀疏模型变分推断自然稀疏先验
隐变量模型EM + 变分半变分方法

7.3 软件工具对比

工具方法特点
PyMCMCMC成熟、易用
StanHMC/NUTS高效、诊断完善
PyTorch-VIP变分推断深度学习集成
Edward2变分推断灵活、自定义
NumPyroMCMC/变分JAX后端、并行

8 深度学习中的应用

8.1 变分自编码器(VAE)

VAE是变分推断的典型应用:

8.2 贝叶斯神经网络

变分方法

采样方法

8.3 生成模型

模型方法特点
VAE变分推断可微分、重参数化
GAN对抗训练无显式推断
Flow可逆变换可精确推断
Diffusion采样迭代去噪
EBMMCMC采样灵活但慢

详见 diffusion-modelgenerative-adversarial-network


9 理论联系

9.1 变分推断的理论保证

变分推断的局限性

即使优化到最优,仍然存在最优近似误差

除非

PAC-Bayes视角

变分分布 的泛化误差有上界:

9.2 MCMC的渐近保证

遍历定理

对于不可约、非周期的马尔可夫链:

收敛速度

  • 几何收敛:
  • 谱隙 决定收敛速度

9.3 两种方法的统一视角

信息论统一

变分推断最小化KL散度,采样方法最小化某种熵相关目标。

优化视角统一

两者都可以理解为优化问题:

  • VI:优化参数化的分布
  • MCMC:优化采样路径

10 代码实现对比

10.1 变分推断实现

import torch
import torch.nn as nn
import torch.optim as optim
 
class VariationalLinearRegression:
    """贝叶斯线性回归的变分推断"""
    
    def __init__(self, input_dim, prior_std=1.0):
        self.prior_std = prior_std
        
        # 变分参数
        self.mean = nn.Parameter(torch.zeros(input_dim))
        self.log_std = nn.Parameter(torch.zeros(input_dim))
    
    def forward(self, x):
        """前向传播(重参数化)"""
        std = torch.exp(self.log_std)
        eps = torch.randn_like(x) * std
        return x @ (self.mean + eps)
    
    def elbo(self, x, y, n_samples=10):
        """证据下界"""
        # 重参数化采样
        samples = []
        for _ in range(n_samples):
            std = torch.exp(self.log_std)
            eps = torch.randn_like(self.mean)
            w = self.mean + eps * std
            samples.append(w)
        
        # 重建损失
        recon = sum(self.compute_log_likelihood(x, y, w) 
                   for w in samples) / n_samples
        
        # KL项
        prior = torch.distributions.Normal(0, self.prior_std)
        post = torch.distributions.Normal(self.mean, torch.exp(self.log_std))
        kl = torch.distributions.kl.kl_divergence(post, prior).sum()
        
        return recon - kl / len(x)
    
    def fit(self, x, y, epochs=1000):
        """训练"""
        optimizer = optim.Adam([self.mean, self.log_std], lr=0.01)
        
        for epoch in range(epochs):
            optimizer.zero_grad()
            loss = -self.elbo(x, y)  # 最小化负ELBO
            loss.backward()
            optimizer.step()

10.2 MCMC实现

class MCMCLinearRegression:
    """贝叶斯线性回归的MCMC推断"""
    
    def __init__(self, log_posterior):
        self.log_posterior = log_posterior
    
    def metropolis_hastings(self, x, y, n_samples=10000, 
                           proposal_std=0.1, warmup=1000):
        """Metropolis-Hastings采样"""
        d = x.shape[1]
        samples = []
        
        # 初始化
        theta = torch.zeros(d)
        
        for i in range(n_samples + warmup):
            # 提议
            proposal = theta + torch.randn(d) * proposal_std
            
            # 接受率
            log_alpha = (self.log_posterior(proposal, x, y) - 
                        self.log_posterior(theta, x, y))
            
            if torch.rand(1).log() < log_alpha:
                theta = proposal
            
            if i >= warmup:
                samples.append(theta.clone())
        
        return torch.stack(samples)

10.3 选择指南代码

def select_inference_method(n_samples, n_dims, n_data, 
                          posterior_complexity='medium',
                          time_budget='medium'):
    """
    自动选择推断方法
    """
    # 决策树
    if n_data > 50000:
        return "Variational (SVI)", "Large dataset requires scalable method"
    
    if n_dims > 500:
        return "Structured Variational", "High dimension with structured approximation"
    
    if posterior_complexity == 'simple':
        return "Mean-field Variational", "Simple posterior, fast approximation"
    
    if posterior_complexity == 'complex' and time_budget == 'large':
        return "MCMC (HMC/NUTS)", "Complex posterior needs unbiased sampling"
    
    if time_budget == 'small':
        return "Mean-field Variational", "Fast approximation under time constraint"
    
    return "Variational + MCMC", "Hybrid approach for best of both worlds"

11 总结

本章深度对比了变分推断和采样方法:

11.1 关键对比点

方面变分推断采样方法
精度有偏但实用无偏渐近
速度
扩展性一般
表达能力受限于近似族原则上无限
实现难度中等较高

11.2 选择原则

  • 需要快速结果 → 变分推断
  • 需要精确贝叶斯 → MCMC
  • 高维大数据 → 结构化变分
  • 复杂后验 → MCMC或混合方法

11.3 未来趋势

  • 自动变分推断(AVI)
  • 神经MCMC(可学习的提议)
  • 统一框架:变分-采样连续体

参考文献

Footnotes

  1. Blei, D.M., Kucukelbir, A., & McAuliffe, J.D. (2017). Variational inference: A review for statisticians. Journal of the American Statistical Association, 112(518), 859-877.