MCMC进阶:NUTS、切片采样与收敛诊断

1. 背景回顾

1.1 基础MCMC回顾

MCMC(Markov Chain Monte Carlo)方法通过构建马尔可夫链来从目标分布 中采样。经典方法包括:

方法原理优点缺点
Metropolis-Hastings提议-接受通用性强收敛慢
Gibbs采样条件采样无接受率问题需要可分解条件分布
Hamiltonian MC梯度辅助高效探索需手动设置步数

详见mcmc-methods

1.2 为什么需要进阶方法?

标准MH和Gibbs在高维问题中面临以下挑战:

  • 随机游走行为:每次移动受限
  • Accept rate vs 效率权衡:高接受率意味着小步长,导致随机游走
  • 参数调优困难:步长、提议分布等需要手动设置

NUTS和切片采样正是为了解决这些问题而提出的。


2. Hamiltonian Monte Carlo回顾

2.1 核心思想

HMC通过引入辅助动量变量 ,在扩展相空间 中执行哈密顿动力学:

哈密顿方程:

2.2 Leapfrog积分器

离散化哈密顿方程(辛欧拉法变体):

def leapfrog(q, r, epsilon, grad_U):
    """Leapfrog积分器
    
    Args:
        q: 位置
        r: 动量
        epsilon: 步长
        grad_U: 势能梯度函数 (-log p的梯度)
    
    Returns:
        q_new, r_new: 更新后的位置和动量
    """
    r_half = r - 0.5 * epsilon * grad_U(q)
    q_new = q + epsilon * r_half
    r_new = r_half - 0.5 * epsilon * grad_U(q_new)
    
    return q_new, r_new

2.3 HMC算法流程

def hmc_sampling(U, grad_U, q_init, n_samples, L, epsilon):
    """HMC采样
    
    Args:
        U: 势能函数 (-log p)
        grad_U: 势能梯度
        q_init: 初始位置
        n_samples: 采样数量
        L: 每步迭代数(轨迹长度)
        epsilon: 步长
    """
    samples = [q_init]
    q = q_init
    
    for _ in range(n_samples):
        # 采样动量
        r = np.random.randn(len(q))
        
        # 提议新状态
        q_new, r_new = q.copy(), r.copy()
        for _ in range(L):
            q_new, r_new = leapfrog(q_new, r_new, epsilon, grad_U)
        
        # Metropolis-Hastings接受
        current_H = U(q) + 0.5 * np.sum(r**2)
        proposed_H = U(q_new) + 0.5 * np.sum(r_new**2)
        
        if np.random.rand() < np.exp(current_H - proposed_H):
            q = q_new
        
        samples.append(q)
    
    return np.array(samples[1:])

3. NUTS:No-U-Turn Sampler

3.1 核心问题

HMC需要手动设置轨迹长度 ,这带来两个问题:

  • 太长:浪费计算资源
  • 太短:采样效率低,接近随机游走

NUTS 自动确定合适的轨迹长度,通过检测轨迹是否开始”掉头”(U-turn)。

3.2 NUTS算法详解

3.2.1 终止条件

NUTS通过以下条件检测U-turn:

其中 是轨迹两端, 是对应的动量。

直观理解:当轨迹两端开始相向运动时,继续走只会重复已探索的区域。

3.2.2 切片采样增强

NUTS引入切片变量 来避免接受-拒绝机制:

这使得算法可以”一次性”构建整个子树,而不是逐个提议。

3.2.3 二叉树构建

NUTS递归构建二叉树,深度 表示 个leapfrog步骤:

def build_tree(theta, r, u, v, j, epsilon):
    """递归构建二叉树
    
    Args:
        theta: 起始位置
        r: 起始动量
        u: 切片变量
        v: 方向 (+1 或 -1)
        j: 当前深度
        epsilon: 步长
    """
    if j == 0:
        # 叶子节点:单步leapfrog
        theta_prime, r_prime = leapfrog(theta, r, v * epsilon, grad_U)
        n_prime = 1 if u <= np.exp(L(theta_prime) - 0.5 * np.sum(r_prime**2)) else 0
        s_prime = 1
        return theta_prime, r_prime, theta_prime, r_prime, n_prime, s_prime, (theta_prime, r_prime)
    
    # 递归构建子树
    theta_minus, r_minus, theta_plus, r_plus, n, s, _ = build_tree(theta, r, u, v, j-1, epsilon)
    
    if v == -1:
        theta_minus, r_minus, n_prime, s_prime, children = build_tree(
            theta_minus, r_minus, u, v, j-1, epsilon
        )
    else:
        theta_plus, r_plus, n_prime, s_prime, children = build_tree(
            theta_plus, r_plus, u, v, j-1, epsilon
        )
    
    # 终止条件检查
    if s_prime == 1:
        if np.random.rand() < min(1, n_prime / (n + n_prime)):
            if v == -1:
                theta_new = theta_minus
            else:
                theta_new = theta_plus
    
    n = n + n_prime
    s = s_prime * (is_u_turn_ok(theta_minus, theta_plus, r_minus, r_plus))
    
    return theta_minus, r_minus, theta_plus, r_plus, n, s, theta_new

