MCMC进阶:NUTS、切片采样与收敛诊断
1. 背景回顾
1.1 基础MCMC回顾
MCMC(Markov Chain Monte Carlo)方法通过构建马尔可夫链来从目标分布 中采样。经典方法包括:
| 方法 | 原理 | 优点 | 缺点 |
|---|---|---|---|
| Metropolis-Hastings | 提议-接受 | 通用性强 | 收敛慢 |
| Gibbs采样 | 条件采样 | 无接受率问题 | 需要可分解条件分布 |
| Hamiltonian MC | 梯度辅助 | 高效探索 | 需手动设置步数 |
详见mcmc-methods。
1.2 为什么需要进阶方法?
标准MH和Gibbs在高维问题中面临以下挑战:
- 随机游走行为:每次移动受限
- Accept rate vs 效率权衡:高接受率意味着小步长,导致随机游走
- 参数调优困难:步长、提议分布等需要手动设置
NUTS和切片采样正是为了解决这些问题而提出的。
2. Hamiltonian Monte Carlo回顾
2.1 核心思想
HMC通过引入辅助动量变量 ,在扩展相空间 中执行哈密顿动力学:
哈密顿方程:
2.2 Leapfrog积分器
离散化哈密顿方程(辛欧拉法变体):
def leapfrog(q, r, epsilon, grad_U):
"""Leapfrog积分器
Args:
q: 位置
r: 动量
epsilon: 步长
grad_U: 势能梯度函数 (-log p的梯度)
Returns:
q_new, r_new: 更新后的位置和动量
"""
r_half = r - 0.5 * epsilon * grad_U(q)
q_new = q + epsilon * r_half
r_new = r_half - 0.5 * epsilon * grad_U(q_new)
return q_new, r_new2.3 HMC算法流程
def hmc_sampling(U, grad_U, q_init, n_samples, L, epsilon):
"""HMC采样
Args:
U: 势能函数 (-log p)
grad_U: 势能梯度
q_init: 初始位置
n_samples: 采样数量
L: 每步迭代数(轨迹长度)
epsilon: 步长
"""
samples = [q_init]
q = q_init
for _ in range(n_samples):
# 采样动量
r = np.random.randn(len(q))
# 提议新状态
q_new, r_new = q.copy(), r.copy()
for _ in range(L):
q_new, r_new = leapfrog(q_new, r_new, epsilon, grad_U)
# Metropolis-Hastings接受
current_H = U(q) + 0.5 * np.sum(r**2)
proposed_H = U(q_new) + 0.5 * np.sum(r_new**2)
if np.random.rand() < np.exp(current_H - proposed_H):
q = q_new
samples.append(q)
return np.array(samples[1:])3. NUTS:No-U-Turn Sampler
3.1 核心问题
HMC需要手动设置轨迹长度 ,这带来两个问题:
- 太长:浪费计算资源
- 太短:采样效率低,接近随机游走
NUTS 自动确定合适的轨迹长度,通过检测轨迹是否开始”掉头”(U-turn)。
3.2 NUTS算法详解
3.2.1 终止条件
NUTS通过以下条件检测U-turn:
其中 是轨迹两端, 是对应的动量。
直观理解:当轨迹两端开始相向运动时,继续走只会重复已探索的区域。
3.2.2 切片采样增强
NUTS引入切片变量 来避免接受-拒绝机制:
这使得算法可以”一次性”构建整个子树,而不是逐个提议。
3.2.3 二叉树构建
NUTS递归构建二叉树,深度 表示 个leapfrog步骤:
def build_tree(theta, r, u, v, j, epsilon):
"""递归构建二叉树
Args:
theta: 起始位置
r: 起始动量
u: 切片变量
v: 方向 (+1 或 -1)
j: 当前深度
epsilon: 步长
"""
if j == 0:
# 叶子节点:单步leapfrog
theta_prime, r_prime = leapfrog(theta, r, v * epsilon, grad_U)
n_prime = 1 if u <= np.exp(L(theta_prime) - 0.5 * np.sum(r_prime**2)) else 0
s_prime = 1
return theta_prime, r_prime, theta_prime, r_prime, n_prime, s_prime, (theta_prime, r_prime)
# 递归构建子树
theta_minus, r_minus, theta_plus, r_plus, n, s, _ = build_tree(theta, r, u, v, j-1, epsilon)
if v == -1:
theta_minus, r_minus, n_prime, s_prime, children = build_tree(
theta_minus, r_minus, u, v, j-1, epsilon
)
else:
theta_plus, r_plus, n_prime, s_prime, children = build_tree(
theta_plus, r_plus, u, v, j-1, epsilon
)
# 终止条件检查
if s_prime == 1:
if np.random.rand() < min(1, n_prime / (n + n_prime)):
if v == -1:
theta_new = theta_minus
else:
theta_new = theta_plus
n = n + n_prime
s = s_prime * (is_u_turn_ok(theta_minus, theta_plus, r_minus, r_plus))
return theta_minus, r_minus, theta_plus, r_plus, n, s, theta_new3.3 NUTS完整算法
def nuts_sampler(log_p, grad_log_p, theta0, n_samples, max_depth=10):
"""NUTS采样器
Args:
log_p: 对数概率函数
grad_log_p: 对数概率梯度
theta0: 初始位置
n_samples: 采样数量
max_depth: 最大树深度
"""
# NUTS使用Dual Averaging自动调参
epsilon, mu, gamma, t0, kappa = initialize_dual_averaging()
samples = [theta0]
theta = theta0
epsilon_bar = 1.0
H_bar = 0
for m in range(n_samples):
# 采样动量
r0 = np.random.randn(len(theta))
# 切片采样
u = np.random.uniform(0, np.exp(log_p(theta) - 0.5 * np.sum(r0**2)))
# 初始化
theta_minus = theta_plus = theta
r_minus = r_plus = r0
j = 0
n = 1
s = 1
while s == 1:
# 选择方向
v = np.random.choice([-1, 1])
if v == -1:
theta_minus, r_minus, _, _, n_prime, s_prime, _ = build_tree(
theta_minus, r_minus, u, v, j, epsilon
)
else:
_, _, theta_plus, r_plus, n_prime, s_prime, _ = build_tree(
theta_plus, r_plus, u, v, j, epsilon
)
# 更新
if s_prime == 1:
if np.random.rand() < min(1, n_prime / n):
if v == -1:
theta = theta_minus
else:
theta = theta_plus
n = n + n_prime
s = s_prime * check_u_turn(theta_minus, theta_plus, r_minus, r_plus)
j = j + 1
if j >= max_depth:
break
# 更新步长
H_bar = (1 - 1/(m + t0)) * H_bar + (1/(m + t0)) * (0.65 - accept_rate)
epsilon = 2 ** (m < 0.5 * (n_samples + 1)) * np.exp(-H_bar)
epsilon_bar = np.exp(m**(-kappa) * np.log(epsilon) + (1 - m**(-kappa)) * np.log(epsilon_bar))
samples.append(theta)
return np.array(samples[1:])3.4 NUTS的理论基础
黎曼几何视角
HMC/NUTS的深层几何含义:
| 概念 | 欧几里得空间 | 黎曼流形 |
|---|---|---|
| 度量 | 恒等矩阵 | Fisher信息矩阵 |
| 距离 | ||
| 动量分布 | ||
| Leapfrog | 标准形式 | 需要修正(半隐式辛积分) |
Riemannian HMC (RHMC):使用局部度量 的HMC,可以更好地适应参数空间的局部曲率。
4. 切片采样(Slice Sampling)
4.1 核心思想
切片采样通过引入辅助变量将1-D分布的采样转化为2-D均匀分布的采样:
然后从边缘分布中采样:
4.2 单变量切片采样
def slice_sampling_1d(log_p, x0, w=1.0, max_steps=100):
"""单变量切片采样
Args:
log_p: 对数概率密度函数
x0: 初始点
w: 步长参数(建议区间宽度)
max_steps: 最大收缩步数
"""
x = x0
for _ in range(max_steps):
# 步骤1:计算对数密度
log_pdf = log_p(x)
# 步骤2:采样辅助变量u
u = np.random.uniform(0, np.exp(log_pdf))
log_u = np.log(u)
# 步骤3:找到x的邻域 [x_l, x_r]
x_l = x - np.random.uniform(0, w)
x_r = x_l + w
# 步骤4:收缩直到找到有效区间
while log_p(x_l) > log_u:
x_l -= w
while log_p(x_r) > log_u:
x_r += w
# 步骤5:在区间内采样
while True:
x_new = np.random.uniform(x_l, x_r)
if log_p(x_new) > log_u:
return x_new
elif x_new < x:
x_l = x_new
else:
x_r = x_new4.3 多变量切片采样
方向采样法:
def slice_sampling_multivariate(log_p, x0, n_steps=10):
"""多变量切片采样
使用随机方向采样方法
"""
x = x0.copy()
D = len(x)
for _ in range(n_steps):
# 计算当前密度
log_pdf = log_p(x)
# 切片
u = np.random.uniform(0, np.exp(log_pdf))
log_u = np.log(u)
# 选择随机方向
direction = np.random.randn(D)
direction = direction / np.linalg.norm(direction)
# 沿方向找区间
x_d = np.dot(x, direction) # 投影
# 简单的收缩逻辑
x_new = x.copy() # 实际实现需要更复杂的区间收缩
x = x_new
return x4.4 切片采样的优势
| 特性 | 说明 |
|---|---|
| 自适应步长 | 自动调整提议宽度 |
| 无接受率问题 | 始终接受(理论上) |
| 无梯度需求 | 适用于黑盒函数 |
| 自动缩放 | 适应目标分布尺度 |
5. MCMC收敛诊断
5.1 诊断的重要性
MCMC链需要足够长才能”忘记”初始位置,并充分探索目标分布。
5.2 Gelman-Rubin诊断 ()
经典
定义:
其中:
- :链内方差
- :后验方差估计
- :链间方差
解读:
- :收敛
- :可能未收敛
改进的Rank-based
问题:经典 在以下情况会失效:
- 链有不同方差
- 分布有重尾(如Cauchy)
- 单链未充分混合
改进方案(Vehtari et al., 2019):
-
Rank归一化:
其中 是所有样本的rank。 -
Folded (检测方差差异):
def compute_rank_rhat(samples):
"""计算改进的Rank-based R-hat
Args:
samples: shape (M, N) 的样本,M是链数,N是每链样本数
Returns:
rhat: 诊断值
"""
M, N = samples.shape
S = M * N
# 所有样本排序
all_samples = samples.flatten()
ranks = np.argsort(np.argsort(all_samples)) + 1
# 重新组织为 (M, N)
ranks = ranks.reshape(M, N)
# Rank归一化
z = stats.norm.ppf((ranks - 0.375) / (S + 0.25))
# 计算链内方差
z_mean_per_chain = z.mean(axis=1, keepdims=True)
W = ((z - z_mean_per_chain) ** 2).sum() / (M * N - M)
# 计算后验方差
z_mean = z.mean()
B = N * ((z_mean_per_chain - z_mean) ** 2).sum() / (M - 1)
var_hat = (N - 1) / N * W + B / N
rhat = np.sqrt(var_hat / W)
return rhat5.3 有效样本量(ESS)
Bulk-ESS
Bulk-ESS测量分布中心的有效样本量:
其中 是滞后 的自相关。
def compute_ess(samples, min_lag=1, max_lag=None):
"""计算有效样本量
Args:
samples: shape (M, N) 的样本
min_lag: 最小滞后
max_lag: 最大滞后
Returns:
ess: Bulk-ESS
"""
M, N = samples.shape
# 展平所有链
flat_samples = samples.flatten()
n = len(flat_samples)
if max_lag is None:
max_lag = n // 2
# 计算自相关
mean = flat_samples.mean()
var = flat_samples.var()
acov = np.zeros(max_lag - min_lag + 1)
acov[0] = ((flat_samples - mean) ** 2).mean()
for lag in range(min_lag, max_lag + 1):
acov[lag - min_lag + 1] = ((flat_samples[:-lag] - mean) *
(flat_samples[lag:] - mean)).mean()
# 拖尾自相关
rho = np.ones(len(acov))
for i, acov_i in enumerate(acov):
if acov_i <= 0:
rho[i] = 0
else:
rho[i] = acov_i / acov[0]
if i > 0 and acov_i < 0:
break
# ESS计算
ess = M * N / (1 + 2 * rho[1:].sum())
return int(ess)Tail-ESS
Tail-ESS测量分布尾部的ESS(通常取2.5%和97.5%分位数附近):
对于关键统计量(均值、方差、尾部),建议:
- Bulk-ESS ≥ 400(确保中心估计稳定)
- Tail-ESS ≥ 400(确保尾部形状准确)
5.4 Geweke检验
Geweke检验比较链开头和结尾部分的均值:
def geweke_test(chain, first_frac=0.1, last_frac=0.5):
"""Geweke收敛诊断
Args:
chain: 单条链的样本
first_frac: 开头部分的比例
last_frac: 结尾部分的比例
Returns:
z_score: 如果|z| < 2,认为收敛
"""
n = len(chain)
first = chain[:int(n * first_frac)]
last = chain[int(n * (1 - last_frac)):]
mean1, var1 = first.mean(), first.var() / len(first)
mean2, var2 = last.mean(), last.var() / len(last)
z = (mean1 - mean2) / np.sqrt(var1 + var2)
return z5.5 实践指南
诊断流程
def diagnose_mcmc(samples, var_names=None):
"""完整的MCMC诊断
Args:
samples: dict of {var_name: (M, N) array}
var_names: 变量名列表
"""
if var_names is None:
var_names = list(samples.keys())
results = {}
for name in var_names:
chain = samples[name] # shape: (M, N)
# 计算诊断量
rhat = compute_rank_rhat(chain)
ess = compute_ess(chain)
tail_ess = compute_tail_ess(chain)
# Monte Carlo误差
mcse = chain.var() / np.sqrt(ess)
results[name] = {
'Rhat': rhat,
'Bulk-ESS': ess,
'Tail-ESS': tail_ess,
'MCSE': mcse
}
# 打印报告
print(f"\n{name}:")
print(f" R-hat: {rhat:.4f} {'✓' if rhat < 1.05 else '⚠'}")
print(f" Bulk-ESS: {ess} {'✓' if ess > 400 else '⚠'}")
print(f" Tail-ESS: {tail_ess} {'✓' if tail_ess > 400 else '⚠'}")
print(f" MCSE: {mcse:.6f}")
return results收敛判断标准
| 指标 | 收敛标准 | 不收敛时处理 |
|---|---|---|
| < 1.05 (推荐), < 1.1 (可接受) | 增加迭代次数、使用不同初始化 | |
| Bulk-ESS | ≥ 400 (推荐), ≥ 100 (最小) | 增加链数、增加迭代次数 |
| Tail-ESS | ≥ 400 (推荐) | 同上 |
常见问题与解决方案
| 问题 | 症状 | 解决方案 |
|---|---|---|
| 未混合 | ESS低、自相关高 | 减小步长、重新参数化 |
| 多模态 | 链只在部分模态 | 使用并行Tempering |
| 高度相关 | 慢速混合 | 使用几何平均参数化 |
| 初始化敏感 | 不同初始化结果不同 | 延长burn-in、增加迭代 |
6. 工具与实践
6.1 PyMC实现
import pymc as pm
import numpy as np
# 简单示例:线性回归的贝叶斯推断
np.random.seed(42)
x = np.random.randn(100)
y = 2 * x + 1 + 0.5 * np.random.randn(100)
with pm.Model() as model:
# 先验
alpha = pm.Normal('alpha', mu=0, sigma=10)
beta = pm.Normal('beta', mu=0, sigma=10)
sigma = pm.HalfNormal('sigma', sigma=1)
# 似然
mu = alpha + beta * x
likelihood = pm.Normal('y', mu=mu, sigma=sigma, observed=y)
# NUTS采样
trace = pm.sample(2000, tune=1000, chains=4, cores=2)
# 诊断
print(pm.summary(trace))
pm.plot_trace(trace)6.2 Stan实现
data {
int<lower=0> N;
vector[N] x;
vector[N] y;
}
parameters {
real alpha;
real beta;
real<lower=0> sigma;
}
model {
// 先验
alpha ~ normal(0, 10);
beta ~ normal(0, 10);
sigma ~ normal(0, 1);
// 似然
y ~ normal(alpha + beta * x, sigma);
}6.3 NumPyro实现(支持GPU加速)
import jax.numpy as jnp
from jax import random
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
def model(x, y=None):
alpha = numpyro.sample('alpha', dist.Normal(0, 10))
beta = numpyro.sample('beta', dist.Normal(0, 10))
sigma = numpyro.sample('sigma', dist.HalfNormal(1))
mu = alpha + beta * x
numpyro.sample('y', dist.Normal(mu, sigma), obs=y)
# NUTS采样
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc.run(random.PRNGKey(0), x=jnp.array(x), y=jnp.array(y))