概述

概率推断(Probabilistic Inference)是根据已知信息(观测数据)推断未知量的过程,是统计学和机器学习的核心任务。1

推断的基本问题

在概率论框架下,我们通常面临以下推断问题:

问题类型描述示例
后验推断给定观测数据,推断参数的后验分布
预测推断给定新数据点,预测其属性
边缘推断计算边缘概率分布
条件推断在给定条件下推断某事件

贝叶斯推断框架

贝叶斯定理

贝叶斯推断的核心是贝叶斯定理:

其中:

  • 后验分布(Posterior)
  • 似然函数(Likelihood)
  • 先验分布(Prior)
  • 边际似然(Marginal Likelihood)

先验与后验

先验分布 编码了在观测数据之前的先验知识。

后验分布 综合了先验知识和观测数据的信息。

共轭先验:当先验分布与后验分布属于同一分布族时,称该先验为共轭先验。

似然函数共轭先验后验分布
伯努利BetaBeta
二项BetaBeta
泊松GammaGamma
高斯(均值)高斯高斯
高斯(方差)Inverse-GammaInverse-Gamma

预测分布

贝叶斯预测分布通过边缘化后验分布得到:

这称为后验预测分布(Posterior Predictive Distribution)。


推断的分类

精确推断

当后验分布可以解析计算时,使用精确推断方法。

解析求解

对于共轭先验模型,后验分布有解析形式:

import numpy as np
from scipy import stats
 
# Beta-Bernoulli模型
# 先验: Beta(α, β)
# 似然: Bernoulli(θ)
# 后验: Beta(α + sum(x), β + n - sum(x))
 
def bayesian_bernoulli_inference(data, alpha_prior=1, beta_prior=1):
    """
    Beta-Bernoulli模型的贝叶斯推断
    
    参数:
        data: 观测数据 (0或1)
        alpha_prior: Beta先验参数alpha
        beta_prior: Beta先验参数beta
    
    返回:
        alpha_post: 后验Beta参数alpha
        beta_post: 后验Beta参数beta
    """
    n = len(data)
    k = np.sum(data)
    
    alpha_post = alpha_prior + k
    beta_post = beta_prior + n - k
    
    return alpha_post, beta_post
 
# 示例
data = [1, 1, 0, 1, 1, 0, 1, 1, 1, 0]
alpha, beta = bayesian_bernoulli_inference(data)
print(f"后验分布: Beta({alpha}, {beta})")
 
# 计算后验均值和方差
posterior_mean = alpha / (alpha + beta)
posterior_var = (alpha * beta) / ((alpha + beta)**2 * (alpha + beta + 1))
print(f"后验均值: {posterior_mean:.4f}")
print(f"后验方差: {posterior_var:.4f}")

变量消除法

对于离散模型,可以使用变量消除法(Variable Elimination)计算边缘概率:

def variable_elimination(factors, eliminate_vars, elimination_order):
    """
    变量消除算法
    
    参数:
        factors: 因子列表
        eliminate_vars: 要消除的变量列表
        elimination_order: 消除顺序
    
    返回:
        最终结果因子
    """
    active_factors = list(factors)
    
    for var in elimination_order:
        # 1. 收集所有包含该变量的因子
        relevant = [f for f in active_factors if var in f.variables]
        
        # 2. 乘积
        product = factor_product(relevant)
        
        # 3. 边缘化(求和)
        summed = factor_sum_out(product, var)
        
        # 4. 更新活跃因子列表
        active_factors = [f for f in active_factors if var not in f.variables]
        active_factors.append(summed)
    
    return active_factors[0]

近似推断

当精确推断不可行时(如后验分布无法解析求解或计算复杂度太高),使用近似推断方法。

蒙特卡洛方法

重要性采样

重要性采样通过从提议分布采样来估计期望:

