概述

变分推断(Variational Inference, VI)是一种将推断问题转化为优化问题的近似推断方法。通过优化一个变分分布来近似真实后验分布。1

变分推断的基本思想

设我们想要求解后验分布 ,其中 是隐变量, 是观测数据。由于精确推断往往不可行,我们引入一个变分分布 来近似

优化目标:最小化 之间的 KL 散度:


数学基础

证据下界(ELBO)

将 KL 散度展开:

利用贝叶斯定理

重新整理得到:

由于 KL 散度非负,我们有:

ELBO(Evidence Lower BOund)是边际似然的下界,最大化 ELBO 等价于最小化 KL 散度。

ELBO 的另一种形式

ELBO 也可以写成期望的形式:

展开联合分布:

这包含两项:

  • 重建项 — 重建观测数据的期望似然
  • 正则化项 — 变分分布与先验的 KL 散度

ELBO 的性质

import numpy as np
 
def compute_elbo(X, model, q_z, n_samples=1000):
    """
    计算ELBO的蒙特卡洛估计
    
    参数:
        X: 观测数据
        model: 生成模型
        q_z: 变分分布
        n_samples: 蒙特卡洛样本数
    """
    # 从变分分布采样
    z_samples = q_z.sample(n_samples)
    
    # 计算重建项
    log_px_given_z = model.log_likelihood(X, z_samples)
    reconstruction = np.mean(log_px_given_z)
    
    # 计算KL项
    log_pz = model.prior.log_prob(z_samples)
    log_qz = q_z.log_prob(z_samples)
    kl = np.mean(log_pz - log_qz)
    
    elbo = reconstruction - kl
    
    return elbo, reconstruction, kl
 
# 示例:检查ELBO与边际似然的关系
# log P(X) = ELBO + D_KL(Q || P)
# 由于 D_KL >= 0,ELBO <= log P(X)

平均场变分家族

平均场假设

平均场变分家族假设隐变量之间相互独立:

其中 是隐变量的划分。

坐标上升变分推断(CAVI)

CAVI 算法通过固定其他变量来更新每个变分因子:

更新公式:对于第 个因子,

这称为变分更新规则

def cavi_update(X, model, q_i, other_q, i):
    """
    CAVI 更新规则
    
    参数:
        X: 观测数据
        model: 模型
        q_i: 第i个变分因子
        other_q: 其他变分因子的乘积
        i: 要更新的因子索引
    
    返回:
        更新后的变分参数
    """
    # 计算期望
    # E_{-Q_i}[log P(X, Z)] = E_{Q_{-i}}[log P(X, Z_i, Z_{-i})]
    
    def expected_log_joint(z_i):
        # 对其他隐变量积分
        return np.mean([model.log_joint(X, z_i, z_other) 
                        for z_other in other_q.sample(1000)])
    
    # 更新变分参数
    new_params = compute_variational_params(expected_log_joint)
    return new_params
 
def cavi_algorithm(X, model, n_iterations=100, tol=1e-4):
    """
    CAVI 算法主循环
    """
    q = initialize_variational_factors(model)
    
    elbo_history = []
    for iteration in range(n_iterations):
        # 依次更新每个变分因子
        for i in range(model.n_latent):
            q[i] = cavi_update(X, model, q[i], 
                              [q[j] for j in range(model.n_latent) if j != i], 
                              i)
        
        # 计算 ELBO
        elbo = compute_elbo(X, model, q)
        elbo_history.append(elbo)
        
        # 检查收敛
        if len(elbo_history) > 1:
            if abs(elbo_history[-1] - elbo_history[-2]) < tol:
                break
    
    return q, elbo_history

变分推断的变体

随机变分推断(SVI)

当数据量大时,使用随机优化来最大化 ELBO:

