结构化变分推断
1. 概述
结构化变分推断(Structured Variational Inference) 是一类重要的变分推断方法,它放弃平均场假设,允许变分分布保留变量之间的相关性结构,从而提高近似精度。
核心问题:
- 平均场变分推断假设
- 这忽略了变量间的所有相关性
- 对于强耦合系统,平均场近似可能很差
解决方案:
- 利用问题的结构化先验
- 设计保留相关性的变分族
- 在表达力和计算复杂度间取得平衡
典型应用:
- 隐马尔可夫模型(HMM)
- 因子分析(FA)
- 混合模型(Mixture Models)
- 动态系统(Kalman Filter)
2. 从平均场到结构化
2.1 平均场近似的局限性
平均场变分推断:
问题示例:考虑两个强相关的变量
设真实分布:
平均场近似 :
由于假设独立性,。
但真实分布的边缘是退化的(协方差接近1),平均场给出了错误的边缘方差。
2.2 结构化近似的思想
核心观察:问题通常有已知结构:
- HMM中的马尔可夫链结构
- 因子分析中的低秩结构
- 混合模型中的簇结构
结构化变分推断利用这些已知结构设计变分族:
其中 保留了我们关心的相关性。
3. 推理网络与归一化流
3.1 推理网络(Inference Networks)
编码器网络:
优点:
- 可以捕获任意复杂的条件分布
- 利用深度学习的表达能力
- 端到端可训练
缺点:
- 近似精度取决于网络容量
- 缺乏理论保证
3.2 归一化流(Normalizing Flows)
可逆变换:
其中 是简单分布(如高斯), 是可逆变换。
变量变换公式:
常见流:
| 流类型 | 变换 | 行列式计算 |
|---|---|---|
| 仿射流 | $\prod | |
| Planar流 | ||
| 径向流 | ||
| IAF |
3.3 代码实现
import torch
import torch.nn as nn
from torch.distributions import Distribution, Transform
class PlanarFlow(Transform):
"""Planar流:z' = z + u * tanh(w^T z + b)"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# 参数
self.u = nn.Parameter(torch.randn(dim))
self.w = nn.Parameter(torch.randn(dim))
self.b = nn.Parameter(torch.zeros(1))
self._initialize()
def _initialize(self):
"""确保可逆性初始化"""
# w^T u >= -1
w_dot_u = torch.dot(self.w, self.u)
if w_dot_u < -1:
self.u.data = self.u * (-1 / w_dot_u + 1e-3)
def _call(self, z):
"""前向变换"""
activation = torch.tanh(torch.dot(self.w, z) + self.b)
return z + self.u * activation
def _inverse(self, z):
"""逆向变换(需要数值求解)"""
# 简化的逆向实现
def fixed_point(alpha):
return z - self.u * torch.tanh(torch.dot(self.w, alpha) + self.b)
# 迭代求解
alpha = z.clone()
for _ in range(10):
alpha = fixed_point(alpha)
return alpha
def log_abs_det_jacobian(self, z):
"""log |det|的计算"""
activation = torch.tanh(torch.dot(self.w, z) + self.b)
derivative = 1 - activation ** 2
psi = self.w * derivative
# 行列式近似
u_dot_psi = torch.dot(self.u, psi)
return torch.log(torch.abs(1 + u_dot_psi) + 1e-10)
class InverseAutoregressiveFlow(Distribution):
"""逆自回归流(IAF)"""
def __init__(self, base_dist, autoregressive_net, T=4):
"""
Args:
base_dist: 基础分布
autoregressive_net: 自回归网络,返回(mu, log_std)
T: 流层数
"""
super().__init__()
self.base_dist = base_dist
self.T = T
self.ar_net = autoregressive_net
def sample(self, n=1):
z0 = self.base_dist.sample((n,))
z = z0
for t in range(self.T):
m, log_s = self.ar_net(z)
z = m + torch.exp(log_s) * z
return z
def log_prob(self, z):
"""计算log概率(需要逆向采样)"""
# 简化的实现
log_q0 = self.base_dist.log_prob(z).sum(dim=-1)
return log_q0
class NormalizingFlowVAE(nn.Module):
"""带归一化流的VAE"""
def __init__(self, encoder, decoder, flow_depth=4):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.flows = nn.ModuleList([PlanarFlow(dim=latent_dim) for _ in range(flow_depth)])
def forward(self, x):
# 编码
q0_mean, q0_logvar = self.encoder(x)
q0_std = torch.exp(0.5 * q0_logvar)
# 从基础分布采样
z0 = q0_mean + q0_std * torch.randn_like(q0_mean)
# 应用流变换
z = z0
log_det = 0
for flow in self.flows:
z = flow(z)
log_det += flow.log_abs_det_jacobian(z)
# 解码
x_recon = self.decoder(z)
# ELBO
log_px_z = self.reconstruction_loss(x_recon, x)
log_qz = torch.distributions.Normal(q0_mean, q0_std).log_prob(z0).sum(dim=-1)
log_pz = torch.distributions.Normal(0, 1).log_prob(z).sum(dim=-1)
elbo = log_px_z + log_pz - log_qz - log_det
return elbo.mean(), x_recon4. 高斯过程变分推断
4.1 变分高斯过程
高斯过程回归:
精确推断问题:需要 协方差矩阵求逆。
变分方法:引入诱导点 :
其中 是变分分布。
4.2 稀疏变分高斯过程(SVGP)
变分目标:
诱导点似然:
其中:
4.3 实现
class SparseVariationalGP(nn.Module):
"""稀疏变分高斯过程"""
def __init__(self, kernel, inducing_points, noise_std=0.1):
super().__init__()
self.kernel = kernel
self.inducing_points = nn.Parameter(inducing_points) # Z
self.noise_std = noise_std
# 变分参数
self.m = nn.Parameter(torch.randn(len(inducing_points)))
self.L = nn.Parameter(torch.eye(len(inducing_points))) # S = LL^T
def forward(self, x, y=None):
n = len(x)
m = len(self.inducing_points)
# 核矩阵
K_XZ = self.kernel(x, self.inducing_points)
K_ZZ = self.kernel(self.inducing_points, self.inducing_points)
K_ZX = K_XZ.t()
# Cholesky分解
L = torch.linalg.cholesky(K_ZZ + 1e-5 * torch.eye(m))
# A = K_XZ @ K_ZZ^{-1}
A = torch.linalg.solve_triangular(L, K_XZ.t(), upper=False)
A = torch.linalg.solve_triangular(L.t(), A, upper=True)
A = A.t()
# 变分均值和方差
mu = A @ self.m
# S @ K_ZZ^{-1} @ K_ZX
S_Kzx = torch.linalg.solve_triangular(L.t(), K_ZX)
S_Kzx = torch.linalg.solve_triangular(L, S_Kzx, upper=False)
var_f = self.kernel.diag(x) - (A * A).sum(dim=1) + (A * S_Kzx.t()).sum(dim=1)
var_f = torch.clamp(var_f, min=0)
if y is None:
return torch.distributions.Normal(mu, torch.sqrt(var_f + self.noise_std**2))
# ELBO
log_lik = torch.distributions.Normal(mu, torch.sqrt(var_f + self.noise_std**2)).log_prob(y).sum()
# KL(q(u) || p(u))
# p(u) = N(0, K_ZZ)
# q(u) = N(m, S) = N(m, LL^T)
KL = 0.5 * (
torch.trace(torch.linalg.solve_triangular(L.t(), torch.linalg.solve_triangular(L, K_ZZ), upper=True)) +
self.m @ torch.linalg.solve_triangular(L.t(), torch.linalg.solve_triangular(L, self.m.unsqueeze(1)), upper=True).squeeze() -
m +
2 * torch.sum(torch.log(torch.diag(L)))
)
elbo = log_lik - KL
return -elbo # 损失5. 隐马尔可夫模型的变分推断
5.1 HMM结构
状态转移:
发射概率:
变分分布:
5.2 Viterbi变分推断
保持链结构:
def hmm_variational_inference(observations, trans_prior, emit_params, n_iter=100):
"""
HMM的变分推断(保持链结构)
"""
T, K = len(observations), trans_prior.shape[0]
# 初始化变分参数
pi = torch.ones(T, K) / K # q(z_1)
A = torch.ones(T-1, K, K) / K # q(z_t | z_{t-1})
for iteration in range(n_iter):
# E步:更新隐藏状态后验
for t in range(T):
if t == 0:
# q(z_1) ∝ π_0 * emit(z_1)
pi[t] = trans_prior[0] * emit_params[observations[t]]
else:
# q(z_t, z_{t-1}) ∝ q(z_{t-1}) * A * emit(z_t)
for j in range(K):
A[t-1, j] = pi[t-1] * trans_prior[1:, j] * emit_params[observations[t]]
# 归一化
A[t-1] = A[t-1] / A[t-1].sum(dim=1, keepdim=True)
pi[t] = (A[t-1] * trans_prior[1:, None]).sum(dim=0)
# M步:更新参数(简化版本省略)
if check_convergence(pi, A):
break
return pi, A6. 因子分析变分推断
6.1 因子分析模型
潜在变量模型:
其中 是因子加载矩阵, 是特异方差矩阵。
6.2 结构化变分近似
利用低秩结构:
其中 是稠密协方差(保留因子相关性)。
完全平均场(失去低秩结构):
6.3 变分EM算法
class VariationalFactorAnalysis(nn.Module):
"""因子分析的变分推断"""
def __init__(self, n_features, n_factors):
super().__init__()
self.n_features = n_features
self.n_factors = n_factors
# 参数
self.W = nn.Parameter(torch.randn(n_features, n_factors) * 0.1)
self.mu = nn.Parameter(torch.zeros(n_features))
self.log_psi = nn.Parameter(torch.zeros(n_features)) # diag(Psi)
# 变分参数
self.m = nn.Parameter(torch.zeros(n_factors))
self.L = nn.Parameter(torch.eye(n_factors)) # Sigma = L @ L.T
@property
def Psi(self):
return torch.diag(torch.exp(self.log_psi))
@property
def Sigma(self):
return self.L @ self.L.t()
def e_step(self, X):
"""
E步:更新变分参数
"""
n = X.shape[0]
D, K = self.n_features, self.n_factors
# W^T @ Psi^{-1} @ W
Psi_inv = torch.diag(1 / torch.exp(self.log_psi))
WtPsiinvW = self.W.t() @ Psi_inv @ self.W
WtPsiinvX = self.W.t() @ Psi_inv @ (X - self.mu).t()
# Sigma_q = (I + W^T Psi^{-1} W)^{-1}
precision = torch.eye(K) + WtPsiinvW
Sigma_q = torch.linalg.inv(precision)
# m_q = Sigma_q @ W^T Psi^{-1} @ (X - mu)
m_q = Sigma_q @ WtPsiinvX
# E[z] = m_q
# E[zz^T] = Sigma_q + m_q m_q^T
Ez = m_q.t()
EzzT = (Sigma_q.unsqueeze(0) + Ez.unsqueeze(-1) @ Ez.unsqueeze(-2)).sum(dim=0)
return Ez, EzzT
def m_step(self, X, Ez, EzzT):
"""
M步:更新模型参数
"""
n = X.shape[0]
X_centered = X - self.mu
# 更新W
W_new = (X_centered.t() @ Ez) @ torch.linalg.inv(EzzT)
self.W.data = W_new
# 更新Psi
X_recon = X_centered @ Ez.t()
residuals = ((X_centered ** 2).sum(dim=1, keepdim=True) - 2 * (X_centered @ Ez.t() * X_centered @ Ez.t()).sum(dim=1)).mean()
psi_new = (residuals / self.n_features).clamp(min=1e-6)
self.log_psi.data = torch.log(psi_new * torch.ones(self.n_features))
def fit(self, X, n_iter=100):
"""
变分EM算法
"""
for iteration in range(n_iter):
# E步
Ez, EzzT = self.e_step(X)
# M步
self.m_step(X, Ez, EzzT)
if iteration % 10 == 0:
loss = self.elbo(X, Ez, EzzT)
print(f'Iter {iteration}: ELBO = {loss:.4f}')
def elbo(self, X, Ez, EzzT):
"""计算ELBO"""
n = X.shape[0]
D = self.n_features
X_centered = X - self.mu
Psi_inv = torch.diag(1 / torch.exp(self.log_psi))
# Reconstruction项
WtPsiinvW = self.W.t() @ Psi_inv @ self.W
WtPsiinvX = self.W.t() @ Psi_inv @ X_centered.t()
recon = -0.5 * n * D * torch.log(torch.tensor(2 * torch.pi))
recon += -0.5 * n * torch.trace(Psi_inv @ (self.W @ EzzT @ self.W.t()))
recon += torch.trace(WtPsiinvW @ EzzT) * (X_centered ** 2).mean()
recon += -torch.trace(WtPsiinvX @ WtPsiinvX.t())
# KL(q||p)
Sigma_q = self.L @ self.L.t()
kl = 0.5 * (
torch.trace(Sigma_q) +
self.m @ torch.linalg.inv(Sigma_q) @ self.m -
self.n_factors +
torch.log(torch.det(Sigma_q))
)
return recon - kl7. 期望传播与变分消息传递
7.1 期望传播框架
EP核心思想:将复杂后验分解为近似因子的乘积:
其中 是近似因子。
7.2 结构化EP
利用问题的图结构:
def structured_ep(factor_graph, initial_beliefs, n_iter=50):
"""
结构化EP:利用因子图结构
Args:
factor_graph: 因子图
initial_beliefs: 初始信念
"""
beliefs = initial_beliefs.copy()
for iteration in range(n_iter):
# 更新每个因子的近似
for factor in factor_graph.factors:
# 计算充分统计量
stats = compute_sufficient_stats(beliefs, factor)
# 更新近似因子
new_factor = update_approximation(factor, stats)
# 更新信念
for var in factor.variables:
beliefs[var] = update_belief(beliefs[var], factor, new_factor)
if check_convergence(beliefs):
break
return beliefs8. 比较与实践指南
8.1 方法选择矩阵
| 方法 | 保留相关性 | 计算复杂度 | 适用场景 |
|---|---|---|---|
| 平均场 | 无 | O(n) | 弱耦合系统 |
| 归一化流 | 中等 | O(n·K) | 中等复杂度 |
| 高斯过程 | 完整 | O(m²n) | 核方法/函数空间 |
| HMM结构 | 链式 | O(T·K²) | 时序数据 |
| 因子分析 | 低秩 | O(D·K²) | 降维/表示学习 |
8.2 实践建议
-
评估近似质量:
- 使用重要性采样估计真实ELBO
- 比较多次运行的方差
- 检查相关性结构是否被保留
-
计算-精度权衡:
- 复杂变分族 → 更精确但更慢
- 从简单开始,逐步增加复杂度
-
诊断:
- KL散度分解(parks vs individual)
- 尾巴质量(通过采样检验)
- 自一致性检查
9. 总结
结构化变分推断的核心思想:
- 超越平均场:利用已知结构保留变量相关性
- 表达力-计算权衡:在保持可行性的同时提高精度
- 灵活设计:根据问题结构定制变分族
主要方法:
- 归一化流:可逆变换扩展变分族
- 稀疏高斯过程:利用低秩结构
- HMM/FA变分:利用领域特定结构
- EP消息传递:因子图上的结构化更新
选择原则:
- 知道强相关结构 → 结构化变分
- 不知道结构 → 平均场或归一化流
- 需要可扩展性 → 稀疏/诱导点方法