神经变分推断深度解析

神经变分推断(Neural Variational Inference,NVI)是将变分推断与神经网络深度融合的范式,通过参数化的神经网络近似复杂后验分布,实现端到端的概率推断。1 本文档深入解析变分推断的数学基础、神经网络的概率解释、以及现代变分方法的实现细节。

1. 变分推断基础回顾

1.1 问题的形式化

在贝叶斯推断中,我们希望计算后验分布:

其中:

  • :先验分布
  • :似然函数
  • :边缘似然(证据)

核心困难:边缘似然 通常难以解析计算,导致后验分布无法直接得到。

1.2 Jensen 不等式与 ELBO

Jensen 不等式:对于凸函数 和概率分布

应用 Jensen 不等式:

证据下界(Evidence Lower Bound,ELBO):

1.3 ELBO 的分解

将 ELBO 进一步分解为两个有意义的项:

含义作用
重构项确保变分分布能重构数据
正则化项约束变分分布接近先验

1.4 KL 散度的性质

定义

性质

  1. 非负性
  2. 非对称性
  3. 可加性

高斯分布间的 KL 散度(闭合形式):

对于


2. VAE 的概率图视角

2.1 VAE 的生成模型

变分自编码器(Variational Autoencoder,VAE)定义了一个层次生成模型2

生成过程

概率图模型

        p(z)          p(x|z)
    ┌─────────┐    ┌─────────┐
    │         │    │         │
    ↓         │    ↓         │
    z ──────────→ x
    ↑
    │
  q(z|x)

2.2 推断网络

由于真实后验 难以计算,VAE 引入推断网络 来近似:

推理网络结构(编码器):

class Encoder(nn.Module):
    """VAE 编码器:推断网络 q(z|x)"""
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        
        # 共享编码层
        self.shared = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # 均值和方差网络
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_log_var = nn.Linear(hidden_dim, latent_dim)
    
    def forward(self, x):
        h = self.shared(x)
        mu = self.fc_mu(h)
        log_var = self.fc_log_var(h)
        
        # 方差确保为正
        return mu, log_var

2.3 生成网络

生成网络结构(解码器):

class Decoder(nn.Module):
    """VAE 解码器:生成分布 p(x|z)"""
    
    def __init__(self, latent_dim, hidden_dim, output_dim, distribution='bernoulli'):
        super().__init__()
        self.distribution = distribution
        
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
        if distribution == 'bernoulli':
            # 输出 sigmoid 激活
            self.output_activation = nn.Sigmoid()
        elif distribution == 'gaussian':
            # 输出均值和对数方差
            self.fc_mu = nn.Linear(hidden_dim, output_dim)
            self.fc_log_var = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, z):
        h = self.fc(z)
        
        if self.distribution == 'bernoulli':
            p_x_given_z = torch.sigmoid(h)
            return p_x_given_z
        elif self.distribution == 'gaussian':
            mu = self.fc_mu(h)
            log_var = self.fc_log_var(h)
            return mu, log_var

2.4 VAE 的目标函数

证据下界

完整目标


3. 重参数化技巧的数学原理

3.1 问题的数学形式

我们希望计算关于变分分布 的期望梯度:

直接求导的问题:期望内部包含随机变量,无法直接应用链式法则。

3.2 重参数化变换

核心思想:将随机性转移到独立的噪声变量中:

常见重参数化

分布类型重参数化形式
高斯
拉普拉斯
分类
混合

3.3 梯度推导

定理:对于可微的重参数化函数

证明

3.4 方差分析

重参数化的优势:降低梯度估计的方差。

Score function 梯度

问题:Score function 梯度在 值较大时方差爆炸。

重参数化梯度

优势:梯度仅通过 传播,不直接涉及 的梯度。

3.5 Gumbel-Softmax 重参数化

对于离散分布,使用 Gumbel-Softmax 进行重参数化:

其中 是温度参数。