def importance_sampling(p, q, f, n_samples=10000):
    """
    重要性采样
    
    参数:
        p: 目标分布(未归一化)
        q: 提议分布
        f: 目标函数
        n_samples: 采样数量
    """
    samples = q.sample(n_samples)
    weights = p(samples) / q(samples)
    
    # 归一化权重
    weights = weights / np.sum(weights)
    
    # 加权估计
    estimate = np.sum(weights * f(samples))
    
    return estimate
 
# 示例:估计高斯分布的期望
from scipy.stats import norm
 
def target(x):
    return norm.pdf(x, loc=0, scale=1)
 
proposal = norm(loc=0, scale=2)
samples = proposal.rvs(10000)
weights = target(samples) / proposal.pdf(samples)
weights = weights / np.sum(weights)
 
estimate = np.sum(weights * samples)
print(f"重要性采样估计: {estimate:.4f}")
print(f"真实值: 0")

马尔可夫链蒙特卡洛(MCMC)

MCMC通过构建马尔可夫链来采样后验分布。

Metropolis-Hastings算法

def metropolis_hastings(p, q, n_samples, proposal_std=1.0):
    """
    Metropolis-Hastings采样
    
    参数:
        p: 目标分布(未归一化)
        q: 提议分布
        n_samples: 采样数量
        proposal_std: 提议分布的标准差
    """
    samples = []
    current = np.random.randn()
    
    for _ in range(n_samples):
        # 从提议分布采样
        proposal = np.random.randn() * proposal_std + current
        
        # 计算接受率
        p_current = p(current)
        p_proposal = p(proposal)
        
        # 对于对称提议分布
        alpha = min(1, p_proposal / p_current)
        
        # 接受/拒绝
        if np.random.rand() < alpha:
            current = proposal
        
        samples.append(current)
    
    return np.array(samples)

Gibbs采样

def gibbs_sampling(p_conditional, n_samples, init_state):
    """
    Gibbs采样
    
    参数:
        p_conditional: 条件概率函数 p(x_i | x_{-i})
        n_samples: 采样数量
        init_state: 初始状态
    """
    n_dims = len(init_state)
    current = np.array(init_state, dtype=float)
    samples = [current.copy()]
    
    for _ in range(n_samples):
        for i in range(n_dims):
            # 从条件分布采样
            current[i] = p_conditional(current, i)
        samples.append(current.copy())
    
    return np.array(samples)

推断的计算复杂度

图模型中的推断复杂度

推断的计算复杂度取决于图模型的结构:

图结构推断复杂度说明
树形结构变量消除或信念传播高效
链式结构K为状态数
低树宽图tw为树宽
一般图P完全NP难

精确推断的局限性

精确推断在以下情况不可行:

  1. 连续变量:后验分布维度高
  2. 大规模离散模型:状态空间指数爆炸
  3. 模型选择:需要积分所有参数

推断在机器学习中的应用

监督学习中的推断

在监督学习中,推断用于:

  • 参数估计:点估计(MLE、MAP)或分布估计(贝叶斯)
  • 预测
  • 不确定性量化:预测置信区间

无监督学习中的推断

在无监督学习中,推断用于:

  • 聚类:推断数据点所属的簇
  • 降维:推断隐变量
  • 密度估计:推断数据分布

生成模型中的推断

生成模型(如VAE、GAN)中的推断:

  • 编码:将观测数据映射到隐变量空间
  • 先验匹配:确保隐变量服从先验分布
  • 后验推断:给定观测,推断隐变量

推断与其他概念的联系

推断与优化

许多推断问题可以转化为优化问题:

推断方法对应优化
变分推断最小化KL散度
信念传播消息传递优化
期望最大化坐标上升
拉普拉斯近似梯度优化

推断与信息论

推断与信息论有深刻联系:

  • 互信息
  • KL散度:衡量推断分布与真实分布的差异
  • :不确定性度量

详见 信息论


参考

Footnotes

  1. Bishop, C. M. (2006). Pattern Recognition and Machine Learning. Springer.