3.3 NUTS完整算法

def nuts_sampler(log_p, grad_log_p, theta0, n_samples, max_depth=10):
    """NUTS采样器
    
    Args:
        log_p: 对数概率函数
        grad_log_p: 对数概率梯度
        theta0: 初始位置
        n_samples: 采样数量
        max_depth: 最大树深度
    """
    # NUTS使用Dual Averaging自动调参
    epsilon, mu, gamma, t0, kappa = initialize_dual_averaging()
    
    samples = [theta0]
    theta = theta0
    epsilon_bar = 1.0
    H_bar = 0
    
    for m in range(n_samples):
        # 采样动量
        r0 = np.random.randn(len(theta))
        
        # 切片采样
        u = np.random.uniform(0, np.exp(log_p(theta) - 0.5 * np.sum(r0**2)))
        
        # 初始化
        theta_minus = theta_plus = theta
        r_minus = r_plus = r0
        j = 0
        n = 1
        s = 1
        
        while s == 1:
            # 选择方向
            v = np.random.choice([-1, 1])
            
            if v == -1:
                theta_minus, r_minus, _, _, n_prime, s_prime, _ = build_tree(
                    theta_minus, r_minus, u, v, j, epsilon
                )
            else:
                _, _, theta_plus, r_plus, n_prime, s_prime, _ = build_tree(
                    theta_plus, r_plus, u, v, j, epsilon
                )
            
            # 更新
            if s_prime == 1:
                if np.random.rand() < min(1, n_prime / n):
                    if v == -1:
                        theta = theta_minus
                    else:
                        theta = theta_plus
            
            n = n + n_prime
            s = s_prime * check_u_turn(theta_minus, theta_plus, r_minus, r_plus)
            j = j + 1
            
            if j >= max_depth:
                break
        
        # 更新步长
        H_bar = (1 - 1/(m + t0)) * H_bar + (1/(m + t0)) * (0.65 - accept_rate)
        epsilon = 2 ** (m < 0.5 * (n_samples + 1)) * np.exp(-H_bar)
        epsilon_bar = np.exp(m**(-kappa) * np.log(epsilon) + (1 - m**(-kappa)) * np.log(epsilon_bar))
        
        samples.append(theta)
    
    return np.array(samples[1:])

3.4 NUTS的理论基础

黎曼几何视角

HMC/NUTS的深层几何含义:

概念欧几里得空间黎曼流形
度量恒等矩阵 Fisher信息矩阵
距离
动量分布
Leapfrog标准形式需要修正(半隐式辛积分)

Riemannian HMC (RHMC):使用局部度量 的HMC,可以更好地适应参数空间的局部曲率。


4. 切片采样(Slice Sampling)

4.1 核心思想

切片采样通过引入辅助变量将1-D分布的采样转化为2-D均匀分布的采样:

然后从边缘分布中采样:

4.2 单变量切片采样

def slice_sampling_1d(log_p, x0, w=1.0, max_steps=100):
    """单变量切片采样
    
    Args:
        log_p: 对数概率密度函数
        x0: 初始点
        w: 步长参数(建议区间宽度)
        max_steps: 最大收缩步数
    """
    x = x0
    
    for _ in range(max_steps):
        # 步骤1:计算对数密度
        log_pdf = log_p(x)
        
        # 步骤2:采样辅助变量u
        u = np.random.uniform(0, np.exp(log_pdf))
        log_u = np.log(u)
        
        # 步骤3:找到x的邻域 [x_l, x_r]
        x_l = x - np.random.uniform(0, w)
        x_r = x_l + w
        
        # 步骤4:收缩直到找到有效区间
        while log_p(x_l) > log_u:
            x_l -= w
        while log_p(x_r) > log_u:
            x_r += w
        
        # 步骤5:在区间内采样
        while True:
            x_new = np.random.uniform(x_l, x_r)
            if log_p(x_new) > log_u:
                return x_new
            elif x_new < x:
                x_l = x_new
            else:
                x_r = x_new