def gumbel_softmax(logits, temperature, hard=False):
    """
    Gumbel-Softmax 重参数化
    
    Args:
        logits: (batch, n_categories) 分类分布的对数概率
        temperature: 温度参数 τ
        hard: 是否使用硬 one-hot 输出
    
    Returns:
        采样的 soft/hard 分布
    """
    # 采样 Gumbel 噪声
    gumbels = -torch.empty_like(logits).exponential_().log()
    gumbels = (logits + gumbels) / temperature
    
    # Softmax
    soft = F.softmax(gumbels, dim=-1)
    
    if hard:
        # 硬 one-hot(但保持梯度流通)
        hard_onehot = F.one_hot(gumbels.argmax(dim=-1), logits.size(-1)).float()
        return (hard_onehot - soft).detach() + soft
    else:
        return soft

4. 证据下界(ELBO)的深入分析

4.1 ELBO 与真实边际似然的关系

定理:对于任意变分分布

推论

  1. (ELBO 是下界)
  2. 时,

4.2 ELBO 的多种等价形式

形式 1:标准形式

形式 2:重构 + KL 形式

形式 3:信息论形式

形式 4:重要性采样形式

4.3 分解性质

数据点的可加性

局部 ELBO

4.4 ELBO 的紧度分析

定义,称为紧度比率

优化目标:最大化 等价于:

  1. 最大化重构似然
  2. 最小化

冲突

  • 重构项鼓励 集中在高似然区域
  • KL 项鼓励 接近先验

5. 变分分布的选择与设计

5.1 均值场近似

假设:变分分布分解为独立因子的乘积:

优势:简化计算,易于优化
劣势:忽略变量间相关性

class MeanFieldVariational:
    """均值场变分分布"""
    
    def __init__(self, dims, distribution='gaussian'):
        self.dims = dims
        self.distribution = distribution
        
        if distribution == 'gaussian':
            # 每个因子是高斯分布
            self.mus = nn.Parameter(torch.randn(dims))
            self.log_vars = nn.Parameter(torch.zeros(dims))
    
    def sample(self, n_samples=1):
        """采样"""
        if self.distribution == 'gaussian':
            std = (0.5 * self.log_vars).exp()
            eps = torch.randn(n_samples, self.dims)
            return self.mus + eps * std
    
    def log_prob(self, z):
        """计算对数概率"""
        if self.distribution == 'gaussian':
            log_prob = -0.5 * (
                np.log(2 * np.pi) + 
                self.log_vars + 
                (z - self.mus) ** 2 / self.log_vars.exp()
            )
            return log_prob.sum(dim=-1)

5.2 层次变分分布

动机:捕获变量间的相关性

class HierarchicalVariational(nn.Module):
    """层次变分分布"""
    
    def __init__(self, latent_dim, hidden_dim):
        super().__init__()
        
        # 先验参数化
        self.prior_net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2 * latent_dim)  # μ, log_var
        )
        
        # 条件变分分布
        self.posterior_net = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2 * latent_dim)
        )
    
    def forward(self, x, n_samples=1):
        """从层次变分分布采样"""
        # 采样顶层参数
        prior_params = self.prior_net(x)
        lambda_mu, lambda_log_var = prior_params.chunk(2, dim=-1)
        lambda_std = (0.5 * lambda_log_var).exp()
        
        # 采样 λ
        eps = torch.randn_like(lambda_mu)
        lambda_sample = lambda_mu + eps * lambda_std
        
        # 给定 λ,采样 z
        posterior_params = self.posterior_net(lambda_sample)
        z_mu, z_log_var = posterior_params.chunk(2, dim=-1)
        z_std = (0.5 * z_log_var).exp()
        
        eps = torch.randn(n_samples, *z_mu.shape)
        z_samples = z_mu + eps * z_std
        
        return z_samples, lambda_sample

5.3 规范化流变分分布

使用可逆变换增强表达能力:

class NormalizingFlowVariational(nn.Module):
    """基于归一化流的变分分布"""
    
    def __init__(self, latent_dim, n_flows=4):
        super().__init__()
        self.latent_dim = latent_dim
        self.n_flows = n_flows
        
        # 基分布
        self.base_mu = nn.Parameter(torch.zeros(latent_dim))
        self.base_log_var = nn.Parameter(torch.zeros(latent_dim))
        
        # 归一化流层
        self.flows = nn.ModuleList([
            PlanarFlow(latent_dim) for _ in range(n_flows)
        ])
    
    def forward(self, n_samples=1):
        """从规范化流分布采样"""
        # 从基分布采样
        std = (0.5 * self.base_log_var).exp()
        z = self.base_mu + torch.randn(n_samples, self.latent_dim) * std
        
        # 通过归一化流
        log_det_sum = 0
        for flow in self.flows:
            z, log_det = flow(z)
            log_det_sum += log_det
        
        return z, log_det_sum
    
    def log_prob(self, z):
        """计算对数概率"""
        # 计算基分布概率
        diff = z - self.base_mu
        base_log_prob = -0.5 * (
            np.log(2 * np.pi) + 
            self.base_log_var + 
            diff ** 2 / self.base_log_var.exp()
        ).sum(dim=-1)
        
        # 加上 Jacobian 行列式
        log_det_sum = torch.zeros(z.size(0), device=z.device)
        z_cur = z
        for flow in self.flows:
            z_cur, log_det = flow(z_cur)
            log_det_sum += log_det
        
        return base_log_prob + log_det_sum
 
 
class PlanarFlow(nn.Module):
    """Planar 归一化流"""
    
    def __init__(self, dim):
        super().__init__()
        self.w = nn.Parameter(torch.randn(dim))
        self.u = nn.Parameter(torch.randn(dim))
        self.b = nn.Parameter(torch.zeros(1))
    
    def forward(self, z):
        """
        变换: z' = z + u * h(w^T z + b)
        
        Returns:
            z_new: 变换后的样本
            log_det: log|det(dz'/dz)|
        """
        # 激活函数
        activation = torch.tanh(z @ self.w + self.b)
        
        # 前向变换
        z_new = z + self.u * activation
        
        # 对数行列式
        # d(z')/dz = I + u * h'(w^T z + b) * w^T
        # det = 1 + u^T * h'(w^T z + b) * w
        psi = (1 - activation ** 2) * self.w
        det = 1 + (self.u @ psi)
        log_det = torch.log(det.abs() + 1e-8)
        
        return z_new, log_det

5.4 混合变分分布

class MixtureVariational(nn.Module):
    """混合变分分布"""
    
    def __init__(self, latent_dim, n_components, hidden_dim):
        super().__init__()
        self.latent_dim = latent_dim
        self.n_components = n_components
        
        # 组件分布参数
        self.component_mus = nn.Parameter(
            torch.randn(n_components, latent_dim)
        )
        self.component_log_vars = nn.Parameter(
            torch.zeros(n_components, latent_dim)
        )
        
        # 混合权重(logits)
        self.mixture_weights = nn.Parameter(torch.zeros(n_components))
    
    def sample(self, n_samples=1):
        """从混合分布采样"""
        # 选择组件
        weights = F.softmax(self.mixture_weights, dim=0)
        component_idx = torch.multinomial(weights, n_samples, replacement=True)
        
        # 从选中组件采样
        samples = torch.randn(n_samples, self.latent_dim)
        for i in range(n_samples):
            comp = component_idx[i]
            std = (0.5 * self.component_log_vars[comp]).exp()
            samples[i] = self.component_mus[comp] + samples[i] * std
        
        return samples
    
    def log_prob(self, z):
        """计算对数概率"""
        log_probs = []
        
        for k in range(self.n_components):
            std = (0.5 * self.component_log_vars[k]).exp()
            diff = z - self.component_mus[k]
            log_prob_k = -0.5 * (
                np.log(2 * np.pi) + 
                self.component_log_vars[k] + 
                diff ** 2 / self.component_log_vars[k].exp()
            ).sum(dim=-1)
            log_probs.append(log_prob_k + self.mixture_weights[k])
        
        # 混合分布的对数概率
        log_probs = torch.stack(log_probs, dim=-1)
        return torch.logsumexp(log_probs, dim=-1)

