概述
变分推断(Variational Inference, VI)是一种将推断问题转化为优化问题的近似推断方法。通过优化一个变分分布来近似真实后验分布。1
变分推断的基本思想
设我们想要求解后验分布 ,其中 是隐变量, 是观测数据。由于精确推断往往不可行,我们引入一个变分分布 来近似 。
优化目标:最小化 与 之间的 KL 散度:
数学基础
证据下界(ELBO)
将 KL 散度展开:
利用贝叶斯定理 :
重新整理得到:
由于 KL 散度非负,我们有:
ELBO(Evidence Lower BOund)是边际似然的下界,最大化 ELBO 等价于最小化 KL 散度。
ELBO 的另一种形式
ELBO 也可以写成期望的形式:
展开联合分布:
这包含两项:
- 重建项: — 重建观测数据的期望似然
- 正则化项: — 变分分布与先验的 KL 散度
ELBO 的性质
import numpy as np
def compute_elbo(X, model, q_z, n_samples=1000):
"""
计算ELBO的蒙特卡洛估计
参数:
X: 观测数据
model: 生成模型
q_z: 变分分布
n_samples: 蒙特卡洛样本数
"""
# 从变分分布采样
z_samples = q_z.sample(n_samples)
# 计算重建项
log_px_given_z = model.log_likelihood(X, z_samples)
reconstruction = np.mean(log_px_given_z)
# 计算KL项
log_pz = model.prior.log_prob(z_samples)
log_qz = q_z.log_prob(z_samples)
kl = np.mean(log_pz - log_qz)
elbo = reconstruction - kl
return elbo, reconstruction, kl
# 示例:检查ELBO与边际似然的关系
# log P(X) = ELBO + D_KL(Q || P)
# 由于 D_KL >= 0,ELBO <= log P(X)平均场变分家族
平均场假设
平均场变分家族假设隐变量之间相互独立:
其中 是隐变量的划分。
坐标上升变分推断(CAVI)
CAVI 算法通过固定其他变量来更新每个变分因子:
更新公式:对于第 个因子,
这称为变分更新规则。
def cavi_update(X, model, q_i, other_q, i):
"""
CAVI 更新规则
参数:
X: 观测数据
model: 模型
q_i: 第i个变分因子
other_q: 其他变分因子的乘积
i: 要更新的因子索引
返回:
更新后的变分参数
"""
# 计算期望
# E_{-Q_i}[log P(X, Z)] = E_{Q_{-i}}[log P(X, Z_i, Z_{-i})]
def expected_log_joint(z_i):
# 对其他隐变量积分
return np.mean([model.log_joint(X, z_i, z_other)
for z_other in other_q.sample(1000)])
# 更新变分参数
new_params = compute_variational_params(expected_log_joint)
return new_params
def cavi_algorithm(X, model, n_iterations=100, tol=1e-4):
"""
CAVI 算法主循环
"""
q = initialize_variational_factors(model)
elbo_history = []
for iteration in range(n_iterations):
# 依次更新每个变分因子
for i in range(model.n_latent):
q[i] = cavi_update(X, model, q[i],
[q[j] for j in range(model.n_latent) if j != i],
i)
# 计算 ELBO
elbo = compute_elbo(X, model, q)
elbo_history.append(elbo)
# 检查收敛
if len(elbo_history) > 1:
if abs(elbo_history[-1] - elbo_history[-2]) < tol:
break
return q, elbo_history变分推断的变体
随机变分推断(SVI)
当数据量大时,使用随机优化来最大化 ELBO:
class StochasticVariationalInference:
"""
随机变分推断
"""
def __init__(self, model, q, optimizer, batch_size=32):
self.model = model
self.q = q
self.optimizer = optimizer
self.batch_size = batch_size
def step(self, X, global_step):
"""
一步随机梯度更新
"""
# 采样一个小批量
batch_idx = np.random.choice(len(X), self.batch_size, replace=False)
X_batch = X[batch_idx]
# 重参数化采样
eps = torch.randn(self.batch_size, self.q.dim)
z = self.q.loc + eps * self.q.scale
# 计算 ELBO 的蒙特卡洛估计
elbo = self.compute_elbo(X_batch, z)
# 反向传播
elbo.backward()
# 更新参数
self.optimizer.step()
self.optimizer.zero_grad()
return elbo.item()
def compute_elbo(self, X, z):
"""
计算 ELBO
"""
# 重建项
log_px_given_z = self.model.likelihood(X, z)
# KL 项
log_pz = torch.distributions.Normal(0, 1).log_prob(z)
log_qz = self.q.log_prob(z)
return log_px_given_z.mean() - (log_pz - log_qz).mean()黑盒变分推断(BBVI)
BBVI 使用分数函数梯度:
def bbvi_gradient(q, f, n_samples=100):
"""
BBVI 梯度估计
参数:
q: 变分分布
f: 目标函数
n_samples: 样本数
"""
samples = q.sample(n_samples)
log_q = q.log_prob(samples)
# 分数函数梯度
# ∇_φ E_q_φ[f(z)] ≈ (1/N) Σ f(z_i) ∇_φ log q_φ(z_i)
gradients = []
for i in range(n_samples):
grads = torch.autograd.grad(
log_q[i],
q.parameters(),
retain_graph=(i < n_samples - 1)
)
gradients.append(f(samples[i]) * grads[0])
return torch.stack(gradients).mean(dim=0)重参数化梯度
当变分分布可微时,使用重参数化技巧:
def reparameterized_gradient(q, f, n_samples=100):
"""
重参数化梯度估计
对于 q_φ(z) = N(μ_φ, σ²_φ),
采样 z = μ_φ + σ_φ * ε, ε ~ N(0,1)
"""
mu, log_sigma = q.params
sigma = torch.exp(log_sigma)
samples = []
for _ in range(n_samples):
eps = torch.randn_like(mu)
z = mu + sigma * eps
samples.append(z)
samples = torch.stack(samples)
# 计算 f(z) 的梯度
f_values = f(samples)
# 重参数化梯度
# ∇_φ f(z(φ, ε)) = ∂f/∂z * ∂z/∂φ
gradients = torch.autograd.grad(
f_values.sum(),
[mu, log_sigma],
retain_graph=True
)
return gradients平均场变分推断的详细推导
高斯混合模型(GMM)
以高斯混合模型为例展示平均场 VI:
模型:
- 混合权重:
- 簇分配:
- 观测:
变分分布:
更新 :
其中
更新 :
class VariationalGMM:
"""
高斯混合模型的变分推断
"""
def __init__(self, X, n_clusters, alpha=1.0):
self.X = X
self.n_clusters = n_clusters
self.n_samples = len(X)
self.alpha = alpha
# 初始化变分参数
self.phi = np.random.dirichlet(np.ones(n_clusters), size=self.n_samples)
self.alpha_pi = np.ones(n_clusters) * alpha
def e_step(self):
"""
E步:更新簇分配
"""
# 更新 phi
for i in range(self.n_samples):
log_resp = (np.log(self.alpha_pi) +
self.compute_log_likelihood(self.X[i]))
log_resp -= np.max(log_resp) # 数值稳定性
self.phi[i] = np.exp(log_resp)
self.phi[i] /= np.sum(self.phi[i])
def m_step(self):
"""
M步:更新混合权重
"""
# 更新 Dirichlet 参数
self.alpha_pi = self.alpha + np.sum(self.phi, axis=0)
def compute_log_likelihood(self, x):
"""
计算每个簇的对数似然
"""
return -0.5 * np.sum((x - self.mu)**2 / self.sigma**2, axis=1) - \
0.5 * np.log(self.sigma**2)
def fit(self, n_iterations=100):
"""
运行变分EM算法
"""
for _ in range(n_iterations):
self.e_step()
self.m_step()
return self.phi变分推断与EM算法的联系
EM vs VI
| 方面 | EM算法 | 变分推断 |
|---|---|---|
| 隐变量 | 优化 | 推断 |
| 参数 | 推断 | 优化 |
| 目标 | 最大化似然 | 最大化ELBO |
| 收敛性 | 局部最优 | 局部最优 |
变分EM
变分EM结合了两者的优点:
def variational_em(model, data, n_iterations=100):
"""
变分EM算法
"""
for iteration in range(n_iterations):
# E步:变分推断(优化隐变量分布)
q_z = variational_inference(model, data)
# M步:优化参数
params = m_step(model, data, q_z)
model.update_params(params)
return model变分推断的实现技巧
数值稳定性
def stable_elbo(log_p, log_q):
"""
数值稳定的 ELBO 计算
使用 log-sum-exp 技巧
"""
# KL = sum(exp(log_p) * (log_p - log_q))
# = sum(exp(log_p) * log_p) - sum(exp(log_p) * log_q)
# 重建项:使用 log-sum-exp
reconstruction = np.mean(log_p)
# KL 项
kl = np.mean(log_p - log_q)
return reconstruction - kl
def log_mean_exp(x, axis=None):
"""
数值稳定的 log(mean(exp(x)))
"""
max_x = np.max(x, axis=axis, keepdims=True)
return max_x + np.log(np.mean(np.exp(x - max_x), axis=axis))调参建议
- 初始化:使用 -means 或随机初始化
- 学习率:使用学习率衰减或 Adam
- 早停:监控验证集上的 ELBO
- 重采样:定期重新采样以减少方差
与其他近似方法的对比
VI vs MCMC
| 方面 | 变分推断 | MCMC |
|---|---|---|
| 计算复杂度 | 每迭代 | 采样 |
| 收敛保证 | 局部最优 | 渐近收敛 |
| 近似质量 | 有偏差 | 无偏(渐近) |
| 实现难度 | 中等 | 较高 |
| 扩展性 | 好 | 一般 |
VI vs 拉普拉斯近似
| 方面 | 变分推断 | 拉普拉斯近似 |
|---|---|---|
| 近似分布 | 任意变分族 | 高斯分布 |
| 灵活性 | 高 | 低 |
| 计算成本 | 中等 | 低 |
| 理论基础 | 信息论 | 渐近理论 |
应用场景
变分自编码器(VAE)
VAE 是变分推断在生成模型中的典型应用:
class VAE(nn.Module):
"""
变分自编码器
"""
def __init__(self, input_dim, latent_dim):
super().__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU()
)
self.fc_mu = nn.Linear(128, latent_dim)
self.fc_log_var = nn.Linear(128, latent_dim)
# 解码器
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, input_dim)
)
def encode(self, x):
h = self.encoder(x)
mu = self.fc_mu(h)
log_var = self.fc_log_var(h)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
return self.decoder(z)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_recon = self.decode(z)
return x_recon, mu, log_var
def elbo_loss(self, x, x_recon, mu, log_var, beta=1.0):
"""
ELBO 损失
ELBO = -L = - Reconstruction - KL
"""
# 重建项(负对数似然)
recon_loss = nn.functional.mse_loss(x_recon, x, reduction='sum')
# KL 项
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return recon_loss + beta * kl_loss变分Dropout
变分Dropout将Dropout解释为贝叶斯推断:
class VariationalDropout(nn.Module):
"""
变分Dropout(Kingma et al., 2015)
"""
def __init__(self, in_features, out_features):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# 变分参数
self.w_mu = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
self.b_mu = nn.Parameter(torch.zeros(out_features))
self.w_log_alpha = nn.Parameter(torch.zeros(out_features, in_features))
self.b_log_alpha = nn.Parameter(torch.zeros(out_features))
def forward(self, x, sample=True):
if sample or self.training:
# 重参数化采样
alpha = torch.exp(self.w_log_alpha)
std = (self.w_mu ** 2) * alpha
# 权重采样
w = self.w_mu + torch.randn_like(self.w_mu) * std.sqrt()
b = self.b_mu + torch.randn_like(self.b_mu) * \
(self.b_mu ** 2 * torch.exp(self.b_log_alpha)).sqrt()
else:
# 近似后验均值
w = self.w_mu
b = self.b_mu
return F.linear(x, w, b)
def kl_divergence(self):
"""
计算 KL 散度
"""
alpha = torch.exp(self.w_log_alpha)
kl = 0.5 * (alpha * self.w_mu ** 2 / (self.w_mu ** 2 + alpha) -
torch.log(self.w_mu ** 2 + alpha) +
torch.log(self.w_mu ** 2) + 1)
return kl.sum()参考
相关链接
Footnotes
-
Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational Inference: A Review for Statisticians. Journal of the American Statistical Association. ↩