概述

重参数化技巧(Reparameterization Trick) 是变分推断和贝叶斯深度学习中的核心技术,它将随机采样操作转化为确定性操作与独立噪声的组合,从而使得梯度可以通过随机变量进行反向传播。1

┌─────────────────────────────────────────────────────────────────┐
│                     重参数化技巧示意                               │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  直接采样(不可导):                                              │
│                                                                 │
│     θ ~ p(θ)  →  forward(θ)  →  loss                           │
│         ↑                                                           │
│         └── 随机采样,无法求导                                    │
│                                                                 │
│  重参数化(可导):                                                │
│                                                                 │
│     ε ~ p(ε)  →  θ = μ + σ · ε  →  forward(θ)  →  loss          │
│                      ↑                                           │
│                      └── 确定性变换,可求导                       │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

1. 变分推断基础回顾

1.1 问题设置

在贝叶斯推断中,我们希望计算后验分布:

但分母中的边缘似然 通常难以计算

1.2 变分推断的核心思想

使用参数化的近似分布 来逼近真实后验 ,通过优化变分参数 最小化两者的差异。

1.3 ELBO目标函数

证据下界(Evidence Lower Bound, ELBO)

  • 第一项:重构似然(数据拟合度)
  • 第二项:KL散度(正则化,使近似后验接近先验)
def elbo(x, model, q_theta, p_theta):
    """
    计算ELBO
    x: 观测数据
    model: 生成模型 p(x|θ)
    q_theta: 变分分布 q(θ)
    p_theta: 先验分布 p(θ)
    """
    # 从变分分布采样
    theta = q_theta.sample()
    
    # 重构损失
    log_px_given_theta = model.log_likelihood(x, theta)
    
    # KL散度
    kl_div = q_theta.kl_divergence(p_theta)
    
    # ELBO
    return log_px_given_theta - kl_div

2. 重参数化技巧详解

2.1 核心思想

重参数化的核心思想是将随机性抽离出来,用一个独立的标准随机变量表示:

2.2 数学推导

设变分分布为 ,我们希望计算:

直接求导会遇到采样操作,无法传播梯度。

重参数化后

现在梯度可以直接通过 传播!

2.3 通用形式

对于任意可参数的分布 ,重参数化的条件是存在可逆变换

常见的重参数化形式:

分布重参数化条件
高斯
指数族指数-仿射变换存在闭式变换
伯努利概率-Logit变换Gumbel-Softmax
离散Gumbel-Softmax温度参数

