变分推断进阶:SVI、IWAEs与重参数化技巧

1. 背景回顾

1.1 变分推断基础

变分推断(Variational Inference, VI)的核心思想是用一个参数化的近似分布 去近似真实后验 ,通过最小化两者之间的KL散度:

其中 证据下界(ELBO)。

详见variational-inference

1.2 本章目标

主题内容
黑盒变分推断(BBVI)无需解析梯度,使用Monte Carlo估计
随机变分推断(SVI)大规模数据的变分推断
重要性加权自编码器(IWAE)比ELBO更紧的下界
半隐式变分推断(SIVI)非参数化变分族
重参数化技巧降低梯度估计方差
归一化流增强变分分布表达能力

2. 黑盒变分推断(BBVI)

2.1 核心问题

在标准变分推断中,我们需要计算对ELBO的梯度:

其中

问题:梯度不能直接通过期望内部传递。

2.2 策略1:score function梯度(REINFORCE)

score function恒等式

def score_function_gradient(f, q, n_samples=100):
    """使用score function估计梯度
    
    问题:方差通常很高
    
    Returns:
        gradient: 梯度估计
    """
    gradients = []
    for _ in range(n_samples):
        z = q.sample()  # 从变分分布采样
        grad_log_q = q.score(z)  # ∇_φ log q_φ(z)
        gradients.append(f(z) * grad_log_q)
    
    return np.mean(gradients, axis=0)

优点:通用性强,适用于任意
缺点:方差通常很高,收敛慢

2.3 策略2:重参数化梯度

重参数化技巧(Reparameterization Trick)

将随机变量 表示为确定性变换 ,其中 是独立于 的噪声:

然后:

def reparam_gradient(f, phi, n_samples=100):
    """使用重参数化技巧估计梯度
    
    优点:方差通常较低
    缺点:需要能写出显式的重参数化映射
    
    Returns:
        gradient: 梯度估计
    """
    gradients = []
    for _ in range(n_samples):
        epsilon = np.random.randn(len(phi))
        z = reparametrize(phi, epsilon)  # z = g(φ, ε)
        gradients.append(f(z))  # 直接对f求梯度
    
    return np.mean(gradients, axis=0)

2.4 策略3:Rao-Blackwellization

Rao-Blackwellization 利用条件期望来降低方差:

如果 ,则:

方差满足:


3. 随机变分推断(SVI)

3.1 大规模数据的挑战

在贝叶斯混合模型或LDA等模型中,数据似然是所有数据点的乘积:

对于大规模数据集,每次迭代计算全部数据的梯度代价太高。

3.2 SVI核心思想

随机梯度变分推断(Stochastic VI)使用小批量(mini-batch)数据来近似完整梯度:

其中 是第 个数据点的局部ELBO。

3.3 SVI算法

def svi(model, X, n_iter=10000, batch_size=100, lr=0.01):
    """随机变分推断
    
    Args:
        model: 概率模型
        X: 数据集 (N, D)
        n_iter: 迭代次数
        batch_size: 小批量大小
        lr: 学习率
    """
    N = len(X)
    phi = initialize_variational_params()
    
    for t in range(n_iter):
        # 采样小批量
        batch_idx = np.random.choice(N, batch_size, replace=False)
        X_batch = X[batch_idx]
        
        # 计算(小批量)梯度
        # 使用重参数化或score function
        grad = compute_elbo_gradient(phi, X_batch, N)
        
        # 随机梯度上升
        phi = phi + lr * grad
        
        # 学习率衰减(可选)
        lr = lr * (1 - 1e-5)
        
        if t % 1000 == 0:
            elbo = compute_elbo(phi, X)
            print(f"Iter {t}: ELBO = {elbo:.4f}")
    
    return phi
 
 
def compute_elbo_gradient(phi, X_batch, N_total):
    """计算小批量ELBO梯度
    
    使用重参数化技巧
    """
    batch_size = len(X_batch)
    scale = N_total / batch_size
    
    gradients = []
    for x in X_batch:
        # 重参数化采样
        epsilon = np.random.randn(dim_z)
        z = mean_field_reparam(phi, epsilon)
        
        # 计算局部ELBO
        local_elbo = scale * local_objective(z, x)
        
        # 重参数化梯度
        grad = torch.autograd.grad(local_elbo, phi)[0]
        gradients.append(grad)
    
    return np.mean(gradients, axis=0)