class StochasticVariationalInference:
    """
    随机变分推断
    """
    def __init__(self, model, q, optimizer, batch_size=32):
        self.model = model
        self.q = q
        self.optimizer = optimizer
        self.batch_size = batch_size
    
    def step(self, X, global_step):
        """
        一步随机梯度更新
        """
        # 采样一个小批量
        batch_idx = np.random.choice(len(X), self.batch_size, replace=False)
        X_batch = X[batch_idx]
        
        # 重参数化采样
        eps = torch.randn(self.batch_size, self.q.dim)
        z = self.q.loc + eps * self.q.scale
        
        # 计算 ELBO 的蒙特卡洛估计
        elbo = self.compute_elbo(X_batch, z)
        
        # 反向传播
        elbo.backward()
        
        # 更新参数
        self.optimizer.step()
        self.optimizer.zero_grad()
        
        return elbo.item()
    
    def compute_elbo(self, X, z):
        """
        计算 ELBO
        """
        # 重建项
        log_px_given_z = self.model.likelihood(X, z)
        
        # KL 项
        log_pz = torch.distributions.Normal(0, 1).log_prob(z)
        log_qz = self.q.log_prob(z)
        
        return log_px_given_z.mean() - (log_pz - log_qz).mean()

黑盒变分推断(BBVI)

BBVI 使用分数函数梯度:

def bbvi_gradient(q, f, n_samples=100):
    """
    BBVI 梯度估计
    
    参数:
        q: 变分分布
        f: 目标函数
        n_samples: 样本数
    """
    samples = q.sample(n_samples)
    log_q = q.log_prob(samples)
    
    # 分数函数梯度
    # ∇_φ E_q_φ[f(z)] ≈ (1/N) Σ f(z_i) ∇_φ log q_φ(z_i)
    gradients = []
    for i in range(n_samples):
        grads = torch.autograd.grad(
            log_q[i], 
            q.parameters(),
            retain_graph=(i < n_samples - 1)
        )
        gradients.append(f(samples[i]) * grads[0])
    
    return torch.stack(gradients).mean(dim=0)

重参数化梯度

当变分分布可微时,使用重参数化技巧:

def reparameterized_gradient(q, f, n_samples=100):
    """
    重参数化梯度估计
    
    对于 q_φ(z) = N(μ_φ, σ²_φ),
    采样 z = μ_φ + σ_φ * ε, ε ~ N(0,1)
    """
    mu, log_sigma = q.params
    sigma = torch.exp(log_sigma)
    
    samples = []
    for _ in range(n_samples):
        eps = torch.randn_like(mu)
        z = mu + sigma * eps
        samples.append(z)
    
    samples = torch.stack(samples)
    
    # 计算 f(z) 的梯度
    f_values = f(samples)
    
    # 重参数化梯度
    # ∇_φ f(z(φ, ε)) = ∂f/∂z * ∂z/∂φ
    gradients = torch.autograd.grad(
        f_values.sum(),
        [mu, log_sigma],
        retain_graph=True
    )
    
    return gradients

平均场变分推断的详细推导

高斯混合模型(GMM)

以高斯混合模型为例展示平均场 VI:

模型

  • 混合权重:
  • 簇分配:
  • 观测:

变分分布

更新

其中

更新

class VariationalGMM:
    """
    高斯混合模型的变分推断
    """
    def __init__(self, X, n_clusters, alpha=1.0):
        self.X = X
        self.n_clusters = n_clusters
        self.n_samples = len(X)
        self.alpha = alpha
        
        # 初始化变分参数
        self.phi = np.random.dirichlet(np.ones(n_clusters), size=self.n_samples)
        self.alpha_pi = np.ones(n_clusters) * alpha
    
    def e_step(self):
        """
        E步:更新簇分配
        """
        # 更新 phi
        for i in range(self.n_samples):
            log_resp = (np.log(self.alpha_pi) + 
                       self.compute_log_likelihood(self.X[i]))
            log_resp -= np.max(log_resp)  # 数值稳定性
            self.phi[i] = np.exp(log_resp)
            self.phi[i] /= np.sum(self.phi[i])
    
    def m_step(self):
        """
        M步:更新混合权重
        """
        # 更新 Dirichlet 参数
        self.alpha_pi = self.alpha + np.sum(self.phi, axis=0)
    
    def compute_log_likelihood(self, x):
        """
        计算每个簇的对数似然
        """
        return -0.5 * np.sum((x - self.mu)**2 / self.sigma**2, axis=1) - \
               0.5 * np.log(self.sigma**2)
    
    def fit(self, n_iterations=100):
        """
        运行变分EM算法
        """
        for _ in range(n_iterations):
            self.e_step()
            self.m_step()
        
        return self.phi

变分推断与EM算法的联系

EM vs VI