2.4 PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class VariationalLinear(nn.Module):
    """
    变分线性层:使用重参数化技巧
    """
    def __init__(self, in_features, out_features, prior_mean=0.0, prior_std=1.0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # 变分参数
        self.weight_mean = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
        self.weight_logvar = nn.Parameter(torch.zeros(out_features, in_features) - 3)  # log(0.1²)
        
        self.bias_mean = nn.Parameter(torch.zeros(out_features))
        self.bias_logvar = nn.Parameter(torch.zeros(out_features) - 3)
        
        # 先验参数
        self.register_buffer('prior_mean', torch.tensor(prior_mean))
        self.register_buffer('prior_std', torch.tensor(prior_std))
    
    def sample_weights(self):
        """
        重参数化采样
        θ = μ + σ · ε
        """
        # 权重采样
        weight_std = torch.exp(0.5 * self.weight_logvar)
        weight_epsilon = torch.randn_like(self.weight_mean)
        weight = self.weight_mean + weight_std * weight_epsilon
        
        # 偏置采样
        bias_std = torch.exp(0.5 * self.bias_logvar)
        bias_epsilon = torch.randn_like(self.bias_mean)
        bias = self.bias_mean + bias_std * bias_epsilon
        
        return weight, bias
    
    def forward(self, x):
        weight, bias = self.sample_weights()
        return F.linear(x, weight, bias)
    
    def kl_divergence(self):
        """
        计算与先验的KL散度
        KL(N(μ,σ²) || N(μ₀,σ₀²)) = log(σ₀/σ) + (σ² + (μ-μ₀)²)/(2σ₀²) - 1/2
        """
        prior_var = self.prior_std ** 2
        
        # 权重KL
        weight_var = torch.exp(self.weight_logvar) ** 2
        kl_weight = 0.5 * (
            torch.log(prior_var / weight_var) + 
            (weight_var + (self.weight_mean - self.prior_mean) ** 2) / prior_var - 
            1
        )
        
        # 偏置KL
        bias_var = torch.exp(self.bias_logvar) ** 2
        kl_bias = 0.5 * (
            torch.log(prior_var / bias_var) + 
            (bias_var + (self.bias_mean - self.prior_mean) ** 2) / prior_var - 
            1
        )
        
        return kl_weight.sum() + kl_bias.sum()
 
 
class BayesianMLP(nn.Module):
    """
    贝叶斯多层感知机
    """
    def __init__(self, input_dim, hidden_dims, output_dim):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(VariationalLinear(prev_dim, hidden_dim))
            layers.append(nn.ReLU())
            prev_dim = hidden_dim
        layers.append(VariationalLinear(prev_dim, output_dim))
        
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)
    
    def elbo_loss(self, x, y, num_samples=1):
        """
        ELBO损失 = 重构损失 + KL散度
        """
        batch_size = x.shape[0]
        
        # 重构损失:多次采样平均
        log_likelihoods = []
        kl_total = 0
        
        for _ in range(num_samples):
            output = self.forward(x)
            log_likelihood = F.cross_entropy(output, y, reduction='sum')
            log_likelihoods.append(log_likelihood)
            
            # 累加KL散度
            for module in self.modules():
                if isinstance(module, VariationalLinear):
                    kl_total += module.kl_divergence()
        
        # 平均重构损失
        avg_log_likelihood = torch.stack(log_likelihoods).mean()
        
        # ELBO = log p(x|θ) - KL(q(θ)||p(θ))
        # 需要归一化
        return -(avg_log_likelihood - kl_total) / batch_size

2.5 方差分析

重参数化技巧的梯度估计方差:

优点

  • 梯度估计方差较低(相对于Score Function Estimator)
  • 收敛更快

缺点

  • 仅适用于可重参数化的分布(通常为连续分布)

3. 局部重参数化技巧

3.1 问题背景

标准重参数化的计算量问题:对于神经网络,需要对每个权重独立采样。

假设网络有 个参数,每次前向传播需要 次采样操作。

3.2 核心思想

局部重参数化(Local Reparameterization Trick) 的关键洞见是:直接在激活值上采样,而非权重上采样2

标准重参数化:
    权重采样: W ~ q(W)  →  每层 O(|W|) 个随机变量
    
局部重参数化:
    激活采样: a ~ q(a)  →  每层 O(|a|) 个随机变量 (通常 |a| << |W|)

3.3 数学推导

考虑线性变换 ,其中:

标准方法:从 的变分分布中采样,然后计算

局部重参数化:直接在 的分布上采样

由于 的线性组合,若 独立,则:

def local_reparameterize(weight_mean, weight_logvar, input_x):
    """
    局部重参数化
    直接在激活值上采样,而非权重上采样
    
    z ~ N(μ_z, σ²_z) 其中:
    μ_z = mean(W) @ x + b
    σ²_z = var(W) @ (x²)
    """
    # 计算激活的均值
    mu_z = F.linear(input_x, weight_mean)
    
    # 计算激活的方差
    weight_var = torch.exp(weight_logvar)
    # F.linear(x², var(W)) 近似激活方差
    mu_z2 = F.linear(input_x ** 2, weight_var)
    
    # 在激活空间采样
    std_z = torch.sqrt(mu_z2 + 1e-8)  # 加小常数防止数值问题
    epsilon = torch.randn_like(mu_z)
    z = mu_z + std_z * epsilon
    
    return z

