变分推断

变分推断(Variational Inference, VI)是贝叶斯统计中一种高效的近似推断方法,通过优化一个近似分布来逼近真实后验分布。1 在现代机器学习中,变分推断是变分自编码器(VAE)、变分信息瓶颈(VIB)等模型的理论基础。

问题背景

贝叶斯推断

在贝叶斯框架下,我们关心后验分布:

其中:

  • 先验分布(Prior)
  • 似然函数(Likelihood)
  • 后验分布(Posterior)
  • 边缘似然(Evidence)

计算挑战

边缘似然 通常难以计算(积分无解析解),这导致:

  • 无法直接计算后验
  • 无法使用最大似然估计(MLE)

变分推断的核心思想:用一个简单的分布 来近似复杂的真实后验


变分推断原理

优化目标

寻找最优的近似分布

这等价于最小化两个分布之间的 KL 散度。

分解与推导

对 KL 散度进行展开:

代入贝叶斯公式

证据下界(ELBO)

重新整理得到:

由于 KL 散度非负:

就是证据下界(Evidence Lower Bound, ELBO)。


ELBO 的信息论解释

两种视角

视角一:重构 + 正则化

含义
重构损失: 应该能重建
正则项: 不应偏离先验太远

视角二:信息瓶颈

重新审视 ELBO:

  • :输入与表示之间的互信息(信息保留)
  • :与先验的偏离(压缩)

这与信息瓶颈理论的目标高度一致!

直观理解

┌──────────────────────────────────────────────────────┐
│                                                      │
│    log p(x)                                          │
│      │                    ╭────────── ELBO           │
│      │                   ╱                           │
│      │                  ╱                            │
│      │                 ╱   ════════════════         │
│      │                ╱                             │
│      │               ╱                              │
│      │              ╱                                │
│      │             ╱    KL(q || p(z|x))             │
│      │            ╱                                  │
│      └────────────╰─────────────────────────────────→ q 的复杂度
│                                                      │
└──────────────────────────────────────────────────────┘

平均场变分推断

分解假设

平均场假设(Mean Field Assumption)将近似分布分解为独立因子的乘积:

其中每个 只依赖对应的隐变量

坐标上升更新

对于平均场变分推断,可以通过坐标上升法(Coordinate Ascent)求解:

最优解

变分混合模型示例

import numpy as np
import scipy.special as sp
 
class VariationalGaussianMixture:
    """
    使用变分推断的高斯混合模型
    
    假设数据由 K 个高斯分布生成:
    - z ~ Cat(π): 混合系数
    - x|z=k ~ N(μ_k, Σ_k): 观测分布
    """
    def __init__(self, n_components, max_iter=100, tol=1e-6):
        self.K = n_components
        self.max_iter = max_iter
        self.tol = tol
        
    def _init_params(self, X):
        N, D = X.shape
        self.pi = np.ones(self.K) / self.K  # 混合系数
        self.mu = X[np.random.choice(N, self.K, False)]  # 均值
        self.var = np.ones((self.K, D))  # 方差
        
    def fit(self, X):
        self._init_params(X)
        N = X.shape[0]
        
        for it in range(self.max_iter):
            # E-step: 计算隐变量的后验
            log_rho = np.zeros((N, self.K))
            for k in range(self.K):
                log_rho[:, k] = np.log(self.pi[k]) + \
                                self._log_gaussian_pdf(X, self.mu[k], self.var[k])
            
            # 归一化(log-sum-exp 技巧)
            log_rho_max = np.max(log_rho, axis=1, keepdims=True)
            rho = np.exp(log_rho - log_rho_max)
            rho = rho / np.sum(rho, axis=1, keepdims=True)
            
            # 计算 ELBO
            elbo = self._compute_elbo(X, rho)
            
            # M-step: 更新参数
            Nk = np.sum(rho, axis=0)  # effective counts
            self.pi = Nk / N
            
            for k in range(self.K):
                # 更新均值
                self.mu[k] = np.sum(rho[:, k:k+1] * X, axis=0) / Nk[k]
                # 更新方差
                diff = X - self.mu[k]
                self.var[k] = np.sum(rho[:, k:k+1] * diff**2, axis=0) / Nk[k]
            
            # 检查收敛
            if it > 0 and abs(elbo - prev_elbo) < self.tol:
                break
            prev_elbo = elbo
                
        return self
    
    def _log_gaussian_pdf(self, X, mu, var):
        D = X.shape[1]
        return -0.5 * D * np.log(2*np.pi) - 0.5 * np.sum(np.log(var)) - \
               0.5 * np.sum((X - mu)**2 / var, axis=1)
    
    def _compute_elbo(self, X, rho):
        """计算 ELBO 的各部分"""
        N = X.shape[0]
        
        # 重构项
        recon = 0
        for k in range(self.K):
            recon += np.sum(rho[:, k] * self._log_gaussian_pdf(X, self.mu[k], self.var[k]))
        
        # KL 项 (pi)
        pi_kl = np.sum(rho * (np.log(self.pi + 1e-10) - np.log(rho + 1e-10)))
        
        return recon - pi_kl