6. 梯度估计方法

6.1 Score Function 梯度(REINFORCE)

Score function 恒等式

蒙特卡洛估计

def score_function_gradient(f, q, n_samples=100):
    """
    Score function 梯度估计
    
    适用于:离散变量、不可微模型、复杂似然函数
    
    Args:
        f: 函数 f: Z → ℝ
        q: 变分分布 q_φ(z)
        n_samples: 采样数量
    
    Returns:
        gradient: 梯度估计
    """
    gradients = []
    
    for _ in range(n_samples):
        z = q.sample()  # 从 q_φ 采样
        log_q = q.log_prob(z)  # log q_φ(z)
        grad_log_q = torch.autograd.grad(
            log_q.sum(), 
            q.parameters(),
            retain_graph=True
        )[0]
        
        # f(z) * ∇_φ log q_φ(z)
        f_val = f(z)
        gradients.append(f_val * grad_log_q)
    
    return torch.stack(gradients).mean(dim=0)

6.2 分数函数梯度(Score Matching)

Score function

分数匹配目标

6.3 重参数化梯度

连续变量的首选方法

def reparameterization_gradient(f, phi, n_samples=100):
    """
    重参数化梯度估计
    
    适用于:连续变量、可微分模型
    
    Args:
        f: 函数 f: Z → ℝ
        phi: 变分参数
    
    Returns:
        gradient: 梯度估计
    """
    gradients = []
    
    for _ in range(n_samples):
        # 采样噪声
        epsilon = torch.randn_like(phi)
        
        # 重参数化
        z = reparameterize(phi, epsilon)
        
        # 计算 f(z) 并反向传播
        f_val = f(z)
        f_val.backward()
        
        gradients.append(phi.grad.clone())
        phi.zero_grad()
    
    return torch.stack(gradients).mean(dim=0)
 
 
def reparameterize(mu, log_var):
    """
    高斯分布重参数化
    
    z = μ + σ * ε, ε ~ N(0, I)
    """
    std = (0.5 * log_var).exp()
    eps = torch.randn_like(std)
    return mu + eps * std

6.4 路径导数梯度(Pathwise Derivatives)

路径导数:通过完整路径的梯度传递。

对于

6.5 方法对比与选择

方法方差偏置适用场景
Score Function离散变量、不可微模型
Reparameterization连续变量、可微模型
Pathwise连续变量、复杂采样路径
RELAX中等混合离散-连续

6.6 方差归一化技术

class VarianceNormalizedEstimator:
    """
    方差归一化的梯度估计器
    
    减少梯度估计方差,加速收敛
    """
    
    def __init__(self, baseline_net=None):
        self.baseline_net = baseline_net  # 用于减方差的基线网络
    
    def estimate(self, f, q, n_samples=100):
        """估计梯度"""
        z_samples = q.sample(n_samples)
        log_q = q.log_prob(z_samples)
        
        # 计算 f(z) - b(z) 的 score function 梯度
        f_vals = f(z_samples)
        
        # 如果有基线,使用基线减方差
        if self.baseline_net is not None:
            baseline = self.baseline_net(z_samples)
            f_centered = f_vals - baseline.detach()
        else:
            # 使用均值作为基线
            f_centered = f_vals - f_vals.mean()
        
        # Score function 梯度
        grad_log_q = torch.autograd.grad(
            log_q.sum(),
            q.parameters(),
            retain_graph=True
        )[0]
        
        gradient = (f_centered * grad_log_q).mean()
        
        return gradient
    
    def update_baseline(self, targets, predictions):
        """
        更新基线网络(通常为 MSE 预测器)
        """
        if self.baseline_net is not None:
            loss = F.mse_loss(predictions, targets)
            loss.backward()

