结构化变分推断

1. 概述

结构化变分推断(Structured Variational Inference) 是一类重要的变分推断方法,它放弃平均场假设,允许变分分布保留变量之间的相关性结构,从而提高近似精度。

核心问题

  • 平均场变分推断假设
  • 这忽略了变量间的所有相关性
  • 对于强耦合系统,平均场近似可能很差

解决方案

  • 利用问题的结构化先验
  • 设计保留相关性的变分族
  • 在表达力和计算复杂度间取得平衡

典型应用

  • 隐马尔可夫模型(HMM)
  • 因子分析(FA)
  • 混合模型(Mixture Models)
  • 动态系统(Kalman Filter)

12

2. 从平均场到结构化

2.1 平均场近似的局限性

平均场变分推断

问题示例:考虑两个强相关的变量

设真实分布:

平均场近似

由于假设独立性,

但真实分布的边缘是退化的(协方差接近1),平均场给出了错误的边缘方差。

2.2 结构化近似的思想

核心观察:问题通常有已知结构

  • HMM中的马尔可夫链结构
  • 因子分析中的低秩结构
  • 混合模型中的簇结构

结构化变分推断利用这些已知结构设计变分族:

其中 保留了我们关心的相关性。

3. 推理网络与归一化流

3.1 推理网络(Inference Networks)

编码器网络

优点

  • 可以捕获任意复杂的条件分布
  • 利用深度学习的表达能力
  • 端到端可训练

缺点

  • 近似精度取决于网络容量
  • 缺乏理论保证

3.2 归一化流(Normalizing Flows)

可逆变换

其中 是简单分布(如高斯), 是可逆变换。

变量变换公式

常见流

流类型变换行列式计算
仿射流$\prod
Planar流
径向流
IAF

3.3 代码实现

import torch
import torch.nn as nn
from torch.distributions import Distribution, Transform
 
class PlanarFlow(Transform):
    """Planar流:z' = z + u * tanh(w^T z + b)"""
    
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        
        # 参数
        self.u = nn.Parameter(torch.randn(dim))
        self.w = nn.Parameter(torch.randn(dim))
        self.b = nn.Parameter(torch.zeros(1))
        
        self._initialize()
    
    def _initialize(self):
        """确保可逆性初始化"""
        # w^T u >= -1
        w_dot_u = torch.dot(self.w, self.u)
        if w_dot_u < -1:
            self.u.data = self.u * (-1 / w_dot_u + 1e-3)
    
    def _call(self, z):
        """前向变换"""
        activation = torch.tanh(torch.dot(self.w, z) + self.b)
        return z + self.u * activation
    
    def _inverse(self, z):
        """逆向变换(需要数值求解)"""
        # 简化的逆向实现
        def fixed_point(alpha):
            return z - self.u * torch.tanh(torch.dot(self.w, alpha) + self.b)
        
        # 迭代求解
        alpha = z.clone()
        for _ in range(10):
            alpha = fixed_point(alpha)
        return alpha
    
    def log_abs_det_jacobian(self, z):
        """log |det|的计算"""
        activation = torch.tanh(torch.dot(self.w, z) + self.b)
        derivative = 1 - activation ** 2
        psi = self.w * derivative
        
        # 行列式近似
        u_dot_psi = torch.dot(self.u, psi)
        return torch.log(torch.abs(1 + u_dot_psi) + 1e-10)
 
 
class InverseAutoregressiveFlow(Distribution):
    """逆自回归流(IAF)"""
    
    def __init__(self, base_dist, autoregressive_net, T=4):
        """
        Args:
            base_dist: 基础分布
            autoregressive_net: 自回归网络,返回(mu, log_std)
            T: 流层数
        """
        super().__init__()
        self.base_dist = base_dist
        self.T = T
        self.ar_net = autoregressive_net
    
    def sample(self, n=1):
        z0 = self.base_dist.sample((n,))
        z = z0
        
        for t in range(self.T):
            m, log_s = self.ar_net(z)
            z = m + torch.exp(log_s) * z
        
        return z
    
    def log_prob(self, z):
        """计算log概率(需要逆向采样)"""
        # 简化的实现
        log_q0 = self.base_dist.log_prob(z).sum(dim=-1)
        return log_q0
 
 
