变分推断进阶:SVI、IWAEs与重参数化技巧
1. 背景回顾
1.1 变分推断基础
变分推断(Variational Inference, VI)的核心思想是用一个参数化的近似分布 去近似真实后验 ,通过最小化两者之间的KL散度:
其中 是证据下界(ELBO)。
1.2 本章目标
| 主题 | 内容 |
|---|---|
| 黑盒变分推断(BBVI) | 无需解析梯度,使用Monte Carlo估计 |
| 随机变分推断(SVI) | 大规模数据的变分推断 |
| 重要性加权自编码器(IWAE) | 比ELBO更紧的下界 |
| 半隐式变分推断(SIVI) | 非参数化变分族 |
| 重参数化技巧 | 降低梯度估计方差 |
| 归一化流 | 增强变分分布表达能力 |
2. 黑盒变分推断(BBVI)
2.1 核心问题
在标准变分推断中,我们需要计算对ELBO的梯度:
其中 。
问题:梯度不能直接通过期望内部传递。
2.2 策略1:score function梯度(REINFORCE)
score function恒等式:
def score_function_gradient(f, q, n_samples=100):
"""使用score function估计梯度
问题:方差通常很高
Returns:
gradient: 梯度估计
"""
gradients = []
for _ in range(n_samples):
z = q.sample() # 从变分分布采样
grad_log_q = q.score(z) # ∇_φ log q_φ(z)
gradients.append(f(z) * grad_log_q)
return np.mean(gradients, axis=0)优点:通用性强,适用于任意
缺点:方差通常很高,收敛慢
2.3 策略2:重参数化梯度
重参数化技巧(Reparameterization Trick):
将随机变量 表示为确定性变换 ,其中 是独立于 的噪声:
然后:
def reparam_gradient(f, phi, n_samples=100):
"""使用重参数化技巧估计梯度
优点:方差通常较低
缺点:需要能写出显式的重参数化映射
Returns:
gradient: 梯度估计
"""
gradients = []
for _ in range(n_samples):
epsilon = np.random.randn(len(phi))
z = reparametrize(phi, epsilon) # z = g(φ, ε)
gradients.append(f(z)) # 直接对f求梯度
return np.mean(gradients, axis=0)2.4 策略3:Rao-Blackwellization
Rao-Blackwellization 利用条件期望来降低方差:
如果 ,则:
方差满足:
3. 随机变分推断(SVI)
3.1 大规模数据的挑战
在贝叶斯混合模型或LDA等模型中,数据似然是所有数据点的乘积:
对于大规模数据集,每次迭代计算全部数据的梯度代价太高。
3.2 SVI核心思想
随机梯度变分推断(Stochastic VI)使用小批量(mini-batch)数据来近似完整梯度:
其中 是第 个数据点的局部ELBO。
3.3 SVI算法
def svi(model, X, n_iter=10000, batch_size=100, lr=0.01):
"""随机变分推断
Args:
model: 概率模型
X: 数据集 (N, D)
n_iter: 迭代次数
batch_size: 小批量大小
lr: 学习率
"""
N = len(X)
phi = initialize_variational_params()
for t in range(n_iter):
# 采样小批量
batch_idx = np.random.choice(N, batch_size, replace=False)
X_batch = X[batch_idx]
# 计算(小批量)梯度
# 使用重参数化或score function
grad = compute_elbo_gradient(phi, X_batch, N)
# 随机梯度上升
phi = phi + lr * grad
# 学习率衰减(可选)
lr = lr * (1 - 1e-5)
if t % 1000 == 0:
elbo = compute_elbo(phi, X)
print(f"Iter {t}: ELBO = {elbo:.4f}")
return phi
def compute_elbo_gradient(phi, X_batch, N_total):
"""计算小批量ELBO梯度
使用重参数化技巧
"""
batch_size = len(X_batch)
scale = N_total / batch_size
gradients = []
for x in X_batch:
# 重参数化采样
epsilon = np.random.randn(dim_z)
z = mean_field_reparam(phi, epsilon)
# 计算局部ELBO
local_elbo = scale * local_objective(z, x)
# 重参数化梯度
grad = torch.autograd.grad(local_elbo, phi)[0]
gradients.append(grad)
return np.mean(gradients, axis=0)3.4 自适应学习率
SVI需要仔细调参学习率。可以使用:
| 方法 | 说明 |
|---|---|
| AdaGrad | 累积历史梯度 |
| RMSProp | 指数加权移动平均 |
| Adam | 结合动量与RMSProp |
def svi_adam(model, X, n_iter=10000, batch_size=100):
"""带Adam优化的SVI"""
# Adam超参数
lr = 0.001
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
m = 0 # 一阶矩估计
v = 0 # 二阶矩估计
t = 0 # 时间步
for t in range(1, n_iter + 1):
# 获取小批量梯度
grad = compute_elbo_gradient(phi, X, batch_size)
# Adam更新
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * (grad ** 2)
# 偏差校正
m_hat = m / (1 - beta1 ** t)
v_hat = v / (1 - beta2 ** t)
phi = phi + lr * m_hat / (np.sqrt(v_hat) + epsilon)
return phi3.5 收敛速率分析
SVI的收敛速率依赖于:
| 因素 | 影响 |
|---|---|
| 批量大小 | 越大方差越小,但计算成本增加 |
| 学习率调度 | 需满足 |
| 模型结构 | 隐变量维度影响收敛 |
4. 重要性加权自编码器(IWAEs)
4.1 核心思想
IWAE使用重要性采样来获得比标准ELBO更紧的下界。
重要性加权的ELBO:
其中 是重要性权重。
4.2 与标准ELBO的关系
定理:对于任意 :
当 时:
4.3 为什么IWAE更好?
| 方面 | ELBO () | IWAE () |
|---|---|---|
| 下界紧度 | 松散 | 更紧 |
| 梯度估计 | 高方差 | 更低方差 |
| 后验近似 | 点估计 | 更平滑 |
| 计算成本 |
4.4 IWAE实现
class IWAE(torch.nn.Module):
"""重要性加权自编码器
IWAE同时学习生成模型p_θ和推理网络q_φ
"""
def __init__(self, input_dim, hidden_dim, latent_dim, k=5):
super().__init__()
self.k = k # 重要性采样数
# 编码器:q_φ(z|x)
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2 * latent_dim) # μ, log σ
)
# 解码器:p_θ(x|z)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2 * input_dim) # μ, log σ
)
def reparameterize(self, mu, log_var):
"""重参数化技巧"""
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
"""前向传播,计算IWAE损失"""
batch_size = x.size(0)
# 编码
h = self.encoder(x)
mu, log_var = h.chunk(2, dim=-1)
# 重要性采样
log_weights = []
for _ in range(self.k):
# 重参数化采样
z = self.reparameterize(mu, log_var)
# 计算对数似然 log p_θ(x|z)
h_dec = self.decoder(z)
dec_mu, dec_log_var = h_dec.chunk(2, dim=-1)
log_p_xz = torch.distributions.Normal(dec_mu, dec_log_var.exp()).log_prob(x)
log_p_xz = log_p_xz.sum(dim=-1)
# 计算先验对数似然 log p(z)
log_p_z = torch.distributions.Normal(0, 1).log_prob(z).sum(dim=-1)
# 计算推理网络对数似然 log q_φ(z|x)
log_q_zx = torch.distributions.Normal(mu, log_var.exp()).log_prob(z).sum(dim=-1)
# 重要性权重
log_w = log_p_xz + log_p_z - log_q_zx
log_weights.append(log_w)
# 堆叠并归一化
log_weights = torch.stack(log_weights, dim=1) # (batch, k)
# IWAE损失:-log(1/K * Σ w_k)
# 使用log-sum-exp数值稳定计算
log_ub = torch.logsumexp(log_weights, dim=1) - np.log(self.k)
loss = -log_ub.mean()
# 返回所有信息用于分析
return {
'loss': loss,
'elbo': log_weights.mean(dim=1).mean(),
'log_weights': log_weights
}
def sample(self, n_samples):
"""从生成分布采样"""
with torch.no_grad():
z = torch.randn(n_samples, self.latent_dim)
h = self.decoder(z)
mu, log_var = h.chunk(2, dim=-1)
x = mu + torch.randn_like(mu) * log_var.exp()
return x4.5 IWAE的理论分析
下界紧度
命题:设 ,则:
证明:利用Jensen不等式的严格性条件…
梯度估计方差
命题:IWAE的梯度方差随 增大而减小(当使用适当的归一化时)。
5. 半隐式变分推断(SIVI)
5.1 标准VI的局限
标准VI使用参数化的显式分布 :
- 表达能力受限于分布族(如高斯)
- 无法捕捉复杂的后验结构
5.2 SIVI核心思想
SIVI使用半隐式分布:
其中 是参数化分布, 是隐式采样的随机变量。
5.3 SIVI实现
class SIVI(nn.Module):
"""半隐式变分推断"""
def __init__(self, input_dim, hidden_dim, latent_dim):
super().__init__()
# 参数化均值和方差网络
self.mu_net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim)
)
self.log_var_net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim)
)
# 噪声注入网络(用于半隐式采样)
self.noise_net = nn.Sequential(
nn.Linear(input_dim + latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim)
)
def sample(self, x, n_samples=1):
"""从半隐式分布采样"""
mu = self.mu_net(x)
log_var = self.log_var_net(x)
# 半隐式采样:添加学习的噪声偏移
eps = torch.randn_like(mu)
noise = self.noise_net(torch.cat([x, eps], dim=-1))
z = mu + eps * log_var.exp() + noise * 0.1
return z
def elbo(self, x, n_samples=5):
"""计算SIVI的ELBO"""
mu = self.mu_net(x)
log_var = self.log_var_net(x)
elbo_samples = []
for _ in range(n_samples):
z = self.sample(x)
# 计算各部分
log_p_z = torch.distributions.Normal(0, 1).log_prob(z).sum(dim=-1)
log_q_zx = torch.distributions.Normal(mu, log_var.exp()).log_prob(z).sum(dim=-1)
log_p_xz = self.decode_log_prob(x, z)
elbo_samples.append(log_p_xz + log_p_z - log_q_zx)
return torch.stack(elbo_samples).mean()6. 归一化流在VI中的应用
6.1 基本思想
归一化流通过可逆变换 来增强变分分布的表达能力:
其中 。
6.2 Planar流
class PlanarFlow(nn.Module):
"""Planar归一化流
变换: z = x + u * h(w^T x + b)
其中 h 是非线性激活函数
"""
def __init__(self, dim):
super().__init__()
self.w = nn.Parameter(torch.randn(dim))
self.u = nn.Parameter(torch.randn(dim))
self.b = nn.Parameter(torch.randn(1))
def forward(self, z):
"""前向传播"""
activation = torch.tanh(z @ self.w + self.b)
z_new = z + self.u * activation
# 计算对数行列式
psi = (1 - activation**2) * self.w
det = 1 + torch.sum(self.u * psi)
log_det = torch.log(det.abs() + 1e-8)
return z_new, log_det
def inverse(self, z):
"""逆变换(数值近似)"""
# 需要数值迭代求解
raise NotImplementedError6.3 Sylvester流
class SylvesterFlow(nn.Module):
"""Sylvester归一化流
比Planar流更灵活,支持多参数
"""
def __init__(self, dim, num_batched=None):
super().__init__()
self.dim = dim
self.num_batched = num_batched
# 正交矩阵参数化
self.orthog = nn.Parameter(torch.eye(dim))
# 对角缩放
self.diag = nn.Parameter(torch.ones(dim))
# 移位向量
self.shift = nn.Parameter(torch.zeros(dim))
def forward(self, z):
"""前向传播"""
z_new = self.orthog @ (self.diag * z) + self.shift
# 对数行列式 = sum(log(|diag|)) = sum(log(|diag|))
log_det = torch.sum(torch.log(self.diag.abs() + 1e-8))
return z_new, log_det6.4 VAE + 归一化流
class NormalizingFlowVAE(nn.Module):
"""带归一化流的VAE"""
def __init__(self, input_dim, latent_dim, n_flows=8):
super().__init__()
# 编码器
self.encoder = Encoder(input_dim, latent_dim)
# 初始变分分布参数
self.loc_net = nn.Linear(latent_dim, latent_dim)
self.scale_net = nn.Linear(latent_dim, latent_dim)
# 归一化流
self.flows = nn.ModuleList([PlanarFlow(latent_dim) for _ in range(n_flows)])
# 解码器
self.decoder = Decoder(latent_dim, input_dim)
def forward(self, x, n_samples=1):
"""前向传播"""
# 编码
h = self.encoder(x)
q0_loc = self.loc_net(h)
q0_scale = torch.exp(self.scale_net(h) / 2)
# 从初始分布采样
z0 = q0_loc + q0_scale * torch.randn_like(q0_loc)
# 通过归一化流
z = z0
log_det_sum = 0
for flow in self.flows:
z, log_det = flow(z)
log_det_sum += log_det
# 解码
x_rec = self.decoder(z)
# 计算ELBO
log_p_xz = x_rec.log_prob(x).sum(dim=-1)
log_p_z = torch.distributions.Normal(0, 1).log_prob(z).sum(dim=-1)
log_q_zx = torch.distributions.Normal(q0_loc, q0_scale).log_prob(z0).sum(dim=-1)
elbo = log_p_xz + log_p_z - log_q_zx + log_det_sum
return {
'elbo': elbo.mean(),
'x_rec': x_rec.mean,
'z': z
}7. 实践指南
7.1 梯度估计器选择
| 方法 | 方差 | 计算成本 | 适用场景 |
|---|---|---|---|
| Score Function | 高 | 中 | 离散变量、不可微模型 |
| Reparameterization | 低 | 中 | 连续变量 |
| IWAE | 中-低 | 高() | 需要更紧下界 |
| Normalizing Flow | 低 | 高 | 需要复杂后验 |
7.2 初始化策略
def initialize_variational_params(model):
"""变分参数初始化"""
for name, param in model.named_parameters():
if 'log_var' in name or 'logit' in name:
# 初始化小方差
nn.init.constant_(param, -5)
elif 'loc' in name or 'mu' in name:
# Xavier初始化
nn.init.xavier_uniform_(param)
elif 'weight' in name:
nn.init.kaiming_normal_(param)
elif 'bias' in name:
nn.init.zeros_(param)
return model7.3 调试技巧
class DebugVI:
"""变分推断调试工具"""
def __init__(self, model):
self.model = model
self.history = {
'elbo': [],
'kl': [],
'reconstruction': [],
'grad_norm': []
}
def train_step(self, x):
"""带诊断的训练步骤"""
output = self.model(x)
# 记录指标
self.history['elbo'].append(output['elbo'].item())
self.history['kl'].append(output['kl'].item())
self.history['reconstruction'].append(output['reconstruction'].item())
# 计算梯度范数
output['loss'].backward()
grad_norm = sum(p.grad.norm().item()
for p in self.model.parameters()
if p.grad is not None)
self.history['grad_norm'].append(grad_norm)
# 诊断检查
self._check_numerics()
self._check_gradients(grad_norm)
return output['loss']
def _check_numerics(self):
"""检查数值稳定性"""
for name, param in self.model.named_parameters():
if torch.isnan(param).any():
raise ValueError(f"NaN in {name}")
if torch.isinf(param).any():
raise ValueError(f"Inf in {name}")
def _check_gradients(self, grad_norm):
"""检查梯度"""
if grad_norm > 100:
print(f"Warning: Large gradient norm: {grad_norm}")
if grad_norm < 0.01:
print(f"Warning: Small gradient norm: {grad_norm}")
def plot_history(self):
"""可视化训练历史"""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
axes[0, 0].plot(self.history['elbo'])
axes[0, 0].set_title('ELBO')
axes[0, 1].plot(self.history['kl'])
axes[0, 1].set_title('KL Divergence')
axes[1, 0].plot(self.history['reconstruction'])
axes[1, 0].set_title('Reconstruction Loss')
axes[1, 1].plot(self.history['grad_norm'])
axes[1, 1].set_title('Gradient Norm')
plt.tight_layout()
plt.savefig('vi_training_diagnostics.png')8. 完整示例:变分自编码器
import torch
import torch.nn as nn
import torch.distributions as dist
class VAE(nn.Module):
"""完整的变分自编码器实现"""
def __init__(self, input_dim, hidden_dim, latent_dim, n_flows=0):
super().__init__()
self.latent_dim = latent_dim
self.n_flows = n_flows
# 编码器
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2 * latent_dim)
)
# 解码器
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2 * input_dim)
)
# 归一化流(可选)
if n_flows > 0:
self.flows = nn.ModuleList([
PlanarFlow(latent_dim) for _ in range(n_flows)
])
else:
self.flows = []
def encode(self, x):
"""编码"""
h = self.encoder(x)
mu, log_var = h.chunk(2, dim=-1)
return mu, log_var
def reparameterize(self, mu, log_var):
"""重参数化"""
std = (0.5 * log_var).exp()
eps = torch.randn_like(std)
return mu + eps * std
def flow_transform(self, z0, log_det_sum=0):
"""归一化流变换"""
z = z0
for flow in self.flows:
z, log_det = flow(z)
log_det_sum = log_det_sum + log_det
return z, log_det_sum
def decode(self, z):
"""解码"""
h = self.decoder(z)
mu, log_var = h.chunk(2, dim=-1)
return dist.Normal(mu, log_var.exp())
def forward(self, x, k=1):
"""前向传播,计算ELBO"""
# 编码
mu, log_var = self.encode(x)
# 采样
z = self.reparameterize(mu, log_var)
# 归一化流(如果使用)
log_det_sum = 0
if self.n_flows > 0:
z, log_det_sum = self.flow_transform(z, log_det_sum)
# 解码
p_x_given_z = self.decode(z)
# 计算ELBO
log_p_xz = p_x_given_z.log_prob(x).sum(dim=-1)
log_p_z = dist.Normal(0, 1).log_prob(z).sum(dim=-1)
log_q_zx = dist.Normal(mu, log_var.exp()).log_prob(z).sum(dim=-1)
elbo = log_p_xz + log_p_z - log_q_zx + log_det_sum
return {
'elbo': elbo.mean(),
'reconstruction': log_p_xz.mean(),
'kl': (log_q_zx - log_p_z).mean(),
'x_rec': p_x_given_z.mean,
'z': z
}
def sample(self, n_samples):
"""从先验采样并生成"""
with torch.no_grad():
z = torch.randn(n_samples, self.latent_dim)
p_x_given_z = self.decode(z)
x_samples = p_x_given_z.sample()
return x_samples
# 训练脚本
def train_vae(X, input_dim, hidden_dim=400, latent_dim=20,
n_epochs=100, batch_size=128, lr=1e-3, n_flows=0):
"""训练VAE"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(input_dim, hidden_dim, latent_dim, n_flows).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
dataset = torch.tensor(X, dtype=torch.float32)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
history = {'elbo': [], 'recon': [], 'kl': []}
for epoch in range(n_epochs):
epoch_elbo = 0
epoch_recon = 0
epoch_kl = 0
for batch in dataloader:
batch = batch.to(device)
optimizer.zero_grad()
output = model(batch)
loss = -output['elbo']
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
epoch_elbo += output['elbo'].item()
epoch_recon += output['reconstruction'].item()
epoch_kl += output['kl'].item()
n_batches = len(dataloader)
history['elbo'].append(epoch_elbo / n_batches)
history['recon'].append(epoch_recon / n_batches)
history['kl'].append(epoch_kl / n_batches)
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{n_epochs}, ELBO: {history['elbo'][-1]:.2f}, "
f"Recon: {history['recon'][-1]:.2f}, KL: {history['kl'][-1]:.2f}")
return model, history