3.4 自适应学习率

SVI需要仔细调参学习率。可以使用:

方法说明
AdaGrad累积历史梯度
RMSProp指数加权移动平均
Adam结合动量与RMSProp
def svi_adam(model, X, n_iter=10000, batch_size=100):
    """带Adam优化的SVI"""
    
    # Adam超参数
    lr = 0.001
    beta1 = 0.9
    beta2 = 0.999
    epsilon = 1e-8
    
    m = 0  # 一阶矩估计
    v = 0  # 二阶矩估计
    t = 0  # 时间步
    
    for t in range(1, n_iter + 1):
        # 获取小批量梯度
        grad = compute_elbo_gradient(phi, X, batch_size)
        
        # Adam更新
        m = beta1 * m + (1 - beta1) * grad
        v = beta2 * v + (1 - beta2) * (grad ** 2)
        
        # 偏差校正
        m_hat = m / (1 - beta1 ** t)
        v_hat = v / (1 - beta2 ** t)
        
        phi = phi + lr * m_hat / (np.sqrt(v_hat) + epsilon)
    
    return phi

3.5 收敛速率分析

SVI的收敛速率依赖于:

因素影响
批量大小越大方差越小,但计算成本增加
学习率调度需满足
模型结构隐变量维度影响收敛

4. 重要性加权自编码器(IWAEs)

4.1 核心思想

IWAE使用重要性采样来获得比标准ELBO更紧的下界。

重要性加权的ELBO

其中 是重要性权重。

4.2 与标准ELBO的关系

定理:对于任意

时:

4.3 为什么IWAE更好?

方面ELBO ()IWAE ()
下界紧度松散更紧
梯度估计高方差更低方差
后验近似点估计更平滑
计算成本

4.4 IWAE实现

class IWAE(torch.nn.Module):
    """重要性加权自编码器
    
    IWAE同时学习生成模型p_θ和推理网络q_φ
    """
    
    def __init__(self, input_dim, hidden_dim, latent_dim, k=5):
        super().__init__()
        self.k = k  # 重要性采样数
        
        # 编码器:q_φ(z|x)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2 * latent_dim)  # μ, log σ
        )
        
        # 解码器:p_θ(x|z)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2 * input_dim)  # μ, log σ
        )
    
    def reparameterize(self, mu, log_var):
        """重参数化技巧"""
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        """前向传播,计算IWAE损失"""
        batch_size = x.size(0)
        
        # 编码
        h = self.encoder(x)
        mu, log_var = h.chunk(2, dim=-1)
        
        # 重要性采样
        log_weights = []
        for _ in range(self.k):
            # 重参数化采样
            z = self.reparameterize(mu, log_var)
            
            # 计算对数似然 log p_θ(x|z)
            h_dec = self.decoder(z)
            dec_mu, dec_log_var = h_dec.chunk(2, dim=-1)
            log_p_xz = torch.distributions.Normal(dec_mu, dec_log_var.exp()).log_prob(x)
            log_p_xz = log_p_xz.sum(dim=-1)
            
            # 计算先验对数似然 log p(z)
            log_p_z = torch.distributions.Normal(0, 1).log_prob(z).sum(dim=-1)
            
            # 计算推理网络对数似然 log q_φ(z|x)
            log_q_zx = torch.distributions.Normal(mu, log_var.exp()).log_prob(z).sum(dim=-1)
            
            # 重要性权重
            log_w = log_p_xz + log_p_z - log_q_zx
            log_weights.append(log_w)
        
        # 堆叠并归一化
        log_weights = torch.stack(log_weights, dim=1)  # (batch, k)
        
        # IWAE损失:-log(1/K * Σ w_k)
        # 使用log-sum-exp数值稳定计算
        log_ub = torch.logsumexp(log_weights, dim=1) - np.log(self.k)
        loss = -log_ub.mean()
        
        # 返回所有信息用于分析
        return {
            'loss': loss,
            'elbo': log_weights.mean(dim=1).mean(),
            'log_weights': log_weights
        }
    
    def sample(self, n_samples):
        """从生成分布采样"""
        with torch.no_grad():
            z = torch.randn(n_samples, self.latent_dim)
            h = self.decoder(z)
            mu, log_var = h.chunk(2, dim=-1)
            x = mu + torch.randn_like(mu) * log_var.exp()
        return x