7. 最新进展:归一化流、连续混合、神经变分推断

7.1 连续归一化流

连续时间归一化流(Continuous Normalizing Flows,CNF):

概率流 ODE

class CNF(nn.Module):
    """连续归一化流"""
    
    def __init__(self, dim, hidden_dim=64):
        super().__init__()
        
        # 速度场网络
        self.velocity_net = nn.Sequential(
            nn.Linear(dim + 1, hidden_dim),  # +1 for time
            nn.Softplus(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Softplus(),
            nn.Linear(hidden_dim, dim)
        )
    
    def velocity(self, z, t):
        """速度场 f(z, t)"""
        # 时间条件
        t_emb = t * torch.ones(z.size(0), 1, device=z.device)
        z_t = torch.cat([z, t_emb], dim=-1)
        return self.velocity_net(z_t)
    
    def forward(self, z0, t_span):
        """
        前向传播:z0 → z1
        
        使用数值 ODE 求解器
        """
        # 使用 torchdiffeq
        from torchdiffeq import odeint
        
        solution = odeint(
            self.ode_func,
            z0,
            t_span,
            method='dopri5'
        )
        
        z1 = solution[-1]
        
        # 计算 log det |dz1/dz0|
        # 通过日志雅可比行列式的积分
        log_det = self.compute_log_det(z0, t_span)
        
        return z1, log_det
    
    def ode_func(self, t, z):
        """ODE 定义"""
        return self.velocity(z, t)
    
    def compute_log_det(self, z0, t_span):
        """计算对数行列式"""
        # 简化为使用迹的近似
        log_det = 0
        for t in t_span[:-1]:
            z_t = self.trajectory(t)[-1]
            with torch.enable_grad():
                v = self.velocity(z_t, t)
                div_v = torch.autograd.grad(
                    v.sum(), z_t, retain_graph=True
                )[0].sum(dim=-1)
            log_det -= div_v * (t_span[1] - t)
        
        return log_det

7.2 神经混合模型

神经混合模型(Neural Mixture Models):

class NeuralMixtureVAE(nn.Module):
    """神经混合 VAE"""
    
    def __init__(self, input_dim, latent_dim, n_components=4):
        super().__init__()
        self.n_components = n_components
        
        # 编码器
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU()
        )
        
        # 混合权重
        self.pi_net = nn.Linear(128, n_components)
        
        # 每个组件的均值和方差网络
        self.component_nets = nn.ModuleList([
            nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 2 * latent_dim)  # μ, log σ
            )
            for _ in range(n_components)
        ])
        
        # 解码器
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
            nn.ReLU(),
            nn.Linear(256, 2 * input_dim)
        )
    
    def forward(self, x):
        """前向传播"""
        # 编码
        h = self.encoder(x)
        
        # 计算混合权重
        logits = self.pi_net(h)
        pi = F.softmax(logits, dim=-1)
        
        # 为每个组件采样
        z_samples = []
        log_probs = []
        
        for k in range(self.n_components):
            params = self.component_nets[k](h)
            mu, log_var = params.chunk(2, dim=-1)
            std = (0.5 * log_var).exp()
            
            # 采样
            z_k = mu + std * torch.randn_like(mu)
            z_samples.append(z_k)
            
            # 对数概率
            log_prob_k = -0.5 * (
                np.log(2 * np.pi) + 
                log_var + 
                (z_k - mu) ** 2 / log_var.exp()
            ).sum(dim=-1)
            log_probs.append(log_prob_k)
        
        z_samples = torch.stack(z_samples, dim=0)  # (K, batch, latent)
        log_probs = torch.stack(log_probs, dim=0)  # (K, batch)
        
        # 加权求和
        log_pi = torch.log(pi.T + 1e-8)  # (K, batch)
        weighted_log_probs = log_probs + log_pi
        
        # 边缘化
        z_out = torch.logsumexp(weighted_log_probs, dim=0)
        
        # 解码(使用均值)
        z_mean = (pi.unsqueeze(-1) * z_samples.permute(1, 2, 0)).sum(dim=-1)
        x_rec = self.decoder(z_mean)
        
        return {
            'z_samples': z_samples,
            'pi': pi,
            'x_rec': x_rec,
            'log_weights': weighted_log_probs
        }

