概述

重要性加权自编码器(Importance Weighted Autoencoder,IWAE)是Burda et al. (2016)提出的一种变分推断方法,通过使用多个样本估计ELBO来获得更紧的下界,从而提高近似精度。1

IWAE的核心思想是利用重要性采样(Importance Sampling)的思想,在变分推断中引入多个隐变量样本,使估计的变分下界更接近真实的对数似然。


标准VAE的局限性

ELBO的紧性

标准VAE优化的证据下界(ELBO)为:

这个下界相对于对数似然 的”紧性”取决于近似后验 与真实后验 的接近程度。

单样本估计的高方差

当使用单个样本估计期望时:

这种估计的方差很大,导致:

  1. 梯度估计噪声高
  2. 训练不稳定
  3. 下界可能远离真实对数似然

IWAE的核心思想

多样本ELBO

IWAE使用 个独立样本来估计ELBO:

与标准ELBO的关系

时,IWAE退化为标准ELBO:

下界的紧性

IWAE的样本估计提供了更紧的下界

随着 增加,下界单调递增,并收敛到真实对数似然。


理论基础:重要性采样

重要性采样回顾

给定目标分布 和提议分布 ,重要性采样通过提议分布采样来估计期望:

IWAE的解释

在IWAE中,真实后验 可以通过 来估计:

利用这个恒等式,可以构造更紧的下界。

紧性证明

根据Jensen不等式:

而根据重要性采样的中心极限定理,多样本估计的方差更小,因此更接近真实的


IWAE的训练目标

优化目标

IWAE直接优化多样本ELBO:

蒙特卡洛估计

使用有限样本近似期望:

其中 是重要性权重。


梯度估计方法

1. REINFORCE(对数导数技巧)

最基本的梯度估计方法:

但这种方法方差很大。

2. 归一化权重梯度

更稳定的梯度估计:

展开为:

3. 直接梯度

将权重看作常数,直接对 求导:

这种方法被称为直通估计器(Straight-Through Estimator)。


低方差梯度估计:IMQ和REBAR

IMQ (Inverse Multinomial Quadrature)

使用解析分布近似重要性权重:

其中 是某种距离度量。

REBAR方法

结合reinforced和baseline技术:

def rebar_gradient(log_w, z, z_noise, phi):
    """
    REBAR梯度估计器
    log_w: 真实对数权重
    z: 从q采样
    z_noise: 从辅助分布采样
    """
    # 基线项
    baseline = tf.stop_gradient(log_w.mean())
    
    # 真实梯度
    grad_true = tf.gradients(log_w, phi)
    
    # 辅助梯度
    log_q_noise = q.log_prob(z_noise)
    grad_aux = tf.gradients(log_q_noise, phi)
    
    # REBAR估计
    return (log_w - baseline) * grad_true + tf.stop_gradient(log_w) * grad_aux

IWAE的实现

PyTorch实现

import torch
import torch.nn as nn
from torch.distributions import Normal
 
class IWAE(nn.Module):
    def __init__(self, encoder, decoder, latent_dim, k=5):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.latent_dim = latent_dim
        self.k = k  # 样本数量
    
    def forward(self, x):
        batch_size = x.shape[0]
        
        # 编码得到高斯参数
        mu, log_var = self.encoder(x)
        std = torch.exp(0.5 * log_var)
        
        # 从高斯分布采样K个样本
        # [K, batch_size, latent_dim]
        z = mu + std * torch.randn(self.k, batch_size, self.latent_dim, 
                                   device=x.device)
        
        # 重参数化技巧
        z_flat = z.view(-1, self.latent_dim)
        x_flat = x.unsqueeze(0).expand(self.k, -1, -1).reshape(-1, x.shape[1])
        
        # 解码
        logits = self.decoder(z_flat)
        
        # 计算对数似然
        log_p_x_given_z = -nn.functional.binary_cross_entropy_with_logits(
            logits, x_flat, reduction='none'
        ).view(self.k, batch_size, -1).sum(dim=-1)
        
        # 计算先验对数似然
        log_p_z = Normal(0, 1).log_prob(z).sum(dim=-1)
        
        # 计算变分后验对数似然
        log_q_z_given_x = Normal(mu, std).log_prob(z).sum(dim=-1)
        
        # 重要性权重
        log_weights = log_p_x_given_z + log_p_z - log_q_z_given_x
        
        # IWAE损失(负ELBO)
        # 使用log-sum-exp技巧提高数值稳定性
        log_loss = torch.logsumexp(log_weights, dim=0) - torch.log(torch.tensor(self.k))
        
        return -log_loss.mean()
    
    def elbo(self, x):
        """返回标准ELBO用于比较"""
        with torch.no_grad():
            mu, log_var = self.encoder(x)
            std = torch.exp(0.5 * log_var)
            z = mu + std * torch.randn_like(mu)
            
            logits = self.decoder(z)
            log_p_x_given_z = -nn.functional.binary_cross_entropy_with_logits(
                logits, x, reduction='none'
            ).sum(dim=-1)
            log_p_z = Normal(0, 1).log_prob(z).sum(dim=-1)
            log_q_z_given_x = Normal(mu, std).log_prob(z).sum(dim=-1)
            
            return (log_p_x_given_z + log_p_z - log_q_z_given_x).mean()

