概述

变分推断(Variational Inference, VI)是概率图模型中最重要的近似推断方法之一,它将后验分布推断问题转化为优化问题。1

在深度学习时代,变分推断成为连接贝叶斯神经网络与经典神经网络的关键桥梁——变分自编码器(VAE)、变分循环网络、变分图神经网络等模型都建立在变分推断的数学框架之上。


变分推断的基本框架

问题设定

给定观测数据 和潜在变量 ,我们希望:

  1. 学习参数:最大化边缘似然
  2. 推断后验:计算后验分布

边缘似然的分解:

精确推断的困难

在大多数实际问题中,积分 不可计算的

  • 潜在空间维度高(,
  • 后验分布 没有解析形式
  • 配分函数 难以计算

变分推断的核心思想

用一个简单的分布 去近似复杂的真实后验

然后将推断问题转化为优化问题:找到最优的 使得 最接近。

变分族的选择

常见的变分分布族:

变分族形式优点缺点
均值场(Mean-Field)独立、易于计算过于简化
高斯变分平滑、连续参数多
归一化流表达能力强计算复杂
摊销分布$q_\phi(zx) = \text{NN}_\phi(x)$共享参数

证据下界(ELBO)

KL散度推导

我们用KL散度衡量两个分布的差异:

利用贝叶斯定理

证据下界

重新整理得:

关键洞察:由于KL散度非负,我们得到证据下界(Evidence Lower Bound, ELBO):

ELBO的两种形式

形式1:期望形式

解释

  • 第一项:重构似然的期望(重构损失)
  • 第二项:先验与后验的KL散度(正则化项)

形式2:信息论形式

解释:ELBO是边缘似然减去真实后验与变分后验的KL散度。最小化后验近似误差等价于最大化ELBO。

最大化ELBO的目标


变分推断的优化方法

坐标上升变分推断(CAVI)

对于均值场变分族,可以交替优化每个局部变分参数:

def cavi_update(j, X, q):
    """
    CAVI更新规则
    
    Args:
        j: 更新的变量索引
        X: 观测数据
        q: 当前变分分布
    
    Returns:
        new_q_j: 更新后的变分分布
    """
    # 计算期望(除z_j外的所有其他变量)
    expected_logjoint = 0
    for sample in range(num_samples):
        z_sample = q.sample()
        expected_logjoint += np.log(p(X, z_sample))
    
    expected_logjoint /= num_samples
    
    # 归一化
    new_q_j = np.exp(expected_logjoint)
    new_q_j /= new_q_j.sum()  # 归一化
    
    return new_q_j

随机变分推断(SVI)

当数据规模很大时,使用随机梯度上升:

其中 是学习率,通常使用自适应学习率调度。

重参数化技巧

为了计算 的梯度,使用重参数化技巧

高斯分布的例子


变分自编码器(VAE)

VAE的概率模型

VAE假设数据生成过程如下:

其中:

  • :标准高斯先验
  • :解码器分布(通常是高斯或伯努利)
  • :变分近似后验(编码器)

VAE的ELBO

PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, Bernoulli
 
class Encoder(nn.Module):
    """变分编码器"""
    def __init__(self, input_dim, latent_dim, hidden_dim=400):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)      # 均值
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)   # 对数方差
    
    def forward(self, x):
        h = F.relu(self.fc1(x))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
 
 
class Decoder(nn.Module):
    """变分解码器"""
    def __init__(self, latent_dim, output_dim, hidden_dim=400):
        super().__init__()
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, z):
        h = F.relu(self.fc1(z))
        # 输出logits用于伯努利分布
        logits = self.fc2(h)
        return logits
 
 