7.3 对抗变分推断

对抗变分推断(Adversarial Variational Inference,AVI):

class AdversarialVI(nn.Module):
    """
    对抗变分推断
    
    使用判别器区分真实后验和变分后验
    """
    
    def __init__(self, encoder, decoder, discriminator):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.discriminator = discriminator
    
    def discriminator_loss(self, x, z):
        """
        判别器损失
        
        目标:区分来自真实后验和变分后验的样本
        """
        # 来自变分后验的样本
        z_from_q = z
        
        # 来自先验的样本(近似真实后验)
        z_from_prior = torch.randn_like(z)
        
        # 判别器预测
        d_q = self.discriminator(x, z_from_q)
        d_prior = self.discriminator(x, z_from_prior)
        
        # 损失
        loss_d = -0.5 * (
            torch.log(d_q + 1e-8) + 
            torch.log(1 - d_prior + 1e-8)
        ).mean()
        
        return loss_d
    
    def generator_loss(self, x, z):
        """
        生成器(编码器+解码器)损失
        
        目标:骗过判别器
        """
        d = self.discriminator(x, z)
        loss_g = -torch.log(d + 1e-8).mean()
        
        # 添加重构损失
        x_rec = self.decoder(z)
        loss_rec = F.mse_loss(x_rec, x)
        
        return loss_g + loss_rec

8. PyTorch 完整实现

8.1 完整 VAE 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
import math
import numpy as np
 