TensorFlow/Pyro实现

Pyro库提供了IWAE的开箱即用支持:

import pyro
import pyro.distributions as dist
from pyro.infer import ImportanceSampling, EmpiricalMarginal
 
def model(x):
    # 先验
    z = pyro.sample("z", dist.Normal(0, 1).expand([latent_dim]))
    # 生成
    logits = decoder(z)
    pyro.sample("obs", dist.Bernoulli(logits=logits), obs=x)
 
def guide(x):
    # 变分后验
    mu, log_var = encoder(x)
    pyro.sample("z", dist.Normal(mu, log_var.exp()))
 
# IWAE推断
elbo = ImportanceSampling(model, guide, num_samples=10)
loss = elbo.loss(model, guide, x)

IWAE与标准VAE的对比

理论对比

特性标准VAE (K=1)IWAE (K>1)
ELBO紧性较低紧性更高
梯度方差较低
计算成本
隐表示质量一般更好
生成样本中等更高质量

实验对比

在MNIST数据集上的典型结果:

方法测试ELBO样本数K
VAE-86.51
IWAE-83.25
IWAE-81.810
IWAE-80.550

定性观察

  1. 隐空间结构:IWAE学习更连贯的隐空间结构
  2. 插值质量:隐空间插值更加平滑
  3. 后验覆盖:后验分布更好地覆盖真实后验

IWAE的扩展

1. 多层IWAE

堆叠多层隐变量:

2. 流增强IWAE

在变分后验中引入归一化流:

class FlowIWAE(IWAE):
    def __init__(self, encoder, decoder, flow, k=5):
        super().__init__(encoder, decoder, k)
        self.flow = flow
    
    def forward(self, x):
        # 编码
        mu, log_var = self.encoder(x)
        std = torch.exp(0.5 * log_var)
        
        # 采样并通过流变换
        z0 = mu + std * torch.randn(self.k, *mu.shape, device=x.device)
        zK, log_det = self.flow(z0)
        
        # 计算ELBO(包含流变换的雅可比行列式)
        # ...

3. 双向IWAE

使用双向推断网络,允许从数据和隐变量双向生成。


IWAE的局限性

计算成本

每个样本都需要通过解码器,计算成本随 线性增长:

梯度消失问题

很大时,权重可能高度不平衡,导致少数样本主导梯度:

收敛速度

IWAE通常需要更多的训练迭代才能达到最优。


实践建议

1. 样本数量选择

数据规模推荐K
小数据集1-5
中等数据集5-10
大规模实验10-50

2. 学习率调整

IWAE可能需要更小的学习率,因为梯度方差虽然降低,但仍然存在。

3. 早停策略

监控验证集ELBO,当连续几个epoch没有提升时停止训练。

4. 与标准ELBO对比

定期计算标准ELBO(K=1)来监控真实对数似然的近似程度。


参考


相关主题

Footnotes

  1. Burda et al. (2016). “Importance Weighted Autoencoders”. ICLR 2016.