class NormalizingFlowVAE(nn.Module):
    """带归一化流的VAE"""
    
    def __init__(self, encoder, decoder, flow_depth=4):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.flows = nn.ModuleList([PlanarFlow(dim=latent_dim) for _ in range(flow_depth)])
    
    def forward(self, x):
        # 编码
        q0_mean, q0_logvar = self.encoder(x)
        q0_std = torch.exp(0.5 * q0_logvar)
        
        # 从基础分布采样
        z0 = q0_mean + q0_std * torch.randn_like(q0_mean)
        
        # 应用流变换
        z = z0
        log_det = 0
        for flow in self.flows:
            z = flow(z)
            log_det += flow.log_abs_det_jacobian(z)
        
        # 解码
        x_recon = self.decoder(z)
        
        # ELBO
        log_px_z = self.reconstruction_loss(x_recon, x)
        log_qz = torch.distributions.Normal(q0_mean, q0_std).log_prob(z0).sum(dim=-1)
        log_pz = torch.distributions.Normal(0, 1).log_prob(z).sum(dim=-1)
        
        elbo = log_px_z + log_pz - log_qz - log_det
        
        return elbo.mean(), x_recon

4. 高斯过程变分推断

4.1 变分高斯过程

高斯过程回归

精确推断问题:需要 协方差矩阵求逆。

变分方法:引入诱导点

其中 是变分分布。

4.2 稀疏变分高斯过程(SVGP)

变分目标

诱导点似然

其中:

4.3 实现

class SparseVariationalGP(nn.Module):
    """稀疏变分高斯过程"""
    
    def __init__(self, kernel, inducing_points, noise_std=0.1):
        super().__init__()
        self.kernel = kernel
        self.inducing_points = nn.Parameter(inducing_points)  # Z
        self.noise_std = noise_std
        
        # 变分参数
        self.m = nn.Parameter(torch.randn(len(inducing_points)))
        self.L = nn.Parameter(torch.eye(len(inducing_points)))  # S = LL^T
        
    def forward(self, x, y=None):
        n = len(x)
        m = len(self.inducing_points)
        
        # 核矩阵
        K_XZ = self.kernel(x, self.inducing_points)
        K_ZZ = self.kernel(self.inducing_points, self.inducing_points)
        K_ZX = K_XZ.t()
        
        # Cholesky分解
        L = torch.linalg.cholesky(K_ZZ + 1e-5 * torch.eye(m))
        
        # A = K_XZ @ K_ZZ^{-1}
        A = torch.linalg.solve_triangular(L, K_XZ.t(), upper=False)
        A = torch.linalg.solve_triangular(L.t(), A, upper=True)
        A = A.t()
        
        # 变分均值和方差
        mu = A @ self.m
        
        # S @ K_ZZ^{-1} @ K_ZX
        S_Kzx = torch.linalg.solve_triangular(L.t(), K_ZX)
        S_Kzx = torch.linalg.solve_triangular(L, S_Kzx, upper=False)
        
        var_f = self.kernel.diag(x) - (A * A).sum(dim=1) + (A * S_Kzx.t()).sum(dim=1)
        var_f = torch.clamp(var_f, min=0)
        
        if y is None:
            return torch.distributions.Normal(mu, torch.sqrt(var_f + self.noise_std**2))
        
        # ELBO
        log_lik = torch.distributions.Normal(mu, torch.sqrt(var_f + self.noise_std**2)).log_prob(y).sum()
        
        # KL(q(u) || p(u))
        # p(u) = N(0, K_ZZ)
        # q(u) = N(m, S) = N(m, LL^T)
        KL = 0.5 * (
            torch.trace(torch.linalg.solve_triangular(L.t(), torch.linalg.solve_triangular(L, K_ZZ), upper=True)) +
            self.m @ torch.linalg.solve_triangular(L.t(), torch.linalg.solve_triangular(L, self.m.unsqueeze(1)), upper=True).squeeze() -
            m +
            2 * torch.sum(torch.log(torch.diag(L)))
        )
        
        elbo = log_lik - KL
        
        return -elbo  # 损失

5. 隐马尔可夫模型的变分推断

5.1 HMM结构

状态转移

发射概率

变分分布