class VAE(nn.Module):
    """
    完整变分自编码器实现
    
    支持:
    - 重参数化采样
    - 可配置编码器/解码器
    - 多种损失函数
    - 学习率调度
    """
    
    def __init__(
        self, 
        input_dim: int,
        latent_dim: int,
        hidden_dim: int = 256,
        encoder_depth: int = 2,
        decoder_depth: int = 2,
        distribution: str = 'gaussian',
        use_flows: bool = False,
        n_flows: int = 0
    ):
        super().__init__()
        
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.distribution = distribution
        
        # 编码器
        encoder_layers = []
        prev_dim = input_dim
        for _ in range(encoder_depth):
            encoder_layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU()
            ])
            prev_dim = hidden_dim
        self.encoder = nn.Sequential(*encoder_layers)
        
        # 潜在空间映射
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_log_var = nn.Linear(hidden_dim, latent_dim)
        
        # 解码器
        decoder_layers = []
        prev_dim = latent_dim
        for _ in range(decoder_depth):
            decoder_layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU()
            ])
            prev_dim = hidden_dim
        self.decoder = nn.Sequential(*decoder_layers)
        
        # 输出层
        if distribution == 'gaussian':
            self.fc_out_mu = nn.Linear(hidden_dim, input_dim)
            self.fc_out_log_var = nn.Linear(hidden_dim, input_dim)
        elif distribution == 'bernoulli':
            self.fc_out = nn.Linear(hidden_dim, input_dim)
        
        # 归一化流(可选)
        self.use_flows = use_flows
        if use_flows:
            self.flows = nn.ModuleList([
                PlanarFlow(latent_dim) for _ in range(n_flows)
            ])
    
    def encode(self, x):
        """编码:计算后验参数"""
        h = self.encoder(x)
        mu = self.fc_mu(h)
        log_var = self.fc_log_var(h)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        """重参数化采样"""
        std = (0.5 * log_var).exp()
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def flow_transform(self, z0, reverse=False):
        """归一化流变换"""
        log_det_sum = torch.zeros(z0.size(0), device=z0.device)
        z = z0
        
        flows = self.flows if not reverse else reversed(self.flows)
        
        for flow in flows:
            z, log_det = flow(z)
            log_det_sum += log_det
        
        return z, log_det_sum
    
    def decode(self, z):
        """解码:计算生成分布"""
        h = self.decoder(z)
        
        if self.distribution == 'gaussian':
            mu = self.fc_out_mu(h)
            log_var = self.fc_out_log_var(h)
            return mu, log_var
        elif self.distribution == 'bernoulli':
            logits = self.fc_out(h)
            return torch.sigmoid(logits), None
    
    def forward(self, x, n_samples=1):
        """
        前向传播
        
        Args:
            x: (batch_size, input_dim) 输入数据
            n_samples: 每个数据点的采样数
        
        Returns:
            dict: 包含重构、损失等信息
        """
        # 编码
        mu, log_var = self.encode(x)
        
        # 重参数化采样
        z = self.reparameterize(mu, log_var)
        
        # 归一化流(如果使用)
        log_det_flow = torch.zeros(x.size(0), device=x.device)
        if self.use_flows:
            z, log_det_flow = self.flow_transform(z)
        
        # 解码
        recon = self.decode(z)
        
        # 计算 ELBO
        if self.distribution == 'gaussian':
            recon_mu, recon_log_var = recon
            log_px_given_z = dist.Normal(recon_mu, recon_log_var.exp()).log_prob(x).sum(dim=-1)
        elif self.distribution == 'bernoulli':
            p_x_given_z = recon
            log_px_given_z = dist.Bernoulli(p_x_given_z).log_prob(x).sum(dim=-1)
        
        # 先验对数似然
        log_pz = dist.Normal(0, 1).log_prob(z).sum(dim=-1)
        
        # 后验对数似然
        log_qz_given_x = dist.Normal(mu, log_var.exp()).log_prob(z).sum(dim=-1)
        
        # ELBO
        elbo = log_px_given_z + log_pz - log_qz_given_x + log_det_flow
        
        return {
            'elbo': elbo,
            'reconstruction': log_px_given_z,
            'kl': log_qz_given_x - log_pz - log_det_flow,
            'z': z,
            'mu': mu,
            'log_var': log_var
        }
    
    def loss(self, x):
        """计算损失(负 ELBO)"""
        output = self.forward(x)
        return -output['elbo'].mean()
    
    def sample(self, n_samples, temperature=1.0):
        """从先验采样并生成"""
        with torch.no_grad():
            # 从先验采样
            z = torch.randn(n_samples, self.latent_dim) * temperature
            
            # 归一化流逆变换(如果使用)
            if self.use_flows:
                z = self.flow_transform(z, reverse=True)[0]
            
            # 解码
            recon = self.decode(z)
            
            if self.distribution == 'gaussian':
                mu, log_var = recon
                x_samples = mu
            elif self.distribution == 'bernoulli':
                p = recon
                x_samples = torch.bernoulli(p)
            
            return x_samples, z
    
    def reconstruct(self, x):
        """重构输入"""
        with torch.no_grad():
            mu, log_var = self.encode(x)
            z = self.reparameterize(mu, log_var)
            recon = self.decode(z)
            
            if self.distribution == 'gaussian':
                return recon[0]
            elif self.distribution == 'bernoulli':
                return recon
 
 