4.5 IWAE的理论分析

下界紧度

命题:设 ,则:

证明:利用Jensen不等式的严格性条件…

梯度估计方差

命题:IWAE的梯度方差随 增大而减小(当使用适当的归一化时)。


5. 半隐式变分推断(SIVI)

5.1 标准VI的局限

标准VI使用参数化的显式分布

  • 表达能力受限于分布族(如高斯)
  • 无法捕捉复杂的后验结构

5.2 SIVI核心思想

SIVI使用半隐式分布

其中 是参数化分布,隐式采样的随机变量。

5.3 SIVI实现

class SIVI(nn.Module):
    """半隐式变分推断"""
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        
        # 参数化均值和方差网络
        self.mu_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        
        self.log_var_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        
        # 噪声注入网络(用于半隐式采样)
        self.noise_net = nn.Sequential(
            nn.Linear(input_dim + latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
    
    def sample(self, x, n_samples=1):
        """从半隐式分布采样"""
        mu = self.mu_net(x)
        log_var = self.log_var_net(x)
        
        # 半隐式采样:添加学习的噪声偏移
        eps = torch.randn_like(mu)
        noise = self.noise_net(torch.cat([x, eps], dim=-1))
        
        z = mu + eps * log_var.exp() + noise * 0.1
        return z
    
    def elbo(self, x, n_samples=5):
        """计算SIVI的ELBO"""
        mu = self.mu_net(x)
        log_var = self.log_var_net(x)
        
        elbo_samples = []
        for _ in range(n_samples):
            z = self.sample(x)
            
            # 计算各部分
            log_p_z = torch.distributions.Normal(0, 1).log_prob(z).sum(dim=-1)
            log_q_zx = torch.distributions.Normal(mu, log_var.exp()).log_prob(z).sum(dim=-1)
            log_p_xz = self.decode_log_prob(x, z)
            
            elbo_samples.append(log_p_xz + log_p_z - log_q_zx)
        
        return torch.stack(elbo_samples).mean()

6. 归一化流在VI中的应用

6.1 基本思想

归一化流通过可逆变换 来增强变分分布的表达能力:

其中

6.2 Planar流

class PlanarFlow(nn.Module):
    """Planar归一化流
    
    变换: z = x + u * h(w^T x + b)
    其中 h 是非线性激活函数
    """
    
    def __init__(self, dim):
        super().__init__()
        self.w = nn.Parameter(torch.randn(dim))
        self.u = nn.Parameter(torch.randn(dim))
        self.b = nn.Parameter(torch.randn(1))
    
    def forward(self, z):
        """前向传播"""
        activation = torch.tanh(z @ self.w + self.b)
        z_new = z + self.u * activation
        
        # 计算对数行列式
        psi = (1 - activation**2) * self.w
        det = 1 + torch.sum(self.u * psi)
        log_det = torch.log(det.abs() + 1e-8)
        
        return z_new, log_det
    
    def inverse(self, z):
        """逆变换(数值近似)"""
        # 需要数值迭代求解
        raise NotImplementedError

6.3 Sylvester流

class SylvesterFlow(nn.Module):
    """Sylvester归一化流
    
    比Planar流更灵活,支持多参数
    """
    
    def __init__(self, dim, num_batched=None):
        super().__init__()
        self.dim = dim
        self.num_batched = num_batched
        
        # 正交矩阵参数化
        self.orthog = nn.Parameter(torch.eye(dim))
        
        # 对角缩放
        self.diag = nn.Parameter(torch.ones(dim))
        
        # 移位向量
        self.shift = nn.Parameter(torch.zeros(dim))
    
    def forward(self, z):
        """前向传播"""
        z_new = self.orthog @ (self.diag * z) + self.shift
        
        # 对数行列式 = sum(log(|diag|)) = sum(log(|diag|))
        log_det = torch.sum(torch.log(self.diag.abs() + 1e-8))
        
        return z_new, log_det

6.4 VAE + 归一化流

class NormalizingFlowVAE(nn.Module):
    """带归一化流的VAE"""
    
    def __init__(self, input_dim, latent_dim, n_flows=8):
        super().__init__()
        
        # 编码器
        self.encoder = Encoder(input_dim, latent_dim)
        
        # 初始变分分布参数
        self.loc_net = nn.Linear(latent_dim, latent_dim)
        self.scale_net = nn.Linear(latent_dim, latent_dim)
        
        # 归一化流
        self.flows = nn.ModuleList([PlanarFlow(latent_dim) for _ in range(n_flows)])
        
        # 解码器
        self.decoder = Decoder(latent_dim, input_dim)
    
    def forward(self, x, n_samples=1):
        """前向传播"""
        # 编码
        h = self.encoder(x)
        q0_loc = self.loc_net(h)
        q0_scale = torch.exp(self.scale_net(h) / 2)
        
        # 从初始分布采样
        z0 = q0_loc + q0_scale * torch.randn_like(q0_loc)
        
        # 通过归一化流
        z = z0
        log_det_sum = 0
        for flow in self.flows:
            z, log_det = flow(z)
            log_det_sum += log_det
        
        # 解码
        x_rec = self.decoder(z)
        
        # 计算ELBO
        log_p_xz = x_rec.log_prob(x).sum(dim=-1)
        log_p_z = torch.distributions.Normal(0, 1).log_prob(z).sum(dim=-1)
        log_q_zx = torch.distributions.Normal(q0_loc, q0_scale).log_prob(z0).sum(dim=-1)
        
        elbo = log_p_xz + log_p_z - log_q_zx + log_det_sum
        
        return {
            'elbo': elbo.mean(),
            'x_rec': x_rec.mean,
            'z': z
        }

7. 实践指南

7.1 梯度估计器选择

方法方差计算成本适用场景
Score Function离散变量、不可微模型
Reparameterization连续变量
IWAE中-低高(需要更紧下界
Normalizing Flow需要复杂后验

7.2 初始化策略

def initialize_variational_params(model):
    """变分参数初始化"""
    
    for name, param in model.named_parameters():
        if 'log_var' in name or 'logit' in name:
            # 初始化小方差
            nn.init.constant_(param, -5)
        elif 'loc' in name or 'mu' in name:
            # Xavier初始化
            nn.init.xavier_uniform_(param)
        elif 'weight' in name:
            nn.init.kaiming_normal_(param)
        elif 'bias' in name:
            nn.init.zeros_(param)
    
    return model

7.3 调试技巧

class DebugVI:
    """变分推断调试工具"""
    
    def __init__(self, model):
        self.model = model
        self.history = {
            'elbo': [],
            'kl': [],
            'reconstruction': [],
            'grad_norm': []
        }
    
    def train_step(self, x):
        """带诊断的训练步骤"""
        output = self.model(x)
        
        # 记录指标
        self.history['elbo'].append(output['elbo'].item())
        self.history['kl'].append(output['kl'].item())
        self.history['reconstruction'].append(output['reconstruction'].item())
        
        # 计算梯度范数
        output['loss'].backward()
        grad_norm = sum(p.grad.norm().item() 
                       for p in self.model.parameters() 
                       if p.grad is not None)
        self.history['grad_norm'].append(grad_norm)
        
        # 诊断检查
        self._check_numerics()
        self._check_gradients(grad_norm)
        
        return output['loss']
    
    def _check_numerics(self):
        """检查数值稳定性"""
        for name, param in self.model.named_parameters():
            if torch.isnan(param).any():
                raise ValueError(f"NaN in {name}")
            if torch.isinf(param).any():
                raise ValueError(f"Inf in {name}")
    
    def _check_gradients(self, grad_norm):
        """检查梯度"""
        if grad_norm > 100:
            print(f"Warning: Large gradient norm: {grad_norm}")
        if grad_norm < 0.01:
            print(f"Warning: Small gradient norm: {grad_norm}")
    
    def plot_history(self):
        """可视化训练历史"""
        import matplotlib.pyplot as plt
        
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        
        axes[0, 0].plot(self.history['elbo'])
        axes[0, 0].set_title('ELBO')
        
        axes[0, 1].plot(self.history['kl'])
        axes[0, 1].set_title('KL Divergence')
        
        axes[1, 0].plot(self.history['reconstruction'])
        axes[1, 0].set_title('Reconstruction Loss')
        
        axes[1, 1].plot(self.history['grad_norm'])
        axes[1, 1].set_title('Gradient Norm')
        
        plt.tight_layout()
        plt.savefig('vi_training_diagnostics.png')

8. 完整示例:变分自编码器

import torch
import torch.nn as nn
import torch.distributions as dist
 
class VAE(nn.Module):
    """完整的变分自编码器实现"""
    
    def __init__(self, input_dim, hidden_dim, latent_dim, n_flows=0):
        super().__init__()
        self.latent_dim = latent_dim
        self.n_flows = n_flows
        
        # 编码器
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2 * latent_dim)
        )
        
        # 解码器
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2 * input_dim)
        )
        
        # 归一化流(可选)
        if n_flows > 0:
            self.flows = nn.ModuleList([
                PlanarFlow(latent_dim) for _ in range(n_flows)
            ])
        else:
            self.flows = []
    
    def encode(self, x):
        """编码"""
        h = self.encoder(x)
        mu, log_var = h.chunk(2, dim=-1)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        """重参数化"""
        std = (0.5 * log_var).exp()
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def flow_transform(self, z0, log_det_sum=0):
        """归一化流变换"""
        z = z0
        for flow in self.flows:
            z, log_det = flow(z)
            log_det_sum = log_det_sum + log_det
        return z, log_det_sum
    
    def decode(self, z):
        """解码"""
        h = self.decoder(z)
        mu, log_var = h.chunk(2, dim=-1)
        return dist.Normal(mu, log_var.exp())
    
    def forward(self, x, k=1):
        """前向传播,计算ELBO"""
        # 编码
        mu, log_var = self.encode(x)
        
        # 采样
        z = self.reparameterize(mu, log_var)
        
        # 归一化流(如果使用)
        log_det_sum = 0
        if self.n_flows > 0:
            z, log_det_sum = self.flow_transform(z, log_det_sum)
        
        # 解码
        p_x_given_z = self.decode(z)
        
        # 计算ELBO
        log_p_xz = p_x_given_z.log_prob(x).sum(dim=-1)
        log_p_z = dist.Normal(0, 1).log_prob(z).sum(dim=-1)
        log_q_zx = dist.Normal(mu, log_var.exp()).log_prob(z).sum(dim=-1)
        
        elbo = log_p_xz + log_p_z - log_q_zx + log_det_sum
        
        return {
            'elbo': elbo.mean(),
            'reconstruction': log_p_xz.mean(),
            'kl': (log_q_zx - log_p_z).mean(),
            'x_rec': p_x_given_z.mean,
            'z': z
        }
    
    def sample(self, n_samples):
        """从先验采样并生成"""
        with torch.no_grad():
            z = torch.randn(n_samples, self.latent_dim)
            p_x_given_z = self.decode(z)
            x_samples = p_x_given_z.sample()
        return x_samples
 
 
