变分推断与概率图模型统一框架

1. 引言

变分推断(Variational Inference, VI)和概率图模型(Probabilistic Graphical Models, PGM)是贝叶斯推断的两大支柱。1 当这两者结合时,形成了一个强大的统一框架,能够处理大规模概率模型的推断问题。

本文从概率编程的角度出发,探讨变分推断与概率图模型的深层联系,揭示消息传递、平均场近似和变分自由能之间的统一性。

2. 概率图模型基础

2.1 因子图表示

概率图模型通过图结构编码变量之间的依赖关系:

其中 是因子(factor), 是配分函数。

因子图是表示概率分布的另一种方式:

  • 变量节点:表示随机变量
  • 因子节点:表示势函数
  • :连接变量和因子

2.2 消息传递范式

在因子图上,信念传播(Belief Propagation)通过消息传递进行推断:

变量到因子的消息

因子到变量的消息

3. 变分推断基础

3.1 变分推断的核心思想

变分推断的核心是用参数化的分布族 去近似真实后验 ,通过最小化 KL 散度:

等价于最大化变分下界(ELBO)

3.2 平均场近似

平均场近似假设后验分布可以分解为独立因子的乘积:

这个假设与概率图模型的因子分解形式完全对应。

4. 统一框架:变分消息传递

4.1 从 BP 到 VI

信念传播和变分推断在数学上有着深刻的联系。考虑因子图上的变分推断:

因子势函数的变分

KL 散度分解

4.2 变分消息传递算法

变分消息传递(Variational Message Passing, VMP)将变分推断表述为消息传递:

因子节点更新

变量节点更新

4.3 算法收敛性

定理(变分消息传递收敛性)2

如果因子图不含环,则变分消息传递算法在有限迭代内收敛到唯一的稳态分布,对应平均场近似的最优解。

5. 具体实例推导

5.1 高斯混合模型的变分推断

考虑高斯混合模型(GMM)的变分推断:

模型定义

平均场近似

更新

计算得到:

这正是 E 步 的更新公式。

5.2 因子分析与变分推断

因子分析模型:

后验分布

其中:

变分推断版本

更新规则与 EM 算法相同,但采用随机变分推断(SVI)进行大规模数据处理。

6. 变分推断与信念传播的联系

6.1 数学对应

变分推断信念传播含义
ELBO自由能优化目标
平均场分解图分解近似结构
坐标上升更新消息传递优化算法
期望计算消息计算计算步骤

6.2 环上的变分推断

当概率图模型包含环时,信念传播可能不收敛。变分推断提供了系统性的解决方案:

外推方法

其中 是外推系数, 是标准消息更新。

期望传播(Expectation Propagation)
使用矩匹配代替消息传递,处理环状结构:

7. 变分推断的概率编程视角

7.1 概率编程语言

在概率编程语言(如 PyMC、Stan、Edward)中,变分推断和概率图模型被统一在同一个框架下:

# PyMC3 示例
import pymc3 as pm
 
with pm.Model() as model:
    # 先验
    theta = pm.Beta('theta', alpha=1, beta=1)
    
    # 似然
    y = pm.Bernoulli('y', p=theta, observed=data)
    
    # 变分推断
    approx = pm.fit(n=10000, method='advi')
    trace = approx.sample(1000)

7.2 自动变分推断

现代概率编程库实现了自动变分推断

# Pyro 示例
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
 
def model(data):
    # 先验
    probs = pyro.sample('probs', dist.Beta(torch.ones(2), torch.ones(2)))
    # 似然
    with pyro.plate('data', len(data)):
        return pyro.sample('obs', dist.Bernoulli(probs), obs=data)
 
def guide(data):
    # 变分分布
    alpha_q = pyro.param('alpha_q', torch.ones(2), constraint=dist.constraints.positive)
    beta_q = pyro.param('beta_q', torch.ones(2), constraint=dist.constraints.positive)
    return pyro.sample('probs', dist.Beta(alpha_q, beta_q))
 
# SVI 训练
svi = SVI(model, guide, pyro.optim.Adam({'lr': 0.01}), Trace_ELBO())
for step in range(1000):
    loss = svi.step(data)

7.3 黑箱变分推断(BBVI)

当模型的对数似然不可微或过于复杂时,使用黑箱变分推断:

其中

8. 随机变分推断(SVI)

8.1 大规模数据处理

当数据规模很大时,全数据批变分推断计算代价高昂。SVI 使用小批量数据:

其中 是小批量大小, 是总数据量。

8.2 PyTorch 实现

class StochasticVariationalInference:
    def __init__(self, model, guide, optimizer, data_loader):
        self.model = model
        self.guide = guide
        self.optimizer = optimizer
        self.data_loader = data_loader
    
    def step(self):
        self.optimizer.zero_grad()
        
        # 小批量数据
        batch = next(iter(self.data_loader))
        
        # 重参数化采样
        z = self.guide.sample(batch)
        
        # 计算 ELBO
        elbo = self.model.elbo(batch, z)
        
        # 反向传播
        (-elbo).backward()
        self.optimizer.step()
        
        return elbo.item()

9. 变分推断的理论保证

9.1 收敛性分析

定理(平均场变分推断的收敛性)

对于指数族分布的平均场变分推断,坐标上升算法收敛到唯一的全局最优解。

证明思路

  1. ELBO 是凹函数(相对于每个变分参数)
  2. 坐标上升在每步都提升 ELBO
  3. ELBO 有下界(由熵项保证)
  4. 收敛到不动点,即平均场近似的最优解

9.2 近似误差分析

变分推断的近似误差由两部分组成:

  • 结构误差:由平均场近似引入
  • 参数误差:由有限样本估计引入

10. 与深度学习的统一

10.1 变分自编码器(VAE)

VAE 是变分推断与深度学习的完美结合:

使用变分下界:

10.2 概率图模型与神经网络的统一

graph TB
    A[概率图模型] --> B[变分推断]
    A --> C[信念传播]
    B --> D[平均场近似]
    B --> E[黑箱变分推断]
    C --> F[环状图扩展]
    D --> G[SVI]
    E --> H[重参数化技巧]
    F --> I[期望传播]
    G --> J[概率编程]
    H --> J
    I --> J
    J --> K[VAE]
    J --> L[深度概率模型]

11. 实践指南

11.1 变分分布选择

模型结构推荐变分分布说明
离散 latentCategorical类别变量
非负参数Gamma, Log-Normal保证正性
概率参数Beta, Dirichlet概率参数
连续向量Gaussian, t-distribution通用选择

11.2 常用优化器

  • Adam:通用首选
  • RMSProp:适合方差不稳定的场景
  • Natural Gradient:利用 Fisher 信息矩阵

11.3 诊断方法

  1. ELBO 监控:确保 ELBO 稳定上升
  2. 后验预测检验:比较预测分布与观测
  3. Geweke 检验:比较前后段样本

12. 参考资料

Footnotes

  1. Blei et al. (2017). “Variational Inference: A Review for Statisticians.” JASA 2017.

  2. Winn & Bishop (2005). “Variational Message Passing.” JMLR 2005.