方面EM算法变分推断
隐变量优化推断
参数推断优化
目标最大化似然最大化ELBO
收敛性局部最优局部最优

变分EM

变分EM结合了两者的优点:

def variational_em(model, data, n_iterations=100):
    """
    变分EM算法
    """
    for iteration in range(n_iterations):
        # E步:变分推断(优化隐变量分布)
        q_z = variational_inference(model, data)
        
        # M步:优化参数
        params = m_step(model, data, q_z)
        model.update_params(params)
    
    return model

变分推断的实现技巧

数值稳定性

def stable_elbo(log_p, log_q):
    """
    数值稳定的 ELBO 计算
    
    使用 log-sum-exp 技巧
    """
    # KL = sum(exp(log_p) * (log_p - log_q))
    # = sum(exp(log_p) * log_p) - sum(exp(log_p) * log_q)
    
    # 重建项:使用 log-sum-exp
    reconstruction = np.mean(log_p)
    
    # KL 项
    kl = np.mean(log_p - log_q)
    
    return reconstruction - kl
 
def log_mean_exp(x, axis=None):
    """
    数值稳定的 log(mean(exp(x)))
    """
    max_x = np.max(x, axis=axis, keepdims=True)
    return max_x + np.log(np.mean(np.exp(x - max_x), axis=axis))

调参建议

  1. 初始化:使用 -means 或随机初始化
  2. 学习率:使用学习率衰减或 Adam
  3. 早停:监控验证集上的 ELBO
  4. 重采样:定期重新采样以减少方差

与其他近似方法的对比

VI vs MCMC

方面变分推断MCMC
计算复杂度 每迭代 采样
收敛保证局部最优渐近收敛
近似质量有偏差无偏(渐近)
实现难度中等较高
扩展性一般

VI vs 拉普拉斯近似

方面变分推断拉普拉斯近似
近似分布任意变分族高斯分布
灵活性
计算成本中等
理论基础信息论渐近理论

应用场景

变分自编码器(VAE)

VAE 是变分推断在生成模型中的典型应用:

class VAE(nn.Module):
    """
    变分自编码器
    """
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        
        # 编码器
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_log_var = nn.Linear(128, latent_dim)
        
        # 解码器
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim)
        )
    
    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        log_var = self.fc_log_var(h)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decode(z)
        return x_recon, mu, log_var
    
    def elbo_loss(self, x, x_recon, mu, log_var, beta=1.0):
        """
        ELBO 损失
        
        ELBO = -L = - Reconstruction - KL
        """
        # 重建项(负对数似然)
        recon_loss = nn.functional.mse_loss(x_recon, x, reduction='sum')
        
        # KL 项
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        return recon_loss + beta * kl_loss

变分Dropout

变分Dropout将Dropout解释为贝叶斯推断:

class VariationalDropout(nn.Module):
    """
    变分Dropout(Kingma et al., 2015)
    """
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # 变分参数
        self.w_mu = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
        self.b_mu = nn.Parameter(torch.zeros(out_features))
        
        self.w_log_alpha = nn.Parameter(torch.zeros(out_features, in_features))
        self.b_log_alpha = nn.Parameter(torch.zeros(out_features))
    
    def forward(self, x, sample=True):
        if sample or self.training:
            # 重参数化采样
            alpha = torch.exp(self.w_log_alpha)
            std = (self.w_mu ** 2) * alpha
            
            # 权重采样
            w = self.w_mu + torch.randn_like(self.w_mu) * std.sqrt()
            b = self.b_mu + torch.randn_like(self.b_mu) * \
                (self.b_mu ** 2 * torch.exp(self.b_log_alpha)).sqrt()
        else:
            # 近似后验均值
            w = self.w_mu
            b = self.b_mu
        
        return F.linear(x, w, b)
    
    def kl_divergence(self):
        """
        计算 KL 散度
        """
        alpha = torch.exp(self.w_log_alpha)
        kl = 0.5 * (alpha * self.w_mu ** 2 / (self.w_mu ** 2 + alpha) - 
                    torch.log(self.w_mu ** 2 + alpha) + 
                    torch.log(self.w_mu ** 2) + 1)
        return kl.sum()

参考


相关链接

Footnotes

  1. Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational Inference: A Review for Statisticians. Journal of the American Statistical Association.