4.3 多变量切片采样

方向采样法

def slice_sampling_multivariate(log_p, x0, n_steps=10):
    """多变量切片采样
    
    使用随机方向采样方法
    """
    x = x0.copy()
    D = len(x)
    
    for _ in range(n_steps):
        # 计算当前密度
        log_pdf = log_p(x)
        
        # 切片
        u = np.random.uniform(0, np.exp(log_pdf))
        log_u = np.log(u)
        
        # 选择随机方向
        direction = np.random.randn(D)
        direction = direction / np.linalg.norm(direction)
        
        # 沿方向找区间
        x_d = np.dot(x, direction)  # 投影
        
        # 简单的收缩逻辑
        x_new = x.copy()  # 实际实现需要更复杂的区间收缩
        x = x_new
    
    return x

4.4 切片采样的优势

特性说明
自适应步长自动调整提议宽度
无接受率问题始终接受(理论上)
无梯度需求适用于黑盒函数
自动缩放适应目标分布尺度

5. MCMC收敛诊断

5.1 诊断的重要性

MCMC链需要足够长才能”忘记”初始位置,并充分探索目标分布。

5.2 Gelman-Rubin诊断 ()

经典

定义

其中:

  • :链内方差
  • :后验方差估计
  • :链间方差

解读

  • :收敛
  • :可能未收敛

改进的Rank-based

问题:经典 在以下情况会失效:

  1. 链有不同方差
  2. 分布有重尾(如Cauchy)
  3. 单链未充分混合

改进方案(Vehtari et al., 2019):

  1. Rank归一化

    其中 是所有样本的rank。

  2. Folded (检测方差差异):

def compute_rank_rhat(samples):
    """计算改进的Rank-based R-hat
    
    Args:
        samples: shape (M, N) 的样本,M是链数,N是每链样本数
    
    Returns:
        rhat: 诊断值
    """
    M, N = samples.shape
    S = M * N
    
    # 所有样本排序
    all_samples = samples.flatten()
    ranks = np.argsort(np.argsort(all_samples)) + 1
    
    # 重新组织为 (M, N)
    ranks = ranks.reshape(M, N)
    
    # Rank归一化
    z = stats.norm.ppf((ranks - 0.375) / (S + 0.25))
    
    # 计算链内方差
    z_mean_per_chain = z.mean(axis=1, keepdims=True)
    W = ((z - z_mean_per_chain) ** 2).sum() / (M * N - M)
    
    # 计算后验方差
    z_mean = z.mean()
    B = N * ((z_mean_per_chain - z_mean) ** 2).sum() / (M - 1)
    var_hat = (N - 1) / N * W + B / N
    
    rhat = np.sqrt(var_hat / W)
    return rhat

5.3 有效样本量(ESS)

Bulk-ESS

Bulk-ESS测量分布中心的有效样本量:

其中 是滞后 的自相关。

def compute_ess(samples, min_lag=1, max_lag=None):
    """计算有效样本量
    
    Args:
        samples: shape (M, N) 的样本
        min_lag: 最小滞后
        max_lag: 最大滞后
    
    Returns:
        ess: Bulk-ESS
    """
    M, N = samples.shape
    
    # 展平所有链
    flat_samples = samples.flatten()
    n = len(flat_samples)
    
    if max_lag is None:
        max_lag = n // 2
    
    # 计算自相关
    mean = flat_samples.mean()
    var = flat_samples.var()
    
    acov = np.zeros(max_lag - min_lag + 1)
    acov[0] = ((flat_samples - mean) ** 2).mean()
    
    for lag in range(min_lag, max_lag + 1):
        acov[lag - min_lag + 1] = ((flat_samples[:-lag] - mean) * 
                                   (flat_samples[lag:] - mean)).mean()
    
    # 拖尾自相关
    rho = np.ones(len(acov))
    for i, acov_i in enumerate(acov):
        if acov_i <= 0:
            rho[i] = 0
        else:
            rho[i] = acov_i / acov[0]
        
        if i > 0 and acov_i < 0:
            break
    
    # ESS计算
    ess = M * N / (1 + 2 * rho[1:].sum())
    
    return int(ess)

Tail-ESS

Tail-ESS测量分布尾部的ESS(通常取2.5%和97.5%分位数附近):

对于关键统计量(均值、方差、尾部),建议:

  • Bulk-ESS ≥ 400(确保中心估计稳定)
  • Tail-ESS ≥ 400(确保尾部形状准确)