3.4 方差对比

方法梯度方差计算复杂度
标准重参数化 缩放
局部重参数化 降低
Score Function较高

局部重参数化的方差降低原理

较大时,激活值的方差通过求和被平均化,因此梯度方差降低。

3.5 PyTorch实现

class LocalReparameterizedLinear(nn.Module):
    """
    使用局部重参数化技巧的线性层
    """
    def __init__(self, in_features, out_features, prior_std=1.0):
        super().__init__()
        
        # 变分参数
        self.weight_mean = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
        self.weight_logvar = nn.Parameter(torch.zeros(out_features, in_features) - 3)
        
        self.bias_mean = nn.Parameter(torch.zeros(out_features))
        
        # 先验标准差
        self.prior_std = prior_std
    
    def forward(self, x):
        # 局部重参数化:直接在激活空间采样
        z = self._local_reparameterize(x)
        return z
    
    def _local_reparameterize(self, x):
        """
        局部重参数化实现
        """
        # 激活均值
        mu_z = F.linear(x, self.weight_mean, self.bias_mean)
        
        # 激活方差
        weight_var = torch.exp(self.weight_logvar)
        mu_z2 = F.linear(x ** 2, weight_var)  # E[(Wx)²] = Var(Wx) + (E[Wx])²
        
        # 标准差
        std_z = torch.sqrt(mu_z2 - F.linear(x, self.weight_mean) ** 2 + 1e-8)
        
        # 采样
        epsilon = torch.randn_like(mu_z)
        return mu_z + std_z * epsilon
    
    def kl_divergence(self):
        """
        计算与高斯先验的KL散度
        """
        prior_var = self.prior_std ** 2
        
        # KL = log(σ₀/σ) + (σ² + μ²)/(2σ₀²) - 1/2
        weight_var = torch.exp(self.weight_logvar)
        
        kl = 0.5 * (
            torch.log(prior_var / weight_var) + 
            (weight_var + self.weight_mean ** 2) / prior_var - 
            1
        )
        
        return kl.sum()
 
 
class LocalReparameterizedMLP(nn.Module):
    """
    使用局部重参数化的贝叶斯MLP
    """
    def __init__(self, input_dim, hidden_dims, output_dim):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(LocalReparameterizedLinear(prev_dim, hidden_dim))
            layers.append(nn.ReLU())
            prev_dim = hidden_dim
        layers.append(LocalReparameterizedLinear(prev_dim, output_dim))
        
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)
    
    def elbo_loss(self, x, y, beta=1.0):
        """
        β-VAE风格的ELBO损失
        """
        # 前向传播(内部采样)
        output = self.forward(x)
        
        # 重构损失
        recon_loss = F.cross_entropy(output, y, reduction='mean')
        
        # KL损失
        kl_loss = 0
        for module in self.modules():
            if isinstance(module, LocalReparameterizedLinear):
                kl_loss += module.kl_divergence()
        
        # β加权
        return recon_loss + beta * kl_loss / x.shape[0]

4. 与VAE的关系

4.1 VAE中的重参数化

变分自编码器(VAE)使用重参数化来处理离散潜在变量

