变分推断与概率图模型统一框架
1. 引言
变分推断(Variational Inference, VI)和概率图模型(Probabilistic Graphical Models, PGM)是贝叶斯推断的两大支柱。1 当这两者结合时,形成了一个强大的统一框架,能够处理大规模概率模型的推断问题。
本文从概率编程的角度出发,探讨变分推断与概率图模型的深层联系,揭示消息传递、平均场近似和变分自由能之间的统一性。
2. 概率图模型基础
2.1 因子图表示
概率图模型通过图结构编码变量之间的依赖关系:
其中 是因子(factor), 是配分函数。
因子图是表示概率分布的另一种方式:
- 变量节点:表示随机变量
- 因子节点:表示势函数
- 边:连接变量和因子
2.2 消息传递范式
在因子图上,信念传播(Belief Propagation)通过消息传递进行推断:
变量到因子的消息:
因子到变量的消息:
3. 变分推断基础
3.1 变分推断的核心思想
变分推断的核心是用参数化的分布族 去近似真实后验 ,通过最小化 KL 散度:
等价于最大化变分下界(ELBO):
3.2 平均场近似
平均场近似假设后验分布可以分解为独立因子的乘积:
这个假设与概率图模型的因子分解形式完全对应。
4. 统一框架:变分消息传递
4.1 从 BP 到 VI
信念传播和变分推断在数学上有着深刻的联系。考虑因子图上的变分推断:
因子势函数的变分:
KL 散度分解:
4.2 变分消息传递算法
变分消息传递(Variational Message Passing, VMP)将变分推断表述为消息传递:
因子节点更新:
变量节点更新:
4.3 算法收敛性
定理(变分消息传递收敛性)2:
如果因子图不含环,则变分消息传递算法在有限迭代内收敛到唯一的稳态分布,对应平均场近似的最优解。
5. 具体实例推导
5.1 高斯混合模型的变分推断
考虑高斯混合模型(GMM)的变分推断:
模型定义:
平均场近似:
更新 :
计算得到:
这正是 E 步 的更新公式。
5.2 因子分析与变分推断
因子分析模型:
后验分布:
其中:
变分推断版本:
更新规则与 EM 算法相同,但采用随机变分推断(SVI)进行大规模数据处理。
6. 变分推断与信念传播的联系
6.1 数学对应
| 变分推断 | 信念传播 | 含义 |
|---|---|---|
| ELBO | 自由能 | 优化目标 |
| 平均场分解 | 图分解 | 近似结构 |
| 坐标上升更新 | 消息传递 | 优化算法 |
| 期望计算 | 消息计算 | 计算步骤 |
6.2 环上的变分推断
当概率图模型包含环时,信念传播可能不收敛。变分推断提供了系统性的解决方案:
外推方法:
其中 是外推系数, 是标准消息更新。
期望传播(Expectation Propagation):
使用矩匹配代替消息传递,处理环状结构:
7. 变分推断的概率编程视角
7.1 概率编程语言
在概率编程语言(如 PyMC、Stan、Edward)中,变分推断和概率图模型被统一在同一个框架下:
# PyMC3 示例
import pymc3 as pm
with pm.Model() as model:
# 先验
theta = pm.Beta('theta', alpha=1, beta=1)
# 似然
y = pm.Bernoulli('y', p=theta, observed=data)
# 变分推断
approx = pm.fit(n=10000, method='advi')
trace = approx.sample(1000)7.2 自动变分推断
现代概率编程库实现了自动变分推断:
# Pyro 示例
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
def model(data):
# 先验
probs = pyro.sample('probs', dist.Beta(torch.ones(2), torch.ones(2)))
# 似然
with pyro.plate('data', len(data)):
return pyro.sample('obs', dist.Bernoulli(probs), obs=data)
def guide(data):
# 变分分布
alpha_q = pyro.param('alpha_q', torch.ones(2), constraint=dist.constraints.positive)
beta_q = pyro.param('beta_q', torch.ones(2), constraint=dist.constraints.positive)
return pyro.sample('probs', dist.Beta(alpha_q, beta_q))
# SVI 训练
svi = SVI(model, guide, pyro.optim.Adam({'lr': 0.01}), Trace_ELBO())
for step in range(1000):
loss = svi.step(data)7.3 黑箱变分推断(BBVI)
当模型的对数似然不可微或过于复杂时,使用黑箱变分推断:
其中 。
8. 随机变分推断(SVI)
8.1 大规模数据处理
当数据规模很大时,全数据批变分推断计算代价高昂。SVI 使用小批量数据:
其中 是小批量大小, 是总数据量。
8.2 PyTorch 实现
class StochasticVariationalInference:
def __init__(self, model, guide, optimizer, data_loader):
self.model = model
self.guide = guide
self.optimizer = optimizer
self.data_loader = data_loader
def step(self):
self.optimizer.zero_grad()
# 小批量数据
batch = next(iter(self.data_loader))
# 重参数化采样
z = self.guide.sample(batch)
# 计算 ELBO
elbo = self.model.elbo(batch, z)
# 反向传播
(-elbo).backward()
self.optimizer.step()
return elbo.item()9. 变分推断的理论保证
9.1 收敛性分析
定理(平均场变分推断的收敛性):
对于指数族分布的平均场变分推断,坐标上升算法收敛到唯一的全局最优解。
证明思路:
- ELBO 是凹函数(相对于每个变分参数)
- 坐标上升在每步都提升 ELBO
- ELBO 有下界(由熵项保证)
- 收敛到不动点,即平均场近似的最优解
9.2 近似误差分析
变分推断的近似误差由两部分组成:
- 结构误差:由平均场近似引入
- 参数误差:由有限样本估计引入
10. 与深度学习的统一
10.1 变分自编码器(VAE)
VAE 是变分推断与深度学习的完美结合:
使用变分下界:
10.2 概率图模型与神经网络的统一
graph TB A[概率图模型] --> B[变分推断] A --> C[信念传播] B --> D[平均场近似] B --> E[黑箱变分推断] C --> F[环状图扩展] D --> G[SVI] E --> H[重参数化技巧] F --> I[期望传播] G --> J[概率编程] H --> J I --> J J --> K[VAE] J --> L[深度概率模型]
11. 实践指南
11.1 变分分布选择
| 模型结构 | 推荐变分分布 | 说明 |
|---|---|---|
| 离散 latent | Categorical | 类别变量 |
| 非负参数 | Gamma, Log-Normal | 保证正性 |
| 概率参数 | Beta, Dirichlet | 概率参数 |
| 连续向量 | Gaussian, t-distribution | 通用选择 |
11.2 常用优化器
- Adam:通用首选
- RMSProp:适合方差不稳定的场景
- Natural Gradient:利用 Fisher 信息矩阵
11.3 诊断方法
- ELBO 监控:确保 ELBO 稳定上升
- 后验预测检验:比较预测分布与观测
- Geweke 检验:比较前后段样本