class VAE(nn.Module):
    """
    变分自编码器
    
    使用重参数化技巧实现梯度反向传播
    """
    def __init__(self, input_dim, latent_dim, hidden_dim=400):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = Encoder(input_dim, latent_dim, hidden_dim)
        self.decoder = Decoder(latent_dim, input_dim, hidden_dim)
    
    def reparameterize(self, mu, logvar):
        """
        重参数化技巧
        
        z = μ + σ * ε, 其中 ε ~ N(0, I)
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        # 编码
        mu, logvar = self.encoder(x)
        
        # 重参数化采样
        z = self.reparameterize(mu, logvar)
        
        # 解码
        logits = self.decoder(z)
        
        return logits, mu, logvar
    
    def elbo_loss(self, x, logits, mu, logvar):
        """
        计算ELBO损失
        
        ELBO = 重构损失 - KL散度
        """
        # 重构损失(伯努利分布的负对数似然)
        # 对于二值图像数据,sigmoid激活后用BCE
        x_prob = torch.sigmoid(logits)
        recon_loss = F.binary_cross_entropy(x_prob, x, reduction='sum')
        
        # KL散度:q(z|x) || p(z)
        # 对于高斯分布,KL(N(μ,σ²) || N(0,I)) 有闭式解
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        
        return recon_loss + kl_loss
    
    def loss_function(self, x):
        """完整损失计算"""
        logits, mu, logvar = self.forward(x)
        return self.elbo_loss(x, logits, mu, logvar)
    
    def sample(self, num_samples, device):
        """
        从先验采样并解码
        """
        with torch.no_grad():
            z = torch.randn(num_samples, self.latent_dim, device=device)
            logits = self.decoder(z)
            samples = torch.sigmoid(logits)
        return samples
    
    def encode(self, x):
        """编码到潜在空间"""
        mu, logvar = self.encoder(x)
        return self.reparameterize(mu, logvar)
    
    def decode(self, z):
        """从潜在空间解码"""
        logits = self.decoder(z)
        return torch.sigmoid(logits)
 
 
def train_vae(model, dataloader, optimizer, device, epoch):
    model.train()
    total_loss = 0
    total_recon = 0
    total_kl = 0
    
    for batch_idx, (data, _) in enumerate(dataloader):
        data = data.view(-1, input_dim).to(device)
        
        optimizer.zero_grad()
        
        # 前向传播
        logits, mu, logvar = model(data)
        
        # 计算损失
        recon_loss = F.binary_cross_entropy_with_logits(logits, data, reduction='sum')
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        loss = recon_loss + kl_loss
        
        # 反向传播
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_recon += recon_loss.item()
        total_kl += kl_loss.item()
    
    n_samples = len(dataloader.dataset)
    print(f"Epoch {epoch}: Loss={total_loss/n_samples:.4f}, "
          f"Recon={total_recon/n_samples:.4f}, KL={total_kl/n_samples:.4f}")

摊销变分推断(Amortized VI)

摊销的动机

传统变分推断为每个数据点独立优化变分参数 ,计算复杂度为

摊销变分推断使用一个参数化函数(编码器网络):

这使得:

  • 推理成本从 降到 (给定网络前向传播)
  • 参数共享:所有数据点共享
  • 泛化能力:对未见数据也能推断后验

摊销推断的权衡

方面独立VI摊销VI
灵活性每个数据点独立优化共享参数
计算效率
表达能力高(独立参数)中等(共享函数)
泛化

对抗性变分推断

当变分分布族不够表达时,可以使用GAN风格的对抗训练:

class AdversarialVAE(nn.Module):
    """
    对抗变分自编码器
    使用判别器迫使q(z|x)接近p(z|x)
    """
    def __init__(self, input_dim, latent_dim, hidden_dim=400):
        super().__init__()
        # 编码器
        self.encoder = Encoder(input_dim, latent_dim, hidden_dim)
        
        # 解码器
        self.decoder = Decoder(latent_dim, input_dim, hidden_dim)
        
        # 判别器(区分q(z|x)和p(z))
        self.discriminator = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def encode(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        return z
    
    def forward(self, x):
        z = self.encode(x)
        logits = self.decoder(z)
        return logits, z
    
    def discriminator_loss(self, x):
        """判别器损失:鼓励q(z|x)接近p(z)"""
        # 从后验采样
        z_posterior = self.encode(x)
        
        # 从先验采样
        z_prior = torch.randn_like(z_posterior)
        
        # 判别器输出
        d_posterior = self.discriminator(z_posterior)
        d_prior = self.discriminator(z_prior)
        
        # 对抗损失
        return -torch.mean(d_posterior) + torch.mean(d_prior)
    
    def generator_loss(self, x):
        """生成器损失(欺骗判别器)"""
        z_posterior = self.encode(x)
        d_posterior = self.discriminator(z_posterior)
        return -torch.mean(d_posterior)

变分推断在深度学习中的应用

1. 变分循环网络

将变分推断扩展到序列模型:

class VariationalLSTM(nn.Module):
    """
    变分LSTM
    用于序列数据的潜在变量建模
    """
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        
        # 编码器(从隐藏状态推断潜在变量)
        self.q_z = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 2 * latent_dim)  # mu和logvar
        )
        
        # LSTM
        self.lstm = nn.LSTM(input_dim + latent_dim, hidden_dim, batch_first=True)
        
        # 解码器
        self.decoder = nn.Linear(hidden_dim, input_dim)
    
    def forward(self, x, h=None):
        batch_size, seq_len, _ = x.shape
        
        if h is None:
            h = (torch.zeros(1, batch_size, self.hidden_dim),
                 torch.zeros(1, batch_size, self.hidden_dim))
        
        outputs = []
        for t in range(seq_len):
            # 推断潜在变量
            mu_logvar = self.q_z(h[0][0])
            mu, logvar = mu_logvar.chunk(2, dim=-1)
            z = self.reparameterize(mu, logvar)
            
            # 输入 + 潜在变量
            input_t = torch.cat([x[:, t:t+1], z.unsqueeze(1)], dim=-1)
            
            # LSTM前向
            out, h = self.lstm(input_t, h)
            outputs.append(self.decoder(out))
        
        return torch.cat(outputs, dim=1)

2. 变分图神经网络

在图神经网络中引入潜在变量:

class VariationalGNN(nn.Module):
    """
    变分图神经网络
    用于图数据的生成和推断
    """
    def __init__(self, node_dim, edge_dim, latent_dim, hidden_dim):
        super().__init__()
        
        # 节点编码器
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2 * latent_dim)
        )
        
        # GNN层
        self.gnn = MessagePassingLayer(hidden_dim, edge_dim)
        
        # 解码器
        self.decoder = nn.Linear(hidden_dim, node_dim)
    
    def forward(self, x, edge_index):
        # 推断潜在变量
        mu_logvar = self.node_encoder(x)
        mu, logvar = mu_logvar.chunk(2, dim=-1)
        z = self.reparameterize(mu, logvar)
        
        # GNN消息传递
        h = self.gnn(z, edge_index)
        
        # 解码
        x_recon = self.decoder(h)
        
        return x_recon, mu, logvar
    
    def loss(self, x, edge_index):
        x_recon, mu, logvar = self.forward(x, edge_index)
        
        # 重构损失
        recon_loss = F.mse_loss(x_recon, x)
        
        # KL散度
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        
        return recon_loss + kl_loss

3. 变分dropout与贝叶斯神经网络

变分dropout提供了一种贝叶斯视角的dropout解释:

class VariationalDropout(nn.Module):
    """
    变分dropout(Gal & Ghahramani, 2016)
    
    Dropout等价于变分推断中的KL正则化项
    """
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_features, out_features))
        self.bias = nn.Parameter(torch.zeros(out_features))
        
        # Dropout率
        self.log_alpha = nn.Parameter(torch.zeros(1))
    
    @property
    def alpha(self):
        return torch.sigmoid(self.log_alpha)
    
    def forward(self, x, training=True):
        if training:
            # 变分dropout:随机掩码
            mask = torch.bernoulli(1 - self.alpha.expand_as(x))
            x_dropout = x * mask / (1 - self.alpha)
        else:
            x_dropout = x
        
        return F.linear(x_dropout, self.weight, self.bias)
    
    def kl_divergence(self):
        """
        变分dropout的KL散度
        等价于额外的正则化项
        """
        return self.alpha.pow(2) / (1 - self.alpha.pow(2) + 1e-8)

信息论视角

ELBO的信息论分解

ELBO可以进一步分解为信息论量:

互信息项

β-VAE

调节ELBO中KL项的权重:

  • :标准VAE
  • :更强调先验正则化( disentanglement)
  • :更强调重构(更清晰的重建)
class BetaVAE(nn.Module):
    def __init__(self, input_dim, latent_dim, beta=1.0):
        super().__init__()
        self.beta = beta
        self.vae = VAE(input_dim, latent_dim)
    
    def loss_function(self, x):
        logits, mu, logvar = self.vae(x)
        
        recon_loss = F.binary_cross_entropy_with_logits(logits, x, reduction='sum')
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        
        # 加权KL散度
        return recon_loss + self.beta * kl_loss

与现有wiki内容的联系

主题相关文件
概率图模型probabilistic-graphical-models-comprehensive
贝叶斯神经网络bayesian-neural-networks
变分推断基础variational-inference
信息论基础information-theory
归一化流normalizing-flows-variational

参考


相关阅读

Footnotes

  1. Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational inference: A review for statisticians. Journal of the American statistical Association, 112(518), 859-877.