神经变分推断深度解析
神经变分推断(Neural Variational Inference,NVI)是将变分推断与神经网络深度融合的范式,通过参数化的神经网络近似复杂后验分布,实现端到端的概率推断。1 本文档深入解析变分推断的数学基础、神经网络的概率解释、以及现代变分方法的实现细节。
1. 变分推断基础回顾
1.1 问题的形式化
在贝叶斯推断中,我们希望计算后验分布:
其中:
- :先验分布
- :似然函数
- :边缘似然(证据)
核心困难:边缘似然 通常难以解析计算,导致后验分布无法直接得到。
1.2 Jensen 不等式与 ELBO
Jensen 不等式:对于凸函数 和概率分布 :
对 应用 Jensen 不等式:
证据下界(Evidence Lower Bound,ELBO):
1.3 ELBO 的分解
将 ELBO 进一步分解为两个有意义的项:
| 项 | 含义 | 作用 |
|---|---|---|
| 重构项 | 确保变分分布能重构数据 | |
| 正则化项 | 约束变分分布接近先验 |
1.4 KL 散度的性质
定义:
性质:
- 非负性:
- 非对称性:
- 可加性:
高斯分布间的 KL 散度(闭合形式):
对于 和 :
2. VAE 的概率图视角
2.1 VAE 的生成模型
变分自编码器(Variational Autoencoder,VAE)定义了一个层次生成模型。2
生成过程:
概率图模型:
p(z) p(x|z)
┌─────────┐ ┌─────────┐
│ │ │ │
↓ │ ↓ │
z ──────────→ x
↑
│
q(z|x)
2.2 推断网络
由于真实后验 难以计算,VAE 引入推断网络 来近似:
推理网络结构(编码器):
class Encoder(nn.Module):
"""VAE 编码器:推断网络 q(z|x)"""
def __init__(self, input_dim, hidden_dim, latent_dim):
super().__init__()
# 共享编码层
self.shared = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# 均值和方差网络
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_log_var = nn.Linear(hidden_dim, latent_dim)
def forward(self, x):
h = self.shared(x)
mu = self.fc_mu(h)
log_var = self.fc_log_var(h)
# 方差确保为正
return mu, log_var2.3 生成网络
生成网络结构(解码器):
class Decoder(nn.Module):
"""VAE 解码器:生成分布 p(x|z)"""
def __init__(self, latent_dim, hidden_dim, output_dim, distribution='bernoulli'):
super().__init__()
self.distribution = distribution
self.fc = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)
if distribution == 'bernoulli':
# 输出 sigmoid 激活
self.output_activation = nn.Sigmoid()
elif distribution == 'gaussian':
# 输出均值和对数方差
self.fc_mu = nn.Linear(hidden_dim, output_dim)
self.fc_log_var = nn.Linear(hidden_dim, output_dim)
def forward(self, z):
h = self.fc(z)
if self.distribution == 'bernoulli':
p_x_given_z = torch.sigmoid(h)
return p_x_given_z
elif self.distribution == 'gaussian':
mu = self.fc_mu(h)
log_var = self.fc_log_var(h)
return mu, log_var2.4 VAE 的目标函数
证据下界:
完整目标:
3. 重参数化技巧的数学原理
3.1 问题的数学形式
我们希望计算关于变分分布 的期望梯度:
直接求导的问题:期望内部包含随机变量,无法直接应用链式法则。
3.2 重参数化变换
核心思想:将随机性转移到独立的噪声变量中:
常见重参数化:
| 分布类型 | 重参数化形式 |
|---|---|
| 高斯 | |
| 拉普拉斯 | |
| 分类 | |
| 混合 |
3.3 梯度推导
定理:对于可微的重参数化函数 :
证明:
3.4 方差分析
重参数化的优势:降低梯度估计的方差。
Score function 梯度:
问题:Score function 梯度在 值较大时方差爆炸。
重参数化梯度:
优势:梯度仅通过 传播,不直接涉及 的梯度。
3.5 Gumbel-Softmax 重参数化
对于离散分布,使用 Gumbel-Softmax 进行重参数化:
其中 , 是温度参数。
def gumbel_softmax(logits, temperature, hard=False):
"""
Gumbel-Softmax 重参数化
Args:
logits: (batch, n_categories) 分类分布的对数概率
temperature: 温度参数 τ
hard: 是否使用硬 one-hot 输出
Returns:
采样的 soft/hard 分布
"""
# 采样 Gumbel 噪声
gumbels = -torch.empty_like(logits).exponential_().log()
gumbels = (logits + gumbels) / temperature
# Softmax
soft = F.softmax(gumbels, dim=-1)
if hard:
# 硬 one-hot(但保持梯度流通)
hard_onehot = F.one_hot(gumbels.argmax(dim=-1), logits.size(-1)).float()
return (hard_onehot - soft).detach() + soft
else:
return soft4. 证据下界(ELBO)的深入分析
4.1 ELBO 与真实边际似然的关系
定理:对于任意变分分布 :
推论:
- (ELBO 是下界)
- 当 时,,
4.2 ELBO 的多种等价形式
形式 1:标准形式
形式 2:重构 + KL 形式
形式 3:信息论形式
形式 4:重要性采样形式
4.3 分解性质
数据点的可加性:
局部 ELBO:
4.4 ELBO 的紧度分析
定义:,称为紧度比率。
优化目标:最大化 等价于:
- 最大化重构似然
- 最小化
冲突:
- 重构项鼓励 集中在高似然区域
- KL 项鼓励 接近先验
5. 变分分布的选择与设计
5.1 均值场近似
假设:变分分布分解为独立因子的乘积:
优势:简化计算,易于优化
劣势:忽略变量间相关性
class MeanFieldVariational:
"""均值场变分分布"""
def __init__(self, dims, distribution='gaussian'):
self.dims = dims
self.distribution = distribution
if distribution == 'gaussian':
# 每个因子是高斯分布
self.mus = nn.Parameter(torch.randn(dims))
self.log_vars = nn.Parameter(torch.zeros(dims))
def sample(self, n_samples=1):
"""采样"""
if self.distribution == 'gaussian':
std = (0.5 * self.log_vars).exp()
eps = torch.randn(n_samples, self.dims)
return self.mus + eps * std
def log_prob(self, z):
"""计算对数概率"""
if self.distribution == 'gaussian':
log_prob = -0.5 * (
np.log(2 * np.pi) +
self.log_vars +
(z - self.mus) ** 2 / self.log_vars.exp()
)
return log_prob.sum(dim=-1)5.2 层次变分分布
动机:捕获变量间的相关性
class HierarchicalVariational(nn.Module):
"""层次变分分布"""
def __init__(self, latent_dim, hidden_dim):
super().__init__()
# 先验参数化
self.prior_net = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2 * latent_dim) # μ, log_var
)
# 条件变分分布
self.posterior_net = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2 * latent_dim)
)
def forward(self, x, n_samples=1):
"""从层次变分分布采样"""
# 采样顶层参数
prior_params = self.prior_net(x)
lambda_mu, lambda_log_var = prior_params.chunk(2, dim=-1)
lambda_std = (0.5 * lambda_log_var).exp()
# 采样 λ
eps = torch.randn_like(lambda_mu)
lambda_sample = lambda_mu + eps * lambda_std
# 给定 λ,采样 z
posterior_params = self.posterior_net(lambda_sample)
z_mu, z_log_var = posterior_params.chunk(2, dim=-1)
z_std = (0.5 * z_log_var).exp()
eps = torch.randn(n_samples, *z_mu.shape)
z_samples = z_mu + eps * z_std
return z_samples, lambda_sample5.3 规范化流变分分布
使用可逆变换增强表达能力:
class NormalizingFlowVariational(nn.Module):
"""基于归一化流的变分分布"""
def __init__(self, latent_dim, n_flows=4):
super().__init__()
self.latent_dim = latent_dim
self.n_flows = n_flows
# 基分布
self.base_mu = nn.Parameter(torch.zeros(latent_dim))
self.base_log_var = nn.Parameter(torch.zeros(latent_dim))
# 归一化流层
self.flows = nn.ModuleList([
PlanarFlow(latent_dim) for _ in range(n_flows)
])
def forward(self, n_samples=1):
"""从规范化流分布采样"""
# 从基分布采样
std = (0.5 * self.base_log_var).exp()
z = self.base_mu + torch.randn(n_samples, self.latent_dim) * std
# 通过归一化流
log_det_sum = 0
for flow in self.flows:
z, log_det = flow(z)
log_det_sum += log_det
return z, log_det_sum
def log_prob(self, z):
"""计算对数概率"""
# 计算基分布概率
diff = z - self.base_mu
base_log_prob = -0.5 * (
np.log(2 * np.pi) +
self.base_log_var +
diff ** 2 / self.base_log_var.exp()
).sum(dim=-1)
# 加上 Jacobian 行列式
log_det_sum = torch.zeros(z.size(0), device=z.device)
z_cur = z
for flow in self.flows:
z_cur, log_det = flow(z_cur)
log_det_sum += log_det
return base_log_prob + log_det_sum
class PlanarFlow(nn.Module):
"""Planar 归一化流"""
def __init__(self, dim):
super().__init__()
self.w = nn.Parameter(torch.randn(dim))
self.u = nn.Parameter(torch.randn(dim))
self.b = nn.Parameter(torch.zeros(1))
def forward(self, z):
"""
变换: z' = z + u * h(w^T z + b)
Returns:
z_new: 变换后的样本
log_det: log|det(dz'/dz)|
"""
# 激活函数
activation = torch.tanh(z @ self.w + self.b)
# 前向变换
z_new = z + self.u * activation
# 对数行列式
# d(z')/dz = I + u * h'(w^T z + b) * w^T
# det = 1 + u^T * h'(w^T z + b) * w
psi = (1 - activation ** 2) * self.w
det = 1 + (self.u @ psi)
log_det = torch.log(det.abs() + 1e-8)
return z_new, log_det5.4 混合变分分布
class MixtureVariational(nn.Module):
"""混合变分分布"""
def __init__(self, latent_dim, n_components, hidden_dim):
super().__init__()
self.latent_dim = latent_dim
self.n_components = n_components
# 组件分布参数
self.component_mus = nn.Parameter(
torch.randn(n_components, latent_dim)
)
self.component_log_vars = nn.Parameter(
torch.zeros(n_components, latent_dim)
)
# 混合权重(logits)
self.mixture_weights = nn.Parameter(torch.zeros(n_components))
def sample(self, n_samples=1):
"""从混合分布采样"""
# 选择组件
weights = F.softmax(self.mixture_weights, dim=0)
component_idx = torch.multinomial(weights, n_samples, replacement=True)
# 从选中组件采样
samples = torch.randn(n_samples, self.latent_dim)
for i in range(n_samples):
comp = component_idx[i]
std = (0.5 * self.component_log_vars[comp]).exp()
samples[i] = self.component_mus[comp] + samples[i] * std
return samples
def log_prob(self, z):
"""计算对数概率"""
log_probs = []
for k in range(self.n_components):
std = (0.5 * self.component_log_vars[k]).exp()
diff = z - self.component_mus[k]
log_prob_k = -0.5 * (
np.log(2 * np.pi) +
self.component_log_vars[k] +
diff ** 2 / self.component_log_vars[k].exp()
).sum(dim=-1)
log_probs.append(log_prob_k + self.mixture_weights[k])
# 混合分布的对数概率
log_probs = torch.stack(log_probs, dim=-1)
return torch.logsumexp(log_probs, dim=-1)6. 梯度估计方法
6.1 Score Function 梯度(REINFORCE)
Score function 恒等式:
蒙特卡洛估计:
def score_function_gradient(f, q, n_samples=100):
"""
Score function 梯度估计
适用于:离散变量、不可微模型、复杂似然函数
Args:
f: 函数 f: Z → ℝ
q: 变分分布 q_φ(z)
n_samples: 采样数量
Returns:
gradient: 梯度估计
"""
gradients = []
for _ in range(n_samples):
z = q.sample() # 从 q_φ 采样
log_q = q.log_prob(z) # log q_φ(z)
grad_log_q = torch.autograd.grad(
log_q.sum(),
q.parameters(),
retain_graph=True
)[0]
# f(z) * ∇_φ log q_φ(z)
f_val = f(z)
gradients.append(f_val * grad_log_q)
return torch.stack(gradients).mean(dim=0)6.2 分数函数梯度(Score Matching)
Score function:
分数匹配目标:
6.3 重参数化梯度
连续变量的首选方法:
def reparameterization_gradient(f, phi, n_samples=100):
"""
重参数化梯度估计
适用于:连续变量、可微分模型
Args:
f: 函数 f: Z → ℝ
phi: 变分参数
Returns:
gradient: 梯度估计
"""
gradients = []
for _ in range(n_samples):
# 采样噪声
epsilon = torch.randn_like(phi)
# 重参数化
z = reparameterize(phi, epsilon)
# 计算 f(z) 并反向传播
f_val = f(z)
f_val.backward()
gradients.append(phi.grad.clone())
phi.zero_grad()
return torch.stack(gradients).mean(dim=0)
def reparameterize(mu, log_var):
"""
高斯分布重参数化
z = μ + σ * ε, ε ~ N(0, I)
"""
std = (0.5 * log_var).exp()
eps = torch.randn_like(std)
return mu + eps * std6.4 路径导数梯度(Pathwise Derivatives)
路径导数:通过完整路径的梯度传递。
对于 :
6.5 方法对比与选择
| 方法 | 方差 | 偏置 | 适用场景 |
|---|---|---|---|
| Score Function | 高 | 无 | 离散变量、不可微模型 |
| Reparameterization | 低 | 无 | 连续变量、可微模型 |
| Pathwise | 低 | 无 | 连续变量、复杂采样路径 |
| RELAX | 中等 | 有 | 混合离散-连续 |
6.6 方差归一化技术
class VarianceNormalizedEstimator:
"""
方差归一化的梯度估计器
减少梯度估计方差,加速收敛
"""
def __init__(self, baseline_net=None):
self.baseline_net = baseline_net # 用于减方差的基线网络
def estimate(self, f, q, n_samples=100):
"""估计梯度"""
z_samples = q.sample(n_samples)
log_q = q.log_prob(z_samples)
# 计算 f(z) - b(z) 的 score function 梯度
f_vals = f(z_samples)
# 如果有基线,使用基线减方差
if self.baseline_net is not None:
baseline = self.baseline_net(z_samples)
f_centered = f_vals - baseline.detach()
else:
# 使用均值作为基线
f_centered = f_vals - f_vals.mean()
# Score function 梯度
grad_log_q = torch.autograd.grad(
log_q.sum(),
q.parameters(),
retain_graph=True
)[0]
gradient = (f_centered * grad_log_q).mean()
return gradient
def update_baseline(self, targets, predictions):
"""
更新基线网络(通常为 MSE 预测器)
"""
if self.baseline_net is not None:
loss = F.mse_loss(predictions, targets)
loss.backward()7. 最新进展:归一化流、连续混合、神经变分推断
7.1 连续归一化流
连续时间归一化流(Continuous Normalizing Flows,CNF):
概率流 ODE:
class CNF(nn.Module):
"""连续归一化流"""
def __init__(self, dim, hidden_dim=64):
super().__init__()
# 速度场网络
self.velocity_net = nn.Sequential(
nn.Linear(dim + 1, hidden_dim), # +1 for time
nn.Softplus(),
nn.Linear(hidden_dim, hidden_dim),
nn.Softplus(),
nn.Linear(hidden_dim, dim)
)
def velocity(self, z, t):
"""速度场 f(z, t)"""
# 时间条件
t_emb = t * torch.ones(z.size(0), 1, device=z.device)
z_t = torch.cat([z, t_emb], dim=-1)
return self.velocity_net(z_t)
def forward(self, z0, t_span):
"""
前向传播:z0 → z1
使用数值 ODE 求解器
"""
# 使用 torchdiffeq
from torchdiffeq import odeint
solution = odeint(
self.ode_func,
z0,
t_span,
method='dopri5'
)
z1 = solution[-1]
# 计算 log det |dz1/dz0|
# 通过日志雅可比行列式的积分
log_det = self.compute_log_det(z0, t_span)
return z1, log_det
def ode_func(self, t, z):
"""ODE 定义"""
return self.velocity(z, t)
def compute_log_det(self, z0, t_span):
"""计算对数行列式"""
# 简化为使用迹的近似
log_det = 0
for t in t_span[:-1]:
z_t = self.trajectory(t)[-1]
with torch.enable_grad():
v = self.velocity(z_t, t)
div_v = torch.autograd.grad(
v.sum(), z_t, retain_graph=True
)[0].sum(dim=-1)
log_det -= div_v * (t_span[1] - t)
return log_det7.2 神经混合模型
神经混合模型(Neural Mixture Models):
class NeuralMixtureVAE(nn.Module):
"""神经混合 VAE"""
def __init__(self, input_dim, latent_dim, n_components=4):
super().__init__()
self.n_components = n_components
# 编码器
self.encoder = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU()
)
# 混合权重
self.pi_net = nn.Linear(128, n_components)
# 每个组件的均值和方差网络
self.component_nets = nn.ModuleList([
nn.Sequential(
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 2 * latent_dim) # μ, log σ
)
for _ in range(n_components)
])
# 解码器
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 64),
nn.ReLU(),
nn.Linear(64, 256),
nn.ReLU(),
nn.Linear(256, 2 * input_dim)
)
def forward(self, x):
"""前向传播"""
# 编码
h = self.encoder(x)
# 计算混合权重
logits = self.pi_net(h)
pi = F.softmax(logits, dim=-1)
# 为每个组件采样
z_samples = []
log_probs = []
for k in range(self.n_components):
params = self.component_nets[k](h)
mu, log_var = params.chunk(2, dim=-1)
std = (0.5 * log_var).exp()
# 采样
z_k = mu + std * torch.randn_like(mu)
z_samples.append(z_k)
# 对数概率
log_prob_k = -0.5 * (
np.log(2 * np.pi) +
log_var +
(z_k - mu) ** 2 / log_var.exp()
).sum(dim=-1)
log_probs.append(log_prob_k)
z_samples = torch.stack(z_samples, dim=0) # (K, batch, latent)
log_probs = torch.stack(log_probs, dim=0) # (K, batch)
# 加权求和
log_pi = torch.log(pi.T + 1e-8) # (K, batch)
weighted_log_probs = log_probs + log_pi
# 边缘化
z_out = torch.logsumexp(weighted_log_probs, dim=0)
# 解码(使用均值)
z_mean = (pi.unsqueeze(-1) * z_samples.permute(1, 2, 0)).sum(dim=-1)
x_rec = self.decoder(z_mean)
return {
'z_samples': z_samples,
'pi': pi,
'x_rec': x_rec,
'log_weights': weighted_log_probs
}7.3 对抗变分推断
对抗变分推断(Adversarial Variational Inference,AVI):
class AdversarialVI(nn.Module):
"""
对抗变分推断
使用判别器区分真实后验和变分后验
"""
def __init__(self, encoder, decoder, discriminator):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.discriminator = discriminator
def discriminator_loss(self, x, z):
"""
判别器损失
目标:区分来自真实后验和变分后验的样本
"""
# 来自变分后验的样本
z_from_q = z
# 来自先验的样本(近似真实后验)
z_from_prior = torch.randn_like(z)
# 判别器预测
d_q = self.discriminator(x, z_from_q)
d_prior = self.discriminator(x, z_from_prior)
# 损失
loss_d = -0.5 * (
torch.log(d_q + 1e-8) +
torch.log(1 - d_prior + 1e-8)
).mean()
return loss_d
def generator_loss(self, x, z):
"""
生成器(编码器+解码器)损失
目标:骗过判别器
"""
d = self.discriminator(x, z)
loss_g = -torch.log(d + 1e-8).mean()
# 添加重构损失
x_rec = self.decoder(z)
loss_rec = F.mse_loss(x_rec, x)
return loss_g + loss_rec8. PyTorch 完整实现
8.1 完整 VAE 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
import math
import numpy as np
class VAE(nn.Module):
"""
完整变分自编码器实现
支持:
- 重参数化采样
- 可配置编码器/解码器
- 多种损失函数
- 学习率调度
"""
def __init__(
self,
input_dim: int,
latent_dim: int,
hidden_dim: int = 256,
encoder_depth: int = 2,
decoder_depth: int = 2,
distribution: str = 'gaussian',
use_flows: bool = False,
n_flows: int = 0
):
super().__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.distribution = distribution
# 编码器
encoder_layers = []
prev_dim = input_dim
for _ in range(encoder_depth):
encoder_layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.ReLU()
])
prev_dim = hidden_dim
self.encoder = nn.Sequential(*encoder_layers)
# 潜在空间映射
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_log_var = nn.Linear(hidden_dim, latent_dim)
# 解码器
decoder_layers = []
prev_dim = latent_dim
for _ in range(decoder_depth):
decoder_layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.ReLU()
])
prev_dim = hidden_dim
self.decoder = nn.Sequential(*decoder_layers)
# 输出层
if distribution == 'gaussian':
self.fc_out_mu = nn.Linear(hidden_dim, input_dim)
self.fc_out_log_var = nn.Linear(hidden_dim, input_dim)
elif distribution == 'bernoulli':
self.fc_out = nn.Linear(hidden_dim, input_dim)
# 归一化流(可选)
self.use_flows = use_flows
if use_flows:
self.flows = nn.ModuleList([
PlanarFlow(latent_dim) for _ in range(n_flows)
])
def encode(self, x):
"""编码:计算后验参数"""
h = self.encoder(x)
mu = self.fc_mu(h)
log_var = self.fc_log_var(h)
return mu, log_var
def reparameterize(self, mu, log_var):
"""重参数化采样"""
std = (0.5 * log_var).exp()
eps = torch.randn_like(std)
return mu + eps * std
def flow_transform(self, z0, reverse=False):
"""归一化流变换"""
log_det_sum = torch.zeros(z0.size(0), device=z0.device)
z = z0
flows = self.flows if not reverse else reversed(self.flows)
for flow in flows:
z, log_det = flow(z)
log_det_sum += log_det
return z, log_det_sum
def decode(self, z):
"""解码:计算生成分布"""
h = self.decoder(z)
if self.distribution == 'gaussian':
mu = self.fc_out_mu(h)
log_var = self.fc_out_log_var(h)
return mu, log_var
elif self.distribution == 'bernoulli':
logits = self.fc_out(h)
return torch.sigmoid(logits), None
def forward(self, x, n_samples=1):
"""
前向传播
Args:
x: (batch_size, input_dim) 输入数据
n_samples: 每个数据点的采样数
Returns:
dict: 包含重构、损失等信息
"""
# 编码
mu, log_var = self.encode(x)
# 重参数化采样
z = self.reparameterize(mu, log_var)
# 归一化流(如果使用)
log_det_flow = torch.zeros(x.size(0), device=x.device)
if self.use_flows:
z, log_det_flow = self.flow_transform(z)
# 解码
recon = self.decode(z)
# 计算 ELBO
if self.distribution == 'gaussian':
recon_mu, recon_log_var = recon
log_px_given_z = dist.Normal(recon_mu, recon_log_var.exp()).log_prob(x).sum(dim=-1)
elif self.distribution == 'bernoulli':
p_x_given_z = recon
log_px_given_z = dist.Bernoulli(p_x_given_z).log_prob(x).sum(dim=-1)
# 先验对数似然
log_pz = dist.Normal(0, 1).log_prob(z).sum(dim=-1)
# 后验对数似然
log_qz_given_x = dist.Normal(mu, log_var.exp()).log_prob(z).sum(dim=-1)
# ELBO
elbo = log_px_given_z + log_pz - log_qz_given_x + log_det_flow
return {
'elbo': elbo,
'reconstruction': log_px_given_z,
'kl': log_qz_given_x - log_pz - log_det_flow,
'z': z,
'mu': mu,
'log_var': log_var
}
def loss(self, x):
"""计算损失(负 ELBO)"""
output = self.forward(x)
return -output['elbo'].mean()
def sample(self, n_samples, temperature=1.0):
"""从先验采样并生成"""
with torch.no_grad():
# 从先验采样
z = torch.randn(n_samples, self.latent_dim) * temperature
# 归一化流逆变换(如果使用)
if self.use_flows:
z = self.flow_transform(z, reverse=True)[0]
# 解码
recon = self.decode(z)
if self.distribution == 'gaussian':
mu, log_var = recon
x_samples = mu
elif self.distribution == 'bernoulli':
p = recon
x_samples = torch.bernoulli(p)
return x_samples, z
def reconstruct(self, x):
"""重构输入"""
with torch.no_grad():
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
recon = self.decode(z)
if self.distribution == 'gaussian':
return recon[0]
elif self.distribution == 'bernoulli':
return recon
class VAETrainer:
"""
VAE 训练器
支持:
- 学习率调度
- 早停
- 可视化
"""
def __init__(
self,
model: VAE,
optimizer_class: type = torch.optim.Adam,
lr: float = 1e-3,
beta: float = 1.0,
recon_weight: float = 1.0
):
self.model = model
self.beta = beta # KL 项权重
self.recon_weight = recon_weight
self.optimizer = optimizer_class(model.parameters(), lr=lr)
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
patience=10,
factor=0.5
)
self.history = {
'loss': [],
'recon_loss': [],
'kl_loss': [],
'elbo': []
}
def train_step(self, x):
"""单步训练"""
self.optimizer.zero_grad()
# 前向传播
output = self.model(x)
# 计算损失
recon_loss = -output['reconstruction'].mean()
kl_loss = output['kl'].mean()
elbo = output['elbo'].mean()
loss = -(self.recon_weight * recon_loss + self.beta * kl_loss)
# 反向传播
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
# 记录
self.history['loss'].append(loss.item())
self.history['recon_loss'].append(recon_loss.item())
self.history['kl_loss'].append(kl_loss.item())
self.history['elbo'].append(elbo.item())
return loss.item(), recon_loss.item(), kl_loss.item()
def train(self, dataloader, n_epochs, eval_loader=None):
"""完整训练循环"""
for epoch in range(n_epochs):
epoch_loss = 0
epoch_recon = 0
epoch_kl = 0
n_batches = 0
for x in dataloader:
if isinstance(x, (list, tuple)):
x = x[0]
x = x.view(-1, self.model.input_dim)
loss, recon, kl = self.train_step(x)
epoch_loss += loss
epoch_recon += recon
epoch_kl += kl
n_batches += 1
# 平均
avg_loss = epoch_loss / n_batches
avg_recon = epoch_recon / n_batches
avg_kl = epoch_kl / n_batches
# 学习率调度
self.scheduler.step(avg_loss)
# 打印
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{n_epochs}")
print(f" Loss: {avg_loss:.4f}")
print(f" Recon: {avg_recon:.4f}")
print(f" KL: {avg_kl:.4f}")
print(f" ELBO: {self.history['elbo'][-1]:.4f}")
print()
return self.history
def demo_vae():
"""VAE 演示"""
# 生成双峰数据
torch.manual_seed(42)
n_samples = 2000
data1 = torch.randn(n_samples // 2, 2) + torch.tensor([2.0, 2.0])
data2 = torch.randn(n_samples // 2, 2) + torch.tensor([-2.0, -2.0])
data = torch.cat([data1, data2], dim=0)
# 打乱
perm = torch.randperm(len(data))
data = data[perm]
# 创建数据集
dataset = torch.utils.data.TensorDataset(data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
# 创建 VAE
vae = VAE(
input_dim=2,
latent_dim=2,
hidden_dim=64,
encoder_depth=2,
decoder_depth=2,
distribution='gaussian'
)
# 训练
trainer = VAETrainer(vae, lr=1e-3, beta=1.0)
history = trainer.train(dataloader, n_epochs=100)
# 测试采样
samples, z = vae.sample(n_samples=500)
print(f"\n生成的样本统计:")
print(f" 均值: {samples.mean(dim=0)}")
print(f" 方差: {samples.var(dim=0)}")
return vae, history
if __name__ == "__main__":
vae, history = demo_vae()9. 总结与关联
9.1 核心要点
| 主题 | 核心公式 |
|---|---|
| ELBO | |
| 重参数化 | |
| KL 散度 | |
| Gumbel-Softmax |
9.2 与相关文档的关联
| 相关主题 | 关联说明 |
|---|---|
| 贝叶斯网络 | 概率图模型基础 |
| 变分推断进阶 | SVI、IWAE、归一化流 |
| 贝叶斯神经网络 | BNN 的变分推断训练 |
| Bayes by Backprop | 变分推断的神经网络实现 |
| MC Dropout | Dropout 的贝叶斯解释 |
| 概率电路基础 | 电路作为变分后验 |