5.4 Geweke检验

Geweke检验比较链开头和结尾部分的均值:

def geweke_test(chain, first_frac=0.1, last_frac=0.5):
    """Geweke收敛诊断
    
    Args:
        chain: 单条链的样本
        first_frac: 开头部分的比例
        last_frac: 结尾部分的比例
    
    Returns:
        z_score: 如果|z| < 2,认为收敛
    """
    n = len(chain)
    first = chain[:int(n * first_frac)]
    last = chain[int(n * (1 - last_frac)):]
    
    mean1, var1 = first.mean(), first.var() / len(first)
    mean2, var2 = last.mean(), last.var() / len(last)
    
    z = (mean1 - mean2) / np.sqrt(var1 + var2)
    return z

5.5 实践指南

诊断流程

def diagnose_mcmc(samples, var_names=None):
    """完整的MCMC诊断
    
    Args:
        samples: dict of {var_name: (M, N) array}
        var_names: 变量名列表
    """
    if var_names is None:
        var_names = list(samples.keys())
    
    results = {}
    
    for name in var_names:
        chain = samples[name]  # shape: (M, N)
        
        # 计算诊断量
        rhat = compute_rank_rhat(chain)
        ess = compute_ess(chain)
        tail_ess = compute_tail_ess(chain)
        
        # Monte Carlo误差
        mcse = chain.var() / np.sqrt(ess)
        
        results[name] = {
            'Rhat': rhat,
            'Bulk-ESS': ess,
            'Tail-ESS': tail_ess,
            'MCSE': mcse
        }
        
        # 打印报告
        print(f"\n{name}:")
        print(f"  R-hat: {rhat:.4f} {'✓' if rhat < 1.05 else '⚠'}")
        print(f"  Bulk-ESS: {ess} {'✓' if ess > 400 else '⚠'}")
        print(f"  Tail-ESS: {tail_ess} {'✓' if tail_ess > 400 else '⚠'}")
        print(f"  MCSE: {mcse:.6f}")
    
    return results

收敛判断标准

指标收敛标准不收敛时处理
< 1.05 (推荐), < 1.1 (可接受)增加迭代次数、使用不同初始化
Bulk-ESS≥ 400 (推荐), ≥ 100 (最小)增加链数、增加迭代次数
Tail-ESS≥ 400 (推荐)同上

常见问题与解决方案

问题症状解决方案
未混合ESS低、自相关高减小步长、重新参数化
多模态链只在部分模态使用并行Tempering
高度相关慢速混合使用几何平均参数化
初始化敏感不同初始化结果不同延长burn-in、增加迭代

6. 工具与实践

6.1 PyMC实现

import pymc as pm
import numpy as np
 
# 简单示例:线性回归的贝叶斯推断
np.random.seed(42)
x = np.random.randn(100)
y = 2 * x + 1 + 0.5 * np.random.randn(100)
 
with pm.Model() as model:
    # 先验
    alpha = pm.Normal('alpha', mu=0, sigma=10)
    beta = pm.Normal('beta', mu=0, sigma=10)
    sigma = pm.HalfNormal('sigma', sigma=1)
    
    # 似然
    mu = alpha + beta * x
    likelihood = pm.Normal('y', mu=mu, sigma=sigma, observed=y)
    
    # NUTS采样
    trace = pm.sample(2000, tune=1000, chains=4, cores=2)
    
    # 诊断
    print(pm.summary(trace))
    pm.plot_trace(trace)

6.2 Stan实现

data {
  int<lower=0> N;
  vector[N] x;
  vector[N] y;
}
 
parameters {
  real alpha;
  real beta;
  real<lower=0> sigma;
}
 
model {
  // 先验
  alpha ~ normal(0, 10);
  beta ~ normal(0, 10);
  sigma ~ normal(0, 1);
  
  // 似然
  y ~ normal(alpha + beta * x, sigma);
}

6.3 NumPyro实现(支持GPU加速)

import jax.numpy as jnp
from jax import random
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
 
def model(x, y=None):
    alpha = numpyro.sample('alpha', dist.Normal(0, 10))
    beta = numpyro.sample('beta', dist.Normal(0, 10))
    sigma = numpyro.sample('sigma', dist.HalfNormal(1))
    
    mu = alpha + beta * x
    numpyro.sample('y', dist.Normal(mu, sigma), obs=y)
 
# NUTS采样
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=4)
mcmc.run(random.PRNGKey(0), x=jnp.array(x), y=jnp.array(y))

参考文献