概述
重要性加权自编码器(Importance Weighted Autoencoder,IWAE)是Burda et al. (2016)提出的一种变分推断方法,通过使用多个样本估计ELBO来获得更紧的下界,从而提高近似精度。1
IWAE的核心思想是利用重要性采样(Importance Sampling)的思想,在变分推断中引入多个隐变量样本,使估计的变分下界更接近真实的对数似然。
标准VAE的局限性
ELBO的紧性
标准VAE优化的证据下界(ELBO)为:
这个下界相对于对数似然 的”紧性”取决于近似后验 与真实后验 的接近程度。
单样本估计的高方差
当使用单个样本估计期望时:
这种估计的方差很大,导致:
- 梯度估计噪声高
- 训练不稳定
- 下界可能远离真实对数似然
IWAE的核心思想
多样本ELBO
IWAE使用 个独立样本来估计ELBO:
与标准ELBO的关系
当 时,IWAE退化为标准ELBO:
下界的紧性
IWAE的样本估计提供了更紧的下界:
随着 增加,下界单调递增,并收敛到真实对数似然。
理论基础:重要性采样
重要性采样回顾
给定目标分布 和提议分布 ,重要性采样通过提议分布采样来估计期望:
IWAE的解释
在IWAE中,真实后验 可以通过 来估计:
利用这个恒等式,可以构造更紧的下界。
紧性证明
根据Jensen不等式:
而根据重要性采样的中心极限定理,多样本估计的方差更小,因此更接近真实的 。
IWAE的训练目标
优化目标
IWAE直接优化多样本ELBO:
蒙特卡洛估计
使用有限样本近似期望:
其中 是重要性权重。
梯度估计方法
1. REINFORCE(对数导数技巧)
最基本的梯度估计方法:
但这种方法方差很大。
2. 归一化权重梯度
更稳定的梯度估计:
展开为:
3. 直接梯度
将权重看作常数,直接对 求导:
这种方法被称为直通估计器(Straight-Through Estimator)。
低方差梯度估计:IMQ和REBAR
IMQ (Inverse Multinomial Quadrature)
使用解析分布近似重要性权重:
其中 是某种距离度量。
REBAR方法
结合reinforced和baseline技术:
def rebar_gradient(log_w, z, z_noise, phi):
"""
REBAR梯度估计器
log_w: 真实对数权重
z: 从q采样
z_noise: 从辅助分布采样
"""
# 基线项
baseline = tf.stop_gradient(log_w.mean())
# 真实梯度
grad_true = tf.gradients(log_w, phi)
# 辅助梯度
log_q_noise = q.log_prob(z_noise)
grad_aux = tf.gradients(log_q_noise, phi)
# REBAR估计
return (log_w - baseline) * grad_true + tf.stop_gradient(log_w) * grad_auxIWAE的实现
PyTorch实现
import torch
import torch.nn as nn
from torch.distributions import Normal
class IWAE(nn.Module):
def __init__(self, encoder, decoder, latent_dim, k=5):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.latent_dim = latent_dim
self.k = k # 样本数量
def forward(self, x):
batch_size = x.shape[0]
# 编码得到高斯参数
mu, log_var = self.encoder(x)
std = torch.exp(0.5 * log_var)
# 从高斯分布采样K个样本
# [K, batch_size, latent_dim]
z = mu + std * torch.randn(self.k, batch_size, self.latent_dim,
device=x.device)
# 重参数化技巧
z_flat = z.view(-1, self.latent_dim)
x_flat = x.unsqueeze(0).expand(self.k, -1, -1).reshape(-1, x.shape[1])
# 解码
logits = self.decoder(z_flat)
# 计算对数似然
log_p_x_given_z = -nn.functional.binary_cross_entropy_with_logits(
logits, x_flat, reduction='none'
).view(self.k, batch_size, -1).sum(dim=-1)
# 计算先验对数似然
log_p_z = Normal(0, 1).log_prob(z).sum(dim=-1)
# 计算变分后验对数似然
log_q_z_given_x = Normal(mu, std).log_prob(z).sum(dim=-1)
# 重要性权重
log_weights = log_p_x_given_z + log_p_z - log_q_z_given_x
# IWAE损失(负ELBO)
# 使用log-sum-exp技巧提高数值稳定性
log_loss = torch.logsumexp(log_weights, dim=0) - torch.log(torch.tensor(self.k))
return -log_loss.mean()
def elbo(self, x):
"""返回标准ELBO用于比较"""
with torch.no_grad():
mu, log_var = self.encoder(x)
std = torch.exp(0.5 * log_var)
z = mu + std * torch.randn_like(mu)
logits = self.decoder(z)
log_p_x_given_z = -nn.functional.binary_cross_entropy_with_logits(
logits, x, reduction='none'
).sum(dim=-1)
log_p_z = Normal(0, 1).log_prob(z).sum(dim=-1)
log_q_z_given_x = Normal(mu, std).log_prob(z).sum(dim=-1)
return (log_p_x_given_z + log_p_z - log_q_z_given_x).mean()TensorFlow/Pyro实现
Pyro库提供了IWAE的开箱即用支持:
import pyro
import pyro.distributions as dist
from pyro.infer import ImportanceSampling, EmpiricalMarginal
def model(x):
# 先验
z = pyro.sample("z", dist.Normal(0, 1).expand([latent_dim]))
# 生成
logits = decoder(z)
pyro.sample("obs", dist.Bernoulli(logits=logits), obs=x)
def guide(x):
# 变分后验
mu, log_var = encoder(x)
pyro.sample("z", dist.Normal(mu, log_var.exp()))
# IWAE推断
elbo = ImportanceSampling(model, guide, num_samples=10)
loss = elbo.loss(model, guide, x)IWAE与标准VAE的对比
理论对比
| 特性 | 标准VAE (K=1) | IWAE (K>1) |
|---|---|---|
| ELBO | 紧性较低 | 紧性更高 |
| 梯度方差 | 高 | 较低 |
| 计算成本 | ||
| 隐表示质量 | 一般 | 更好 |
| 生成样本 | 中等 | 更高质量 |
实验对比
在MNIST数据集上的典型结果:
| 方法 | 测试ELBO | 样本数K |
|---|---|---|
| VAE | -86.5 | 1 |
| IWAE | -83.2 | 5 |
| IWAE | -81.8 | 10 |
| IWAE | -80.5 | 50 |
定性观察
- 隐空间结构:IWAE学习更连贯的隐空间结构
- 插值质量:隐空间插值更加平滑
- 后验覆盖:后验分布更好地覆盖真实后验
IWAE的扩展
1. 多层IWAE
堆叠多层隐变量:
2. 流增强IWAE
在变分后验中引入归一化流:
class FlowIWAE(IWAE):
def __init__(self, encoder, decoder, flow, k=5):
super().__init__(encoder, decoder, k)
self.flow = flow
def forward(self, x):
# 编码
mu, log_var = self.encoder(x)
std = torch.exp(0.5 * log_var)
# 采样并通过流变换
z0 = mu + std * torch.randn(self.k, *mu.shape, device=x.device)
zK, log_det = self.flow(z0)
# 计算ELBO(包含流变换的雅可比行列式)
# ...3. 双向IWAE
使用双向推断网络,允许从数据和隐变量双向生成。
IWAE的局限性
计算成本
每个样本都需要通过解码器,计算成本随 线性增长:
梯度消失问题
当 很大时,权重可能高度不平衡,导致少数样本主导梯度:
收敛速度
IWAE通常需要更多的训练迭代才能达到最优。
实践建议
1. 样本数量选择
| 数据规模 | 推荐K |
|---|---|
| 小数据集 | 1-5 |
| 中等数据集 | 5-10 |
| 大规模实验 | 10-50 |
2. 学习率调整
IWAE可能需要更小的学习率,因为梯度方差虽然降低,但仍然存在。
3. 早停策略
监控验证集ELBO,当连续几个epoch没有提升时停止训练。
4. 与标准ELBO对比
定期计算标准ELBO(K=1)来监控真实对数似然的近似程度。
参考
相关主题
- variational-inference-advanced - 变分推断进阶
- normalizing-flows-variational - 归一化流与变分推断
- bayesian-neural-networks-advanced-inference - 贝叶斯神经网络高级推断
- probabilistic-circuits-fundamentals - 概率电路基础
- em-algorithm - EM算法
Footnotes
-
Burda et al. (2016). “Importance Weighted Autoencoders”. ICLR 2016. ↩