5.2 Viterbi变分推断

保持链结构

def hmm_variational_inference(observations, trans_prior, emit_params, n_iter=100):
    """
    HMM的变分推断(保持链结构)
    """
    T, K = len(observations), trans_prior.shape[0]
    
    # 初始化变分参数
    pi = torch.ones(T, K) / K          # q(z_1)
    A = torch.ones(T-1, K, K) / K      # q(z_t | z_{t-1})
    
    for iteration in range(n_iter):
        # E步:更新隐藏状态后验
        for t in range(T):
            if t == 0:
                # q(z_1) ∝ π_0 * emit(z_1)
                pi[t] = trans_prior[0] * emit_params[observations[t]]
            else:
                # q(z_t, z_{t-1}) ∝ q(z_{t-1}) * A * emit(z_t)
                for j in range(K):
                    A[t-1, j] = pi[t-1] * trans_prior[1:, j] * emit_params[observations[t]]
                
                # 归一化
                A[t-1] = A[t-1] / A[t-1].sum(dim=1, keepdim=True)
                pi[t] = (A[t-1] * trans_prior[1:, None]).sum(dim=0)
        
        # M步:更新参数(简化版本省略)
        
        if check_convergence(pi, A):
            break
    
    return pi, A

6. 因子分析变分推断

6.1 因子分析模型

潜在变量模型

其中 是因子加载矩阵, 是特异方差矩阵。

6.2 结构化变分近似

利用低秩结构

其中 是稠密协方差(保留因子相关性)。

完全平均场(失去低秩结构):

6.3 变分EM算法

class VariationalFactorAnalysis(nn.Module):
    """因子分析的变分推断"""
    
    def __init__(self, n_features, n_factors):
        super().__init__()
        self.n_features = n_features
        self.n_factors = n_factors
        
        # 参数
        self.W = nn.Parameter(torch.randn(n_features, n_factors) * 0.1)
        self.mu = nn.Parameter(torch.zeros(n_features))
        self.log_psi = nn.Parameter(torch.zeros(n_features))  # diag(Psi)
        
        # 变分参数
        self.m = nn.Parameter(torch.zeros(n_factors))
        self.L = nn.Parameter(torch.eye(n_factors))  # Sigma = L @ L.T
        
    @property
    def Psi(self):
        return torch.diag(torch.exp(self.log_psi))
    
    @property
    def Sigma(self):
        return self.L @ self.L.t()
    
    def e_step(self, X):
        """
        E步:更新变分参数
        """
        n = X.shape[0]
        D, K = self.n_features, self.n_factors
        
        # W^T @ Psi^{-1} @ W
        Psi_inv = torch.diag(1 / torch.exp(self.log_psi))
        WtPsiinvW = self.W.t() @ Psi_inv @ self.W
        WtPsiinvX = self.W.t() @ Psi_inv @ (X - self.mu).t()
        
        # Sigma_q = (I + W^T Psi^{-1} W)^{-1}
        precision = torch.eye(K) + WtPsiinvW
        Sigma_q = torch.linalg.inv(precision)
        
        # m_q = Sigma_q @ W^T Psi^{-1} @ (X - mu)
        m_q = Sigma_q @ WtPsiinvX
        
        # E[z] = m_q
        # E[zz^T] = Sigma_q + m_q m_q^T
        Ez = m_q.t()
        EzzT = (Sigma_q.unsqueeze(0) + Ez.unsqueeze(-1) @ Ez.unsqueeze(-2)).sum(dim=0)
        
        return Ez, EzzT
    
    def m_step(self, X, Ez, EzzT):
        """
        M步:更新模型参数
        """
        n = X.shape[0]
        X_centered = X - self.mu
        
        # 更新W
        W_new = (X_centered.t() @ Ez) @ torch.linalg.inv(EzzT)
        self.W.data = W_new
        
        # 更新Psi
        X_recon = X_centered @ Ez.t()
        residuals = ((X_centered ** 2).sum(dim=1, keepdim=True) - 2 * (X_centered @ Ez.t() * X_centered @ Ez.t()).sum(dim=1)).mean()
        psi_new = (residuals / self.n_features).clamp(min=1e-6)
        self.log_psi.data = torch.log(psi_new * torch.ones(self.n_features))
    
    def fit(self, X, n_iter=100):
        """
        变分EM算法
        """
        for iteration in range(n_iter):
            # E步
            Ez, EzzT = self.e_step(X)
            
            # M步
            self.m_step(X, Ez, EzzT)
            
            if iteration % 10 == 0:
                loss = self.elbo(X, Ez, EzzT)
                print(f'Iter {iteration}: ELBO = {loss:.4f}')
    
    def elbo(self, X, Ez, EzzT):
        """计算ELBO"""
        n = X.shape[0]
        D = self.n_features
        
        X_centered = X - self.mu
        Psi_inv = torch.diag(1 / torch.exp(self.log_psi))
        
        # Reconstruction项
        WtPsiinvW = self.W.t() @ Psi_inv @ self.W
        WtPsiinvX = self.W.t() @ Psi_inv @ X_centered.t()
        
        recon = -0.5 * n * D * torch.log(torch.tensor(2 * torch.pi))
        recon += -0.5 * n * torch.trace(Psi_inv @ (self.W @ EzzT @ self.W.t()))
        recon += torch.trace(WtPsiinvW @ EzzT) * (X_centered ** 2).mean()
        recon += -torch.trace(WtPsiinvX @ WtPsiinvX.t())
        
        # KL(q||p)
        Sigma_q = self.L @ self.L.t()
        kl = 0.5 * (
            torch.trace(Sigma_q) + 
            self.m @ torch.linalg.inv(Sigma_q) @ self.m -
            self.n_factors +
            torch.log(torch.det(Sigma_q))
        )
        
        return recon - kl