class VAETrainer:
    """
    VAE 训练器
    
    支持:
    - 学习率调度
    - 早停
    - 可视化
    """
    
    def __init__(
        self, 
        model: VAE,
        optimizer_class: type = torch.optim.Adam,
        lr: float = 1e-3,
        beta: float = 1.0,
        recon_weight: float = 1.0
    ):
        self.model = model
        self.beta = beta  # KL 项权重
        self.recon_weight = recon_weight
        
        self.optimizer = optimizer_class(model.parameters(), lr=lr)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, 
            patience=10,
            factor=0.5
        )
        
        self.history = {
            'loss': [],
            'recon_loss': [],
            'kl_loss': [],
            'elbo': []
        }
    
    def train_step(self, x):
        """单步训练"""
        self.optimizer.zero_grad()
        
        # 前向传播
        output = self.model(x)
        
        # 计算损失
        recon_loss = -output['reconstruction'].mean()
        kl_loss = output['kl'].mean()
        elbo = output['elbo'].mean()
        loss = -(self.recon_weight * recon_loss + self.beta * kl_loss)
        
        # 反向传播
        loss.backward()
        
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        
        self.optimizer.step()
        
        # 记录
        self.history['loss'].append(loss.item())
        self.history['recon_loss'].append(recon_loss.item())
        self.history['kl_loss'].append(kl_loss.item())
        self.history['elbo'].append(elbo.item())
        
        return loss.item(), recon_loss.item(), kl_loss.item()
    
    def train(self, dataloader, n_epochs, eval_loader=None):
        """完整训练循环"""
        for epoch in range(n_epochs):
            epoch_loss = 0
            epoch_recon = 0
            epoch_kl = 0
            n_batches = 0
            
            for x in dataloader:
                if isinstance(x, (list, tuple)):
                    x = x[0]
                x = x.view(-1, self.model.input_dim)
                
                loss, recon, kl = self.train_step(x)
                
                epoch_loss += loss
                epoch_recon += recon
                epoch_kl += kl
                n_batches += 1
            
            # 平均
            avg_loss = epoch_loss / n_batches
            avg_recon = epoch_recon / n_batches
            avg_kl = epoch_kl / n_batches
            
            # 学习率调度
            self.scheduler.step(avg_loss)
            
            # 打印
            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch+1}/{n_epochs}")
                print(f"  Loss: {avg_loss:.4f}")
                print(f"  Recon: {avg_recon:.4f}")
                print(f"  KL: {avg_kl:.4f}")
                print(f"  ELBO: {self.history['elbo'][-1]:.4f}")
                print()
        
        return self.history
 
 
def demo_vae():
    """VAE 演示"""
    # 生成双峰数据
    torch.manual_seed(42)
    
    n_samples = 2000
    data1 = torch.randn(n_samples // 2, 2) + torch.tensor([2.0, 2.0])
    data2 = torch.randn(n_samples // 2, 2) + torch.tensor([-2.0, -2.0])
    data = torch.cat([data1, data2], dim=0)
    
    # 打乱
    perm = torch.randperm(len(data))
    data = data[perm]
    
    # 创建数据集
    dataset = torch.utils.data.TensorDataset(data)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
    
    # 创建 VAE
    vae = VAE(
        input_dim=2,
        latent_dim=2,
        hidden_dim=64,
        encoder_depth=2,
        decoder_depth=2,
        distribution='gaussian'
    )
    
    # 训练
    trainer = VAETrainer(vae, lr=1e-3, beta=1.0)
    history = trainer.train(dataloader, n_epochs=100)
    
    # 测试采样
    samples, z = vae.sample(n_samples=500)
    
    print(f"\n生成的样本统计:")
    print(f"  均值: {samples.mean(dim=0)}")
    print(f"  方差: {samples.var(dim=0)}")
    
    return vae, history
 
 
if __name__ == "__main__":
    vae, history = demo_vae()

9. 总结与关联

9.1 核心要点

主题核心公式
ELBO
重参数化
KL 散度
Gumbel-Softmax

9.2 与相关文档的关联

相关主题关联说明
贝叶斯网络概率图模型基础
变分推断进阶SVI、IWAE、归一化流
贝叶斯神经网络BNN 的变分推断训练
Bayes by Backprop变分推断的神经网络实现
MC DropoutDropout 的贝叶斯解释
概率电路基础电路作为变分后验

参考

Footnotes

  1. Kingma, D. P., & Welling, M. (2014). “Auto-Encoding Variational Bayes”. ICLR 2014.

  2. Rezende, D. J., Mohamed, S., & Wierstra, D. (2014). “Stochastic Backpropagation and Approximate Inference in Deep Generative Models”. ICML 2014.