概述
重参数化技巧(Reparameterization Trick) 是变分推断和贝叶斯深度学习中的核心技术,它将随机采样操作转化为确定性操作与独立噪声的组合,从而使得梯度可以通过随机变量进行反向传播。1
┌─────────────────────────────────────────────────────────────────┐
│ 重参数化技巧示意 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 直接采样(不可导): │
│ │
│ θ ~ p(θ) → forward(θ) → loss │
│ ↑ │
│ └── 随机采样,无法求导 │
│ │
│ 重参数化(可导): │
│ │
│ ε ~ p(ε) → θ = μ + σ · ε → forward(θ) → loss │
│ ↑ │
│ └── 确定性变换,可求导 │
│ │
└─────────────────────────────────────────────────────────────────┘
1. 变分推断基础回顾
1.1 问题设置
在贝叶斯推断中,我们希望计算后验分布:
但分母中的边缘似然 通常难以计算。
1.2 变分推断的核心思想
使用参数化的近似分布 来逼近真实后验 ,通过优化变分参数 最小化两者的差异。
1.3 ELBO目标函数
证据下界(Evidence Lower Bound, ELBO):
- 第一项:重构似然(数据拟合度)
- 第二项:KL散度(正则化,使近似后验接近先验)
def elbo(x, model, q_theta, p_theta):
"""
计算ELBO
x: 观测数据
model: 生成模型 p(x|θ)
q_theta: 变分分布 q(θ)
p_theta: 先验分布 p(θ)
"""
# 从变分分布采样
theta = q_theta.sample()
# 重构损失
log_px_given_theta = model.log_likelihood(x, theta)
# KL散度
kl_div = q_theta.kl_divergence(p_theta)
# ELBO
return log_px_given_theta - kl_div2. 重参数化技巧详解
2.1 核心思想
重参数化的核心思想是将随机性抽离出来,用一个独立的标准随机变量表示:
2.2 数学推导
设变分分布为 ,我们希望计算:
直接求导会遇到采样操作,无法传播梯度。
重参数化后:
现在梯度可以直接通过 传播!
2.3 通用形式
对于任意可参数的分布 ,重参数化的条件是存在可逆变换 :
常见的重参数化形式:
| 分布 | 重参数化 | 条件 |
|---|---|---|
| 高斯 | ||
| 指数族 | 指数-仿射变换 | 存在闭式变换 |
| 伯努利 | 概率-Logit变换 | Gumbel-Softmax |
| 离散 | Gumbel-Softmax | 温度参数 |
2.4 PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class VariationalLinear(nn.Module):
"""
变分线性层:使用重参数化技巧
"""
def __init__(self, in_features, out_features, prior_mean=0.0, prior_std=1.0):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# 变分参数
self.weight_mean = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
self.weight_logvar = nn.Parameter(torch.zeros(out_features, in_features) - 3) # log(0.1²)
self.bias_mean = nn.Parameter(torch.zeros(out_features))
self.bias_logvar = nn.Parameter(torch.zeros(out_features) - 3)
# 先验参数
self.register_buffer('prior_mean', torch.tensor(prior_mean))
self.register_buffer('prior_std', torch.tensor(prior_std))
def sample_weights(self):
"""
重参数化采样
θ = μ + σ · ε
"""
# 权重采样
weight_std = torch.exp(0.5 * self.weight_logvar)
weight_epsilon = torch.randn_like(self.weight_mean)
weight = self.weight_mean + weight_std * weight_epsilon
# 偏置采样
bias_std = torch.exp(0.5 * self.bias_logvar)
bias_epsilon = torch.randn_like(self.bias_mean)
bias = self.bias_mean + bias_std * bias_epsilon
return weight, bias
def forward(self, x):
weight, bias = self.sample_weights()
return F.linear(x, weight, bias)
def kl_divergence(self):
"""
计算与先验的KL散度
KL(N(μ,σ²) || N(μ₀,σ₀²)) = log(σ₀/σ) + (σ² + (μ-μ₀)²)/(2σ₀²) - 1/2
"""
prior_var = self.prior_std ** 2
# 权重KL
weight_var = torch.exp(self.weight_logvar) ** 2
kl_weight = 0.5 * (
torch.log(prior_var / weight_var) +
(weight_var + (self.weight_mean - self.prior_mean) ** 2) / prior_var -
1
)
# 偏置KL
bias_var = torch.exp(self.bias_logvar) ** 2
kl_bias = 0.5 * (
torch.log(prior_var / bias_var) +
(bias_var + (self.bias_mean - self.prior_mean) ** 2) / prior_var -
1
)
return kl_weight.sum() + kl_bias.sum()
class BayesianMLP(nn.Module):
"""
贝叶斯多层感知机
"""
def __init__(self, input_dim, hidden_dims, output_dim):
super().__init__()
layers = []
prev_dim = input_dim
for hidden_dim in hidden_dims:
layers.append(VariationalLinear(prev_dim, hidden_dim))
layers.append(nn.ReLU())
prev_dim = hidden_dim
layers.append(VariationalLinear(prev_dim, output_dim))
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
def elbo_loss(self, x, y, num_samples=1):
"""
ELBO损失 = 重构损失 + KL散度
"""
batch_size = x.shape[0]
# 重构损失:多次采样平均
log_likelihoods = []
kl_total = 0
for _ in range(num_samples):
output = self.forward(x)
log_likelihood = F.cross_entropy(output, y, reduction='sum')
log_likelihoods.append(log_likelihood)
# 累加KL散度
for module in self.modules():
if isinstance(module, VariationalLinear):
kl_total += module.kl_divergence()
# 平均重构损失
avg_log_likelihood = torch.stack(log_likelihoods).mean()
# ELBO = log p(x|θ) - KL(q(θ)||p(θ))
# 需要归一化
return -(avg_log_likelihood - kl_total) / batch_size2.5 方差分析
重参数化技巧的梯度估计方差:
优点:
- 梯度估计方差较低(相对于Score Function Estimator)
- 收敛更快
缺点:
- 仅适用于可重参数化的分布(通常为连续分布)
3. 局部重参数化技巧
3.1 问题背景
标准重参数化的计算量问题:对于神经网络,需要对每个权重独立采样。
假设网络有 个参数,每次前向传播需要 次采样操作。
3.2 核心思想
局部重参数化(Local Reparameterization Trick) 的关键洞见是:直接在激活值上采样,而非权重上采样。2
标准重参数化:
权重采样: W ~ q(W) → 每层 O(|W|) 个随机变量
局部重参数化:
激活采样: a ~ q(a) → 每层 O(|a|) 个随机变量 (通常 |a| << |W|)
3.3 数学推导
考虑线性变换 ,其中:
标准方法:从 的变分分布中采样,然后计算
局部重参数化:直接在 的分布上采样
由于 是 的线性组合,若 独立,则:
def local_reparameterize(weight_mean, weight_logvar, input_x):
"""
局部重参数化
直接在激活值上采样,而非权重上采样
z ~ N(μ_z, σ²_z) 其中:
μ_z = mean(W) @ x + b
σ²_z = var(W) @ (x²)
"""
# 计算激活的均值
mu_z = F.linear(input_x, weight_mean)
# 计算激活的方差
weight_var = torch.exp(weight_logvar)
# F.linear(x², var(W)) 近似激活方差
mu_z2 = F.linear(input_x ** 2, weight_var)
# 在激活空间采样
std_z = torch.sqrt(mu_z2 + 1e-8) # 加小常数防止数值问题
epsilon = torch.randn_like(mu_z)
z = mu_z + std_z * epsilon
return z3.4 方差对比
| 方法 | 梯度方差 | 计算复杂度 |
|---|---|---|
| 标准重参数化 | 缩放 | |
| 局部重参数化 | 降低 | |
| Score Function | 较高 |
局部重参数化的方差降低原理:
当 较大时,激活值的方差通过求和被平均化,因此梯度方差降低。
3.5 PyTorch实现
class LocalReparameterizedLinear(nn.Module):
"""
使用局部重参数化技巧的线性层
"""
def __init__(self, in_features, out_features, prior_std=1.0):
super().__init__()
# 变分参数
self.weight_mean = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
self.weight_logvar = nn.Parameter(torch.zeros(out_features, in_features) - 3)
self.bias_mean = nn.Parameter(torch.zeros(out_features))
# 先验标准差
self.prior_std = prior_std
def forward(self, x):
# 局部重参数化:直接在激活空间采样
z = self._local_reparameterize(x)
return z
def _local_reparameterize(self, x):
"""
局部重参数化实现
"""
# 激活均值
mu_z = F.linear(x, self.weight_mean, self.bias_mean)
# 激活方差
weight_var = torch.exp(self.weight_logvar)
mu_z2 = F.linear(x ** 2, weight_var) # E[(Wx)²] = Var(Wx) + (E[Wx])²
# 标准差
std_z = torch.sqrt(mu_z2 - F.linear(x, self.weight_mean) ** 2 + 1e-8)
# 采样
epsilon = torch.randn_like(mu_z)
return mu_z + std_z * epsilon
def kl_divergence(self):
"""
计算与高斯先验的KL散度
"""
prior_var = self.prior_std ** 2
# KL = log(σ₀/σ) + (σ² + μ²)/(2σ₀²) - 1/2
weight_var = torch.exp(self.weight_logvar)
kl = 0.5 * (
torch.log(prior_var / weight_var) +
(weight_var + self.weight_mean ** 2) / prior_var -
1
)
return kl.sum()
class LocalReparameterizedMLP(nn.Module):
"""
使用局部重参数化的贝叶斯MLP
"""
def __init__(self, input_dim, hidden_dims, output_dim):
super().__init__()
layers = []
prev_dim = input_dim
for hidden_dim in hidden_dims:
layers.append(LocalReparameterizedLinear(prev_dim, hidden_dim))
layers.append(nn.ReLU())
prev_dim = hidden_dim
layers.append(LocalReparameterizedLinear(prev_dim, output_dim))
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
def elbo_loss(self, x, y, beta=1.0):
"""
β-VAE风格的ELBO损失
"""
# 前向传播(内部采样)
output = self.forward(x)
# 重构损失
recon_loss = F.cross_entropy(output, y, reduction='mean')
# KL损失
kl_loss = 0
for module in self.modules():
if isinstance(module, LocalReparameterizedLinear):
kl_loss += module.kl_divergence()
# β加权
return recon_loss + beta * kl_loss / x.shape[0]4. 与VAE的关系
4.1 VAE中的重参数化
变分自编码器(VAE)使用重参数化来处理离散潜在变量:
class VAE(nn.Module):
"""
标准VAE实现
"""
def __init__(self, input_dim, latent_dim):
super().__init__()
self.latent_dim = latent_dim
# 编码器 q(z|x)
self.encoder = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, latent_dim * 2) # 均值 + 对数方差
)
# 解码器 p(x|z)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, input_dim),
nn.Sigmoid()
)
def reparameterize(self, mu, logvar):
"""
重参数化技巧
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
# 编码
h = self.encoder(x)
mu, logvar = h.chunk(2, dim=-1)
# 重参数化采样
z = self.reparameterize(mu, logvar)
# 解码
x_recon = self.decoder(z)
return x_recon, mu, logvar
def loss(self, x, x_recon, mu, logvar):
"""
VAE损失 = 重构损失 + KL散度
"""
# 二元交叉熵重构损失
recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
# KL散度: KL(N(μ,σ) || N(0,I)) = -0.5 * Σ(1 + log(σ²) - μ² - σ²)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kl_loss4.2 对比:重参数化 vs SGVB
SGVB(Stochastic Gradient Variational Bayes) 估计器:
其中 使用重参数化。
5. 梯度估计方法对比
5.1 三种梯度估计器
| 方法 | 公式 | 适用场景 | 方差 |
|---|---|---|---|
| Score Function (REINFORCE) | 任意分布 | 高 | |
| Path Derivative (重参数化) | 可重参数化 | 低 | |
| 局部重参数化 | 激活空间采样 | 可重参数化 | 更低 |
5.2 Score Function vs 重参数化
# Score Function估计器
def score_function_gradient(f, q, phi, num_samples=100):
"""
Score Function (REINFORCE) 梯度估计
"""
gradients = []
for _ in range(num_samples):
theta = q.sample() # 从变分分布采样
log_prob = q.log_prob(theta) # log q(θ; φ)
grad_log_prob = torch.autograd.grad(log_prob, phi, create_graph=True)[0]
gradients.append(f(theta) * grad_log_prob)
return torch.stack(gradients).mean()
# 重参数化梯度估计
def reparameterization_gradient(f, q, phi, num_samples=100):
"""
重参数化梯度估计
"""
gradients = []
for _ in range(num_samples):
epsilon = q.base_dist.sample() # 从基础分布采样
theta = q.transform(phi, epsilon) # 确定性变换
grad_theta = torch.autograd.grad(f(theta), phi)[0]
gradients.append(grad_theta)
return torch.stack(gradients).mean()5.3 方差可视化
Score Function: ████████████████░░░░ (高方差)
Path Derivative: ████████░░░░░░░░░░░ (中等)
Local Rep: ████░░░░░░░░░░░░░░░ (低方差)
6. 实践注意事项
6.1 数值稳定性
def stable_kl_divergence(mu, logvar, prior_mean=0, prior_std=1):
"""
数值稳定的KL散度计算
"""
prior_var = prior_std ** 2
# 使用log-sum-exp提高数值稳定性
kl = 0.5 * (
torch.log(prior_var) - logvar +
torch.exp(logvar) / prior_var +
(mu - prior_mean) ** 2 / prior_var -
1
)
return kl.sum()6.2 梯度裁剪
对于高方差估计,使用梯度裁剪:
def clipped_gradient(gradient, max_norm=1.0):
"""梯度裁剪"""
return torch.nn.utils.clip_grad_norm_(gradient, max_norm)6.3 蒙特卡洛估计的样本数
- 训练初期:使用较多样本(10-100)降低方差
- 训练后期:减少样本数(1-5)提高效率
- 可使用方差减少技术(如Control Variates)
7. 扩展阅读
7.1 相关技术
| 技术 | 描述 | 参考 |
|---|---|---|
| ** normalizing-flows** | 可逆变换增强表达能力 | normalizing-flows |
| IWAE | 重要性加权ELBO | variational-inference-advanced |
| 流动匹配 | 最优传输视角 | diffusion-flow-matching |
| 贝叶斯优化 | 黑盒函数优化 | 贝叶斯优化文献 |
7.2 经典论文
- Kingma, D. P., & Welling, M. (2014). “Auto-Encoding Variational Bayes”. ICLR.
- Kingma, D. P., Salimans, T., & Welling, M. (2015). “Variational Dropout and the Local Reparameterization Trick”. NeurIPS.
- Blundell, C., et al. (2015). “Weight Uncertainty in Neural Networks”. ICML.