# 训练脚本
def train_vae(X, input_dim, hidden_dim=400, latent_dim=20, 
              n_epochs=100, batch_size=128, lr=1e-3, n_flows=0):
    """训练VAE"""
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = VAE(input_dim, hidden_dim, latent_dim, n_flows).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    dataset = torch.tensor(X, dtype=torch.float32)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    history = {'elbo': [], 'recon': [], 'kl': []}
    
    for epoch in range(n_epochs):
        epoch_elbo = 0
        epoch_recon = 0
        epoch_kl = 0
        
        for batch in dataloader:
            batch = batch.to(device)
            
            optimizer.zero_grad()
            output = model(batch)
            
            loss = -output['elbo']
            loss.backward()
            
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            optimizer.step()
            
            epoch_elbo += output['elbo'].item()
            epoch_recon += output['reconstruction'].item()
            epoch_kl += output['kl'].item()
        
        n_batches = len(dataloader)
        history['elbo'].append(epoch_elbo / n_batches)
        history['recon'].append(epoch_recon / n_batches)
        history['kl'].append(epoch_kl / n_batches)
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{n_epochs}, ELBO: {history['elbo'][-1]:.2f}, "
                  f"Recon: {history['recon'][-1]:.2f}, KL: {history['kl'][-1]:.2f}")
    
    return model, history

参考文献