变分推断 vs MCMC

特性变分推断MCMC
速度快(优化问题)慢(采样)
精度有偏近似无偏(渐近)
扩展性易扩展到大数据困难
收敛诊断直接困难
适用场景大规模、快速原型高精度需求

在机器学习中的应用

1. 变分自编码器(VAE)

VAE 使用变分推断来学习隐变量表示:2

class VAE(nn.Module):
    """
    变分自编码器
    
    ELBO = E_q[log p(x|z)] - D_KL(q(z|x) || p(z))
    """
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        # 编码器:q(z|x)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 2 * latent_dim)  # mu and log_var
        )
        # 解码器:p(x|z)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, input_dim)
        )
        
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        # 编码
        h = self.encoder(x)
        mu, log_var = h.chunk(2, dim=-1)
        z = self.reparameterize(mu, log_var)
        
        # 解码
        x_recon = self.decoder(z)
        
        return x_recon, mu, log_var
    
    def loss(self, x):
        x_recon, mu, log_var = self.forward(x)
        
        # 重构损失(二次误差或 BCE)
        recon_loss = F.mse_loss(x_recon, x, reduction='sum')
        
        # KL 散度:D_KL(N(mu, sigma) || N(0, I))
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # ELBO
        elbo = recon_loss + kl_loss
        
        return elbo / x.size(0), recon_loss / x.size(0), kl_loss / x.size(0)

2. 变分信息瓶颈(VIB)

信息瓶颈理论

3. 贝叶斯神经网络

使用变分推断进行贝叶斯后验近似:

class BayesianLinear(nn.Module):
    """
    贝叶斯线性层
    
    使用变分推断近似权重后验 q(w|θ)
    """
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # 变分参数
        self.weight_mu = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
        self.weight_log_var = nn.Parameter(torch.zeros(out_features, in_features))
        
        # 先验(标准高斯)
        self.prior_mu = torch.zeros(out_features, in_features)
        self.prior_log_var = torch.zeros(out_features, in_features)
        
    def forward(self, x):
        # 从近似后验采样权重
        weight = self.weight_mu + torch.randn_like(self.weight_mu) * \
                 torch.exp(0.5 * self.weight_log_var)
        return F.linear(x, weight)
    
    def kl_loss(self):
        """计算与先验的 KL 散度"""
        q_mean, q_log_var = self.weight_mu, self.weight_log_var
        p_mean, p_log_var = self.prior_mu, self.prior_log_var
        
        kl = 0.5 * torch.sum(
            p_log_var - q_log_var + 
            (q_log_var.exp() + (q_mean - p_mean).pow(2)) / p_log_var.exp() - 1
        )
        return kl

进阶主题

黑盒变分推断(BBVI)

当期望无法解析计算时,使用随机梯度变分推断(SGVI):

共轭性

模型先验似然后验
Beta-BinomialBetaBinomialBeta
Dirichlet-MultinomialDirichletMultinomialDirichlet
Gaussian-GaussianNormalNormalNormal

共轭先验允许解析计算后验,简化变分推断。


核心公式速查

概念公式
KL 散度目标
ELBO 定义
ELBO 等价形式
KL 散度展开
平均场最优解

参考

相关文章

Footnotes

  1. Blei, D.M., Kucukelbir, A., & McAuliffe, J.D. (2017). “Variational Inference: A Review for Statisticians”. Journal of the American Statistical Association, 112(518), 859-877.

  2. Kingma, D.P., & Welling, M. (2014). “Auto-Encoding Variational Bayes”. ICLR.