变分推断与采样方法对比
1 引言
变分推断(Variational Inference, VI)和蒙特卡洛采样(Monte Carlo Sampling)是近似贝叶斯推断的两大核心范式。1 理解它们的优缺点对于实际问题中选择合适的方法至关重要。
核心问题:
- 何时选择变分推断?
- 何时选择采样方法?
- 能否结合两者的优势?
本章从理论、效率、表达能力三个维度进行深度对比分析。
2 两种范式概述
2.1 变分推断:优化视角
变分推断将推断问题转化为优化问题:
目标:找到最接近真实后验的近似分布
等价形式:最大化ELBO
特点:
- 将随机问题转化为确定性问题
- 利用优化算法(梯度下降)
- 需要选择近似族
2.2 采样方法:随机视角
采样方法通过随机模拟来近似后验分布:
目标:从后验分布 采样
估计量:
特点:
- 渐近无偏()
- 不需要选择近似族
- 收敛速度慢()
2.3 直观对比
变分推断 采样方法
│ │
▼ ▼
┌──────┐ ┌──────┐
│优化器│ │随机游走│
└──┬───┘ └──┬───┘
│ │
▼ ▼
快速但有偏 慢但无偏
(固定近似族) (任意精度)
3 理论基础对比
3.1 渐近性质
| 方面 | 变分推断 | 采样方法 |
|---|---|---|
| 偏差 | 有偏(近似族限制) | 无偏(渐近) |
| 方差 | 优化方差 | 蒙特卡洛方差 |
| 收敛 | 确定性优化 | 随机过程 |
变分推断的偏差来源:
- 近似族 的表达能力
- 优化器的局部最优
采样方法的方差来源:
- 有限样本数
- 自相关(相关采样)
3.2 复杂度分析
设数据量为 ,参数维度为 。
变分推断复杂度:
- 每步:(似然计算)+(KL项)
- 总迭代:通常数百到数千步
采样方法复杂度:
- 每次接受:(梯度计算)+ (提议生成)
- 收敛样本数:可能需要数万到百万样本
3.3 维度依赖
维度灾难的影响:
| 维度 | 变分推断 | 采样方法 |
|---|---|---|
| 低维 | 可能欠拟合 | 高效 |
| 中维 | 可行 | 可行(需小心) |
| 高维 | 可扩展(mean-field) | 挑战大 |
关键洞察:
- 变分推断在高维时可通过结构化近似族保持可行
- MCMC在高维时面临慢混合问题
4 计算效率对比
4.1 收敛速度
变分推断:
- 早期快速收敛
- 后期缓慢改进
- 通常 100-1000 次迭代即可
采样方法:
- 预热期无效
- 渐近收敛但速度慢
- 通常需要 10000+ 样本
4.2 梯度利用
变分推断:
- 直接利用梯度信息
- 自然梯度、随机梯度
- 高效的优化算法
采样方法:
- HMC等利用梯度
- 随机游走类方法不利用梯度
- 效率差异巨大
4.3 可扩展性
数据可扩展性:
| 方法 | 小数据集 | 大数据集 |
|---|---|---|
| 变分推断 | ✓ | ✓(随机变分推断) |
| MCMC | ✓ | △(需要特殊处理) |
随机变分推断(SVI):
其中 是基于小批量数据的无偏估计。
4.4 并行化
变分推断:
- 数据并行:天然适合
- 模型并行:需小心处理
采样方法:
- 多链独立:天然并行
- 链内并行:较难
5 表达能力对比
5.1 近似族的影响
变分推断的表达能力完全取决于近似族 的选择。
常见近似族:
| 近似族 | 表达能力 | 计算复杂度 |
|---|---|---|
| 平均场 | 低 | 低 |
| 均值场变分 | 低-中 | 中 |
| 结构化变分 | 中 | 中-高 |
| 归一化流 | 高 | 高 |
| 混合模型 | 高 | 高 |
5.2 平均场假设的限制
平均场变分分布:
问题:
- 忽略参数间相关性
- 后验协方差估计不准确
- 在高度相关的问题上表现差
5.3 高表达能力变分族
归一化流(Normalizing Flows):
通过可逆变换逐步增加表达能力。
树状变分分布:
保留团结构的相关性。
5.4 采样的表达能力
采样方法的表达能力:
- 无偏:理论上可以表示任意分布
- 实际:受链混合速度限制
在高度多峰分布上:
平均场VI MCMC
│ │
▼ ▼
┌─────┐ ┌─────┐
│单峰 │ │多峰 │
│(有偏)│ │(无偏)│
└─────┘ └─────┘
6 混合方法
6.1 变分指导的采样
使用变分后验指导MCMC的提议分布:
提议分布:
优势:
- 提议更接近目标分布
- 提高接受率
- 加速混合
def variational_guided_mcmc(target_log_prob, vi_mean, vi_cov,
n_samples=10000, warmup=1000):
"""
变分引导的MCMC采样
"""
# 使用变分后验作为提议
def proposal(mean, cov, current):
return np.random.multivariate_normal(mean, cov)
# 或者作为辅助分布
def modified_target(theta):
log_post = target_log_prob(theta)
log_vi = multivariate_normal.logpdf(theta, vi_mean, vi_cov)
return log_post - log_vi # 调整目标6.2 采样增强的变分
使用MCMC步骤增强变分分布:
随机变分推断+重参数化MCMC:
class SVIMCMC:
"""采样增强的变分推断"""
def __init__(self, target_log_prob, n_particles=10):
self.target = target_log_prob
self.n_particles = n_particles
def step(self, q_params, data_batch):
# E步:MCMC更新粒子
particles = self.mcmc_update(q_params, self.n_particles)
# M步:更新变分参数
loss = self.compute_elbo(particles, q_params, data_batch)
return {'loss': loss, 'particles': particles}
def mcmc_update(self, q_params, n_particles):
"""MCMC更新变分粒子"""
# 使用随机游走或HMC
# ...6.3 正则化流增强
使用采样技术训练更灵活的变分分布:
流正则化:
其中 是基于采样的正则化项。
7 实践选择指南
7.1 决策流程
开始
│
├─ 数据规模
│ ├─ 小数据(< 10K)──> 两者皆可
│ └─ 大数据(> 10K)──> 变分推断(SVI)
│
├─ 维度
│ ├─ 低维(< 10)──> MCMC
│ ├─ 中维(10-100)──> HMC/变分
│ └─ 高维(> 100)──> 结构化变分
│
├─ 后验复杂度
│ ├─ 单峰/近似高斯──> 变分推断
│ ├─ 多峰/复杂──> MCMC
│ └─ 高度相关──> HMC
│
└─ 计算资源
├─ GPU/TPU ──> 两者皆可
└─ CPU ──> 变分推断更快
7.2 具体场景推荐
| 场景 | 推荐方法 | 理由 |
|---|---|---|
| 实时推断 | 变分推断 | 一次优化,多次使用 |
| 批量贝叶斯分析 | MCMC | 精确后验估计 |
| 深度学习集成 | 变分推断 | 可微分、端到端 |
| 小数据集 | MCMC | 需要无偏估计 |
| 稀疏模型 | 变分推断 | 自然稀疏先验 |
| 隐变量模型 | EM + 变分 | 半变分方法 |
7.3 软件工具对比
| 工具 | 方法 | 特点 |
|---|---|---|
| PyMC | MCMC | 成熟、易用 |
| Stan | HMC/NUTS | 高效、诊断完善 |
| PyTorch-VIP | 变分推断 | 深度学习集成 |
| Edward2 | 变分推断 | 灵活、自定义 |
| NumPyro | MCMC/变分 | JAX后端、并行 |
8 深度学习中的应用
8.1 变分自编码器(VAE)
VAE是变分推断的典型应用:
8.2 贝叶斯神经网络
变分方法:
- bayesian-neural-networks — BNN基础
- bayes-by-backprop — 变分权重不确定性
采样方法:
- bayesian-last-layer-deep-learning — 最后一层贝叶斯
- SGD作为随机采样:sgd-as-bayesian-inference
8.3 生成模型
| 模型 | 方法 | 特点 |
|---|---|---|
| VAE | 变分推断 | 可微分、重参数化 |
| GAN | 对抗训练 | 无显式推断 |
| Flow | 可逆变换 | 可精确推断 |
| Diffusion | 采样 | 迭代去噪 |
| EBM | MCMC采样 | 灵活但慢 |
详见 diffusion-model 和 generative-adversarial-network。
9 理论联系
9.1 变分推断的理论保证
变分推断的局限性:
即使优化到最优,仍然存在最优近似误差:
除非 。
PAC-Bayes视角:
变分分布 的泛化误差有上界:
9.2 MCMC的渐近保证
遍历定理:
对于不可约、非周期的马尔可夫链:
收敛速度:
- 几何收敛:
- 谱隙 决定收敛速度
9.3 两种方法的统一视角
信息论统一:
变分推断最小化KL散度,采样方法最小化某种熵相关目标。
优化视角统一:
两者都可以理解为优化问题:
- VI:优化参数化的分布
- MCMC:优化采样路径
10 代码实现对比
10.1 变分推断实现
import torch
import torch.nn as nn
import torch.optim as optim
class VariationalLinearRegression:
"""贝叶斯线性回归的变分推断"""
def __init__(self, input_dim, prior_std=1.0):
self.prior_std = prior_std
# 变分参数
self.mean = nn.Parameter(torch.zeros(input_dim))
self.log_std = nn.Parameter(torch.zeros(input_dim))
def forward(self, x):
"""前向传播(重参数化)"""
std = torch.exp(self.log_std)
eps = torch.randn_like(x) * std
return x @ (self.mean + eps)
def elbo(self, x, y, n_samples=10):
"""证据下界"""
# 重参数化采样
samples = []
for _ in range(n_samples):
std = torch.exp(self.log_std)
eps = torch.randn_like(self.mean)
w = self.mean + eps * std
samples.append(w)
# 重建损失
recon = sum(self.compute_log_likelihood(x, y, w)
for w in samples) / n_samples
# KL项
prior = torch.distributions.Normal(0, self.prior_std)
post = torch.distributions.Normal(self.mean, torch.exp(self.log_std))
kl = torch.distributions.kl.kl_divergence(post, prior).sum()
return recon - kl / len(x)
def fit(self, x, y, epochs=1000):
"""训练"""
optimizer = optim.Adam([self.mean, self.log_std], lr=0.01)
for epoch in range(epochs):
optimizer.zero_grad()
loss = -self.elbo(x, y) # 最小化负ELBO
loss.backward()
optimizer.step()10.2 MCMC实现
class MCMCLinearRegression:
"""贝叶斯线性回归的MCMC推断"""
def __init__(self, log_posterior):
self.log_posterior = log_posterior
def metropolis_hastings(self, x, y, n_samples=10000,
proposal_std=0.1, warmup=1000):
"""Metropolis-Hastings采样"""
d = x.shape[1]
samples = []
# 初始化
theta = torch.zeros(d)
for i in range(n_samples + warmup):
# 提议
proposal = theta + torch.randn(d) * proposal_std
# 接受率
log_alpha = (self.log_posterior(proposal, x, y) -
self.log_posterior(theta, x, y))
if torch.rand(1).log() < log_alpha:
theta = proposal
if i >= warmup:
samples.append(theta.clone())
return torch.stack(samples)10.3 选择指南代码
def select_inference_method(n_samples, n_dims, n_data,
posterior_complexity='medium',
time_budget='medium'):
"""
自动选择推断方法
"""
# 决策树
if n_data > 50000:
return "Variational (SVI)", "Large dataset requires scalable method"
if n_dims > 500:
return "Structured Variational", "High dimension with structured approximation"
if posterior_complexity == 'simple':
return "Mean-field Variational", "Simple posterior, fast approximation"
if posterior_complexity == 'complex' and time_budget == 'large':
return "MCMC (HMC/NUTS)", "Complex posterior needs unbiased sampling"
if time_budget == 'small':
return "Mean-field Variational", "Fast approximation under time constraint"
return "Variational + MCMC", "Hybrid approach for best of both worlds"11 总结
本章深度对比了变分推断和采样方法:
11.1 关键对比点
| 方面 | 变分推断 | 采样方法 |
|---|---|---|
| 精度 | 有偏但实用 | 无偏渐近 |
| 速度 | 快 | 慢 |
| 扩展性 | 好 | 一般 |
| 表达能力 | 受限于近似族 | 原则上无限 |
| 实现难度 | 中等 | 较高 |
11.2 选择原则
- 需要快速结果 → 变分推断
- 需要精确贝叶斯 → MCMC
- 高维大数据 → 结构化变分
- 复杂后验 → MCMC或混合方法
11.3 未来趋势
- 自动变分推断(AVI)
- 神经MCMC(可学习的提议)
- 统一框架:变分-采样连续体
参考文献
Footnotes
-
Blei, D.M., Kucukelbir, A., & McAuliffe, J.D. (2017). Variational inference: A review for statisticians. Journal of the American Statistical Association, 112(518), 859-877. ↩