7. 期望传播与变分消息传递

7.1 期望传播框架

EP核心思想:将复杂后验分解为近似因子的乘积:

其中 是近似因子。

7.2 结构化EP

利用问题的图结构

def structured_ep(factor_graph, initial_beliefs, n_iter=50):
    """
    结构化EP:利用因子图结构
    
    Args:
        factor_graph: 因子图
        initial_beliefs: 初始信念
    """
    beliefs = initial_beliefs.copy()
    
    for iteration in range(n_iter):
        # 更新每个因子的近似
        for factor in factor_graph.factors:
            # 计算充分统计量
            stats = compute_sufficient_stats(beliefs, factor)
            
            # 更新近似因子
            new_factor = update_approximation(factor, stats)
            
            # 更新信念
            for var in factor.variables:
                beliefs[var] = update_belief(beliefs[var], factor, new_factor)
        
        if check_convergence(beliefs):
            break
    
    return beliefs

8. 比较与实践指南

8.1 方法选择矩阵

方法保留相关性计算复杂度适用场景
平均场O(n)弱耦合系统
归一化流中等O(n·K)中等复杂度
高斯过程完整O(m²n)核方法/函数空间
HMM结构链式O(T·K²)时序数据
因子分析低秩O(D·K²)降维/表示学习

8.2 实践建议

  1. 评估近似质量

    • 使用重要性采样估计真实ELBO
    • 比较多次运行的方差
    • 检查相关性结构是否被保留
  2. 计算-精度权衡

    • 复杂变分族 → 更精确但更慢
    • 从简单开始,逐步增加复杂度
  3. 诊断

    • KL散度分解(parks vs individual)
    • 尾巴质量(通过采样检验)
    • 自一致性检查

9. 总结

结构化变分推断的核心思想

  1. 超越平均场:利用已知结构保留变量相关性
  2. 表达力-计算权衡:在保持可行性的同时提高精度
  3. 灵活设计:根据问题结构定制变分族

主要方法

  • 归一化流:可逆变换扩展变分族
  • 稀疏高斯过程:利用低秩结构
  • HMM/FA变分:利用领域特定结构
  • EP消息传递:因子图上的结构化更新

选择原则

  • 知道强相关结构 → 结构化变分
  • 不知道结构 → 平均场或归一化流
  • 需要可扩展性 → 稀疏/诱导点方法

参考文献

Footnotes

  1. Saul, L. K., et al. (1996). Exploiting Tractable Substructures in Undirected Graphs. Statistical Methods in MRF.

  2. Hoffman, M. D., et al. (2013). Stochastic Block Model Variational Inference. NeurIPS 2013.