概述
变分推断(Variational Inference, VI)是概率图模型中最重要的近似推断方法之一,它将后验分布推断问题转化为优化问题。1
在深度学习时代,变分推断成为连接贝叶斯神经网络与经典神经网络的关键桥梁——变分自编码器(VAE)、变分循环网络、变分图神经网络等模型都建立在变分推断的数学框架之上。
变分推断的基本框架
问题设定
给定观测数据 和潜在变量 ,我们希望:
- 学习参数:最大化边缘似然
- 推断后验:计算后验分布
边缘似然的分解:
精确推断的困难
在大多数实际问题中,积分 是不可计算的:
- 潜在空间维度高(, )
- 后验分布 没有解析形式
- 配分函数 难以计算
变分推断的核心思想
用一个简单的分布 去近似复杂的真实后验 :
然后将推断问题转化为优化问题:找到最优的 使得 与 最接近。
变分族的选择
常见的变分分布族:
| 变分族 | 形式 | 优点 | 缺点 |
|---|---|---|---|
| 均值场(Mean-Field) | 独立、易于计算 | 过于简化 | |
| 高斯变分 | 平滑、连续 | 参数多 | |
| 归一化流 | 表达能力强 | 计算复杂 | |
| 摊销分布 | $q_\phi(z | x) = \text{NN}_\phi(x)$ | 共享参数 |
证据下界(ELBO)
KL散度推导
我们用KL散度衡量两个分布的差异:
利用贝叶斯定理 :
证据下界
重新整理得:
关键洞察:由于KL散度非负,我们得到证据下界(Evidence Lower Bound, ELBO):
ELBO的两种形式
形式1:期望形式
解释:
- 第一项:重构似然的期望(重构损失)
- 第二项:先验与后验的KL散度(正则化项)
形式2:信息论形式
解释:ELBO是边缘似然减去真实后验与变分后验的KL散度。最小化后验近似误差等价于最大化ELBO。
最大化ELBO的目标
变分推断的优化方法
坐标上升变分推断(CAVI)
对于均值场变分族,可以交替优化每个局部变分参数:
def cavi_update(j, X, q):
"""
CAVI更新规则
Args:
j: 更新的变量索引
X: 观测数据
q: 当前变分分布
Returns:
new_q_j: 更新后的变分分布
"""
# 计算期望(除z_j外的所有其他变量)
expected_logjoint = 0
for sample in range(num_samples):
z_sample = q.sample()
expected_logjoint += np.log(p(X, z_sample))
expected_logjoint /= num_samples
# 归一化
new_q_j = np.exp(expected_logjoint)
new_q_j /= new_q_j.sum() # 归一化
return new_q_j随机变分推断(SVI)
当数据规模很大时,使用随机梯度上升:
其中 是学习率,通常使用自适应学习率调度。
重参数化技巧
为了计算 的梯度,使用重参数化技巧:
高斯分布的例子:
变分自编码器(VAE)
VAE的概率模型
VAE假设数据生成过程如下:
其中:
- :标准高斯先验
- :解码器分布(通常是高斯或伯努利)
- :变分近似后验(编码器)
VAE的ELBO
PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, Bernoulli
class Encoder(nn.Module):
"""变分编码器"""
def __init__(self, input_dim, latent_dim, hidden_dim=400):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim) # 均值
self.fc_logvar = nn.Linear(hidden_dim, latent_dim) # 对数方差
def forward(self, x):
h = F.relu(self.fc1(x))
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
class Decoder(nn.Module):
"""变分解码器"""
def __init__(self, latent_dim, output_dim, hidden_dim=400):
super().__init__()
self.fc1 = nn.Linear(latent_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, z):
h = F.relu(self.fc1(z))
# 输出logits用于伯努利分布
logits = self.fc2(h)
return logits
class VAE(nn.Module):
"""
变分自编码器
使用重参数化技巧实现梯度反向传播
"""
def __init__(self, input_dim, latent_dim, hidden_dim=400):
super().__init__()
self.latent_dim = latent_dim
self.encoder = Encoder(input_dim, latent_dim, hidden_dim)
self.decoder = Decoder(latent_dim, input_dim, hidden_dim)
def reparameterize(self, mu, logvar):
"""
重参数化技巧
z = μ + σ * ε, 其中 ε ~ N(0, I)
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
# 编码
mu, logvar = self.encoder(x)
# 重参数化采样
z = self.reparameterize(mu, logvar)
# 解码
logits = self.decoder(z)
return logits, mu, logvar
def elbo_loss(self, x, logits, mu, logvar):
"""
计算ELBO损失
ELBO = 重构损失 - KL散度
"""
# 重构损失(伯努利分布的负对数似然)
# 对于二值图像数据,sigmoid激活后用BCE
x_prob = torch.sigmoid(logits)
recon_loss = F.binary_cross_entropy(x_prob, x, reduction='sum')
# KL散度:q(z|x) || p(z)
# 对于高斯分布,KL(N(μ,σ²) || N(0,I)) 有闭式解
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kl_loss
def loss_function(self, x):
"""完整损失计算"""
logits, mu, logvar = self.forward(x)
return self.elbo_loss(x, logits, mu, logvar)
def sample(self, num_samples, device):
"""
从先验采样并解码
"""
with torch.no_grad():
z = torch.randn(num_samples, self.latent_dim, device=device)
logits = self.decoder(z)
samples = torch.sigmoid(logits)
return samples
def encode(self, x):
"""编码到潜在空间"""
mu, logvar = self.encoder(x)
return self.reparameterize(mu, logvar)
def decode(self, z):
"""从潜在空间解码"""
logits = self.decoder(z)
return torch.sigmoid(logits)
def train_vae(model, dataloader, optimizer, device, epoch):
model.train()
total_loss = 0
total_recon = 0
total_kl = 0
for batch_idx, (data, _) in enumerate(dataloader):
data = data.view(-1, input_dim).to(device)
optimizer.zero_grad()
# 前向传播
logits, mu, logvar = model(data)
# 计算损失
recon_loss = F.binary_cross_entropy_with_logits(logits, data, reduction='sum')
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = recon_loss + kl_loss
# 反向传播
loss.backward()
optimizer.step()
total_loss += loss.item()
total_recon += recon_loss.item()
total_kl += kl_loss.item()
n_samples = len(dataloader.dataset)
print(f"Epoch {epoch}: Loss={total_loss/n_samples:.4f}, "
f"Recon={total_recon/n_samples:.4f}, KL={total_kl/n_samples:.4f}")摊销变分推断(Amortized VI)
摊销的动机
传统变分推断为每个数据点独立优化变分参数 ,计算复杂度为 。
摊销变分推断使用一个参数化函数(编码器网络):
这使得:
- 推理成本从 降到 (给定网络前向传播)
- 参数共享:所有数据点共享
- 泛化能力:对未见数据也能推断后验
摊销推断的权衡
| 方面 | 独立VI | 摊销VI |
|---|---|---|
| 灵活性 | 每个数据点独立优化 | 共享参数 |
| 计算效率 | ||
| 表达能力 | 高(独立参数) | 中等(共享函数) |
| 泛化 | 无 | 有 |
对抗性变分推断
当变分分布族不够表达时,可以使用GAN风格的对抗训练:
class AdversarialVAE(nn.Module):
"""
对抗变分自编码器
使用判别器迫使q(z|x)接近p(z|x)
"""
def __init__(self, input_dim, latent_dim, hidden_dim=400):
super().__init__()
# 编码器
self.encoder = Encoder(input_dim, latent_dim, hidden_dim)
# 解码器
self.decoder = Decoder(latent_dim, input_dim, hidden_dim)
# 判别器(区分q(z|x)和p(z))
self.discriminator = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def encode(self, x):
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logvar)
return z
def forward(self, x):
z = self.encode(x)
logits = self.decoder(z)
return logits, z
def discriminator_loss(self, x):
"""判别器损失:鼓励q(z|x)接近p(z)"""
# 从后验采样
z_posterior = self.encode(x)
# 从先验采样
z_prior = torch.randn_like(z_posterior)
# 判别器输出
d_posterior = self.discriminator(z_posterior)
d_prior = self.discriminator(z_prior)
# 对抗损失
return -torch.mean(d_posterior) + torch.mean(d_prior)
def generator_loss(self, x):
"""生成器损失(欺骗判别器)"""
z_posterior = self.encode(x)
d_posterior = self.discriminator(z_posterior)
return -torch.mean(d_posterior)变分推断在深度学习中的应用
1. 变分循环网络
将变分推断扩展到序列模型:
class VariationalLSTM(nn.Module):
"""
变分LSTM
用于序列数据的潜在变量建模
"""
def __init__(self, input_dim, hidden_dim, latent_dim):
super().__init__()
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
# 编码器(从隐藏状态推断潜在变量)
self.q_z = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 2 * latent_dim) # mu和logvar
)
# LSTM
self.lstm = nn.LSTM(input_dim + latent_dim, hidden_dim, batch_first=True)
# 解码器
self.decoder = nn.Linear(hidden_dim, input_dim)
def forward(self, x, h=None):
batch_size, seq_len, _ = x.shape
if h is None:
h = (torch.zeros(1, batch_size, self.hidden_dim),
torch.zeros(1, batch_size, self.hidden_dim))
outputs = []
for t in range(seq_len):
# 推断潜在变量
mu_logvar = self.q_z(h[0][0])
mu, logvar = mu_logvar.chunk(2, dim=-1)
z = self.reparameterize(mu, logvar)
# 输入 + 潜在变量
input_t = torch.cat([x[:, t:t+1], z.unsqueeze(1)], dim=-1)
# LSTM前向
out, h = self.lstm(input_t, h)
outputs.append(self.decoder(out))
return torch.cat(outputs, dim=1)2. 变分图神经网络
在图神经网络中引入潜在变量:
class VariationalGNN(nn.Module):
"""
变分图神经网络
用于图数据的生成和推断
"""
def __init__(self, node_dim, edge_dim, latent_dim, hidden_dim):
super().__init__()
# 节点编码器
self.node_encoder = nn.Sequential(
nn.Linear(node_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2 * latent_dim)
)
# GNN层
self.gnn = MessagePassingLayer(hidden_dim, edge_dim)
# 解码器
self.decoder = nn.Linear(hidden_dim, node_dim)
def forward(self, x, edge_index):
# 推断潜在变量
mu_logvar = self.node_encoder(x)
mu, logvar = mu_logvar.chunk(2, dim=-1)
z = self.reparameterize(mu, logvar)
# GNN消息传递
h = self.gnn(z, edge_index)
# 解码
x_recon = self.decoder(h)
return x_recon, mu, logvar
def loss(self, x, edge_index):
x_recon, mu, logvar = self.forward(x, edge_index)
# 重构损失
recon_loss = F.mse_loss(x_recon, x)
# KL散度
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kl_loss3. 变分dropout与贝叶斯神经网络
变分dropout提供了一种贝叶斯视角的dropout解释:
class VariationalDropout(nn.Module):
"""
变分dropout(Gal & Ghahramani, 2016)
Dropout等价于变分推断中的KL正则化项
"""
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_features, out_features))
self.bias = nn.Parameter(torch.zeros(out_features))
# Dropout率
self.log_alpha = nn.Parameter(torch.zeros(1))
@property
def alpha(self):
return torch.sigmoid(self.log_alpha)
def forward(self, x, training=True):
if training:
# 变分dropout:随机掩码
mask = torch.bernoulli(1 - self.alpha.expand_as(x))
x_dropout = x * mask / (1 - self.alpha)
else:
x_dropout = x
return F.linear(x_dropout, self.weight, self.bias)
def kl_divergence(self):
"""
变分dropout的KL散度
等价于额外的正则化项
"""
return self.alpha.pow(2) / (1 - self.alpha.pow(2) + 1e-8)信息论视角
ELBO的信息论分解
ELBO可以进一步分解为信息论量:
互信息项:
β-VAE
调节ELBO中KL项的权重:
- :标准VAE
- :更强调先验正则化( disentanglement)
- :更强调重构(更清晰的重建)
class BetaVAE(nn.Module):
def __init__(self, input_dim, latent_dim, beta=1.0):
super().__init__()
self.beta = beta
self.vae = VAE(input_dim, latent_dim)
def loss_function(self, x):
logits, mu, logvar = self.vae(x)
recon_loss = F.binary_cross_entropy_with_logits(logits, x, reduction='sum')
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# 加权KL散度
return recon_loss + self.beta * kl_loss与现有wiki内容的联系
| 主题 | 相关文件 |
|---|---|
| 概率图模型 | probabilistic-graphical-models-comprehensive |
| 贝叶斯神经网络 | bayesian-neural-networks |
| 变分推断基础 | variational-inference |
| 信息论基础 | information-theory |
| 归一化流 | normalizing-flows-variational |
参考
相关阅读
Footnotes
-
Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational inference: A review for statisticians. Journal of the American statistical Association, 112(518), 859-877. ↩