class VAE(nn.Module):
    """
    标准VAE实现
    """
    def __init__(self, input_dim, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim
        
        # 编码器 q(z|x)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim * 2)  # 均值 + 对数方差
        )
        
        # 解码器 p(x|z)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
            nn.Sigmoid()
        )
    
    def reparameterize(self, mu, logvar):
        """
        重参数化技巧
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        # 编码
        h = self.encoder(x)
        mu, logvar = h.chunk(2, dim=-1)
        
        # 重参数化采样
        z = self.reparameterize(mu, logvar)
        
        # 解码
        x_recon = self.decoder(z)
        
        return x_recon, mu, logvar
    
    def loss(self, x, x_recon, mu, logvar):
        """
        VAE损失 = 重构损失 + KL散度
        """
        # 二元交叉熵重构损失
        recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
        
        # KL散度: KL(N(μ,σ) || N(0,I)) = -0.5 * Σ(1 + log(σ²) - μ² - σ²)
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        
        return recon_loss + kl_loss

4.2 对比:重参数化 vs SGVB

SGVB(Stochastic Gradient Variational Bayes) 估计器:

其中 使用重参数化。


5. 梯度估计方法对比

5.1 三种梯度估计器

方法公式适用场景方差
Score Function (REINFORCE)任意分布
Path Derivative (重参数化)可重参数化
局部重参数化激活空间采样可重参数化更低

5.2 Score Function vs 重参数化

# Score Function估计器
def score_function_gradient(f, q, phi, num_samples=100):
    """
    Score Function (REINFORCE) 梯度估计
    """
    gradients = []
    for _ in range(num_samples):
        theta = q.sample()  # 从变分分布采样
        log_prob = q.log_prob(theta)  # log q(θ; φ)
        grad_log_prob = torch.autograd.grad(log_prob, phi, create_graph=True)[0]
        gradients.append(f(theta) * grad_log_prob)
    
    return torch.stack(gradients).mean()
 
 
# 重参数化梯度估计
def reparameterization_gradient(f, q, phi, num_samples=100):
    """
    重参数化梯度估计
    """
    gradients = []
    for _ in range(num_samples):
        epsilon = q.base_dist.sample()  # 从基础分布采样
        theta = q.transform(phi, epsilon)  # 确定性变换
        grad_theta = torch.autograd.grad(f(theta), phi)[0]
        gradients.append(grad_theta)
    
    return torch.stack(gradients).mean()

5.3 方差可视化

Score Function:    ████████████████░░░░  (高方差)
Path Derivative:   ████████░░░░░░░░░░░  (中等)
Local Rep:         ████░░░░░░░░░░░░░░░  (低方差)

6. 实践注意事项

6.1 数值稳定性

def stable_kl_divergence(mu, logvar, prior_mean=0, prior_std=1):
    """
    数值稳定的KL散度计算
    """
    prior_var = prior_std ** 2
    
    # 使用log-sum-exp提高数值稳定性
    kl = 0.5 * (
        torch.log(prior_var) - logvar + 
        torch.exp(logvar) / prior_var + 
        (mu - prior_mean) ** 2 / prior_var - 
        1
    )
    
    return kl.sum()

6.2 梯度裁剪

对于高方差估计,使用梯度裁剪:

def clipped_gradient(gradient, max_norm=1.0):
    """梯度裁剪"""
    return torch.nn.utils.clip_grad_norm_(gradient, max_norm)

6.3 蒙特卡洛估计的样本数

  • 训练初期:使用较多样本(10-100)降低方差
  • 训练后期:减少样本数(1-5)提高效率
  • 可使用方差减少技术(如Control Variates)

7. 扩展阅读

7.1 相关技术

技术描述参考
** normalizing-flows**可逆变换增强表达能力normalizing-flows
IWAE重要性加权ELBOvariational-inference-advanced
流动匹配最优传输视角diffusion-flow-matching
贝叶斯优化黑盒函数优化贝叶斯优化文献

7.2 经典论文

  1. Kingma, D. P., & Welling, M. (2014). “Auto-Encoding Variational Bayes”. ICLR.
  2. Kingma, D. P., Salimans, T., & Welling, M. (2015). “Variational Dropout and the Local Reparameterization Trick”. NeurIPS.
  3. Blundell, C., et al. (2015). “Weight Uncertainty in Neural Networks”. ICML.

参考资料


相关链接

Footnotes

  1. Kingma, D. P., & Welling, M. (2014). Auto-Encoding Variational Bayes. ICLR 2014.

  2. Kingma, D. P., Salimans, T., & Welling, M. (2015). Variational Dropout and the Local Reparameterization Trick. NeurIPS 2015.