Score Matching变分推断:GSM-VI与BaM

1. 概述

变分推断(Variational Inference, VI)是贝叶斯深度学习的核心技术之一。传统VI使用**证据下界(ELBO)**作为目标函数:

然而,**Black Box变分推断(BBVI)**存在一个根本性问题:梯度方差过大。当使用score function estimator时:

梯度估计的高方差导致收敛缓慢,需要大量Monte Carlo样本。

**Score Matching VI(GSM-VI)Batch and Match(BaM)**提供了全新的优化视角:将变分推断问题转化为score matching问题,从而绕过ELBO梯度估计的困难。

12

2. 从ELBO到Score Matching

2.1 ELBO梯度的方差问题

ELBO的两种梯度形式

Score Function Estimator:

其中

问题

  • 峰度较高时, 都可能很大
  • 乘积的方差是两者方差的乘积项
  • 需要大量Monte Carlo样本才能得到可靠估计

Pathwise Gradient(重参数化):

问题

  • 要求 可重参数化
  • 隐变量通常是离散或混合分布时不可用
  • 对网络结构有限制

2.2 Score Matching的直觉

核心观察:KL散度最小化

Score Function定义:

Score Matching目标

物理意义:最小化真实分布和变分分布在score函数空间的距离。

2.3 为什么Score Matching更稳定?

Score Function的特性

  • Score是梯度的方向,天然具有单位范数的性质
  • 不依赖于似然的绝对尺度
  • 对概率归一化不敏感

对比

方面ELBO (BBVI)Score Matching
梯度尺度依赖似然绝对值归一化到单位球
方差高(需要大量样本)低(几何结构稳定)
可处理分布可重参数化任意可微分分布
计算成本O(1/M),M=样本数O(1)(单样本即可)

3. GSM-VI:Score Matching变分推断

3.1 GSM-VI目标函数

论文:Gaussian Score Matching for Variational Inference (NeurIPS 2023)

GSM-VI针对高斯变分分布推导了简化的score matching目标。

高斯分布的Score

Score Matching for Gaussian

简化推导

最终目标(忽略常数):

3.2 期望的闭式计算

关键发现:期望可以解析计算!

GSM-VI最终目标

结论:高斯变分分布的GSM目标只依赖于 (协方差),与 无关!

3.3 与ELBO的联系

高斯ELBO

对比

  • ELBO梯度涉及 ,方差大
  • GSM-VI梯度只涉及 ,方差小
  • GSM不依赖对数似然的梯度,只依赖其值

3.4 梯度计算

GSM-VI梯度

更新方向:最大化 等价于最小化 的特征值。

物理意义:GSM-VI鼓励变分分布的协方差更加”紧凑”,趋向先验或最大熵分布。

4. BaM:Batch and Match

4.1 核心思想

论文:Batch and Match for Black-Box Variational Inference (ICML 2024)

BaM将GSM-VI的思想进一步发展,提出闭式近端更新

问题设置

:最小化 在score空间的差异。

4.2 Proximal Update

近端算子

物理意义:在保持与当前分布 接近的同时,向参考分布 移动。

4.3 闭式解推导

对于高斯分布

闭式近端更新

类比:这类似于指数加权移动平均,但作用于协方差矩阵。

4.4 BaM算法流程

def batch_and_match(model, data, n_batches=10, alpha=0.1):
    """
    Batch and Match for BBVI
    
    Args:
        model: Neural network with variational params
        data: Dataset
        n_batches: Number of data batches for Fisher estimation
        alpha: Proximal step size
    """
    phi = model.variational_params
    
    for iteration in range(n_iterations):
        # 1. 计算当前变分分布的统计量
        grads = []
        for batch in data.random_batches(n_batches):
            loss = model.elbo(batch)
            grads.append(torch.autograd.grad(loss, phi))
        
        # 2. 估计Fisher信息(使用样本梯度)
        F_hat = estimate_fisher(grads)
        
        # 3. BaM闭式更新
        phi = prox_update(phi, F_hat, alpha)
        
    return phi
 
 
def prox_update(phi, F_hat, alpha):
    """闭式近端更新"""
    mu, Sigma = phi['mu'], phi['Sigma']
    
    # Fisher对齐目标分布
    Sigma_0 = F_hat  # Fisher作为目标分布
    
    # 近端更新
    Sigma_new = torch.linalg.inv(
        (1 - alpha) * torch.linalg.inv(Sigma) + alpha * torch.linalg.inv(Sigma_0)
    )
    
    mu_new = Sigma_new @ (
        (1 - alpha) * torch.linalg.inv(Sigma) @ mu +
        alpha * torch.linalg.inv(Sigma_0) @ torch.zeros_like(mu)
    )
    
    return {'mu': mu_new, 'Sigma': Sigma_new}

5. 收敛性分析

5.1 GSM-VI的收敛保证

定理(GSM-VI论文):对于高斯变分分布,GSM目标的最大化等价于KL散度的最小化(在常数项内)。

证明概要

  1. KL散度可以写成score函数的期望
  2. Score matching恰好是最小化这个期望
  3. 闭式期望计算保证无偏性

5.2 BaM的指数收敛

定理(BaM论文):当目标分布 是高斯且步长 满足 时,BaM更新指数收敛到最优解。

收敛速率

物理意义:每步减少 的总变差距离。

5.3 与BBVI的对比

方面BBVIGSM-VI/BaM
梯度方差O(1/√M)O(1)(闭式)
收敛速度慢(高方差)快(低方差)
样本需求100-10001-10
实现复杂度中等
理论保证渐近指数收敛

6. 实践实现

6.1 PyTorch实现

import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
 
class GSMVI(nn.Module):
    """Gaussian Score Matching VI"""
    
    def __init__(self, dim, prior_std=1.0):
        super().__init__()
        self.dim = dim
        
        # 变分参数(对数协方差)
        self.log_var = nn.Parameter(torch.zeros(dim))
        self.prior_std = prior_std
        
    @property
    def sigma(self):
        """协方差矩阵(对角)"""
        return torch.exp(self.log_var).diag_embed()
    
    @property
    def sigma_inv(self):
        """协方差逆矩阵"""
        return torch.exp(-self.log_var).diag_embed()
    
    def gsm_objective(self, log_likelihood_fn, x, n_samples=1):
        """
        GSM-VI目标函数
        
        Args:
            log_likelihood_fn: log p(x|z) 函数
            x: 观测数据
            n_samples: Monte Carlo样本数
        """
        batch_size = x.size(0)
        z = torch.randn(batch_size, self.dim, device=x.device) @ self.sigma.cholesky()
        z = z + self.variational_mean.unsqueeze(0)  # 添加均值
        
        # Score matching目标
        score_norm = torch.sum((z - self.variational_mean)**2 * torch.exp(-self.log_var))
        
        # Tracy-Widom项
        tr_sigma_inv = torch.sum(torch.exp(-self.log_var))
        
        # Score matching损失
        loss = score_norm - 2 * self.dim * tr_sigma_inv
        
        # 添加似然项(简化版本)
        with torch.no_grad():
            ll = log_likelihood_fn(x, z)
        loss = loss - ll.mean()  # 负ELBO代理
        
        return loss / batch_size
    
    def fit(self, log_likelihood_fn, x, lr=0.01, n_epochs=100):
        """训练"""
        optimizer = torch.optim.Adam([self.log_var], lr=lr)
        
        for epoch in range(n_epochs):
            optimizer.zero_grad()
            loss = self.gsm_objective(log_likelihood_fn, x)
            loss.backward()
            optimizer.step()
            
            if epoch % 10 == 0:
                print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
 
 
class BaMOptimizer:
    """Batch and Match优化器"""
    
    def __init__(self, var_params, alpha=0.1):
        self.alpha = alpha
        self.var_params = var_params
        
        # 初始化
        self.mu = var_params['mu'].clone()
        self.Sigma = var_params['Sigma'].clone()
        
    def compute_fisher_estimate(self, grad_samples):
        """
        从梯度样本估计Fisher信息矩阵
        
        使用empirical Fisher: F ≈ (1/M) Σ ∇ℓ ∇ℓ⊤
        """
        grad_stack = torch.stack(grad_samples)
        M = grad_stack.size(0)
        
        # 样本协方差
        F_hat = torch.cov(grad_stack.t())
        
        # 对角近似(更稳定)
        F_hat = torch.diag(torch.diag(F_hat) + 1e-3)
        
        return F_hat
    
    def proximal_update(self, F_hat):
        """闭式近端更新"""
        # 协方差更新
        Sigma_inv = torch.linalg.inv(self.Sigma)
        F_hat_inv = torch.linalg.inv(F_hat + 1e-3 * torch.eye_like(F_hat))
        
        Sigma_new = torch.linalg.inv(
            (1 - self.alpha) * Sigma_inv + self.alpha * F_hat_inv
        )
        
        # 均值更新(假设先验均值=0)
        mu_new = Sigma_new @ ((1 - self.alpha) * Sigma_inv @ self.mu)
        
        self.Sigma = Sigma_new
        self.mu = mu_new
        
        return self.mu, self.Sigma
    
    def step(self, grad_fn, batch_size=10):
        """
        执行一步BaM更新
        
        Args:
            grad_fn: 返回梯度样本的函数
            batch_size: 样本数
        """
        # 收集梯度样本
        grad_samples = []
        for _ in range(batch_size):
            grads = grad_fn()
            grad_samples.append(grads)
        
        # 估计Fisher
        F_hat = self.compute_fisher_estimate(grad_samples)
        
        # 近端更新
        self.proximal_update(F_hat)
        
        return self.mu, self.Sigma

6.2 应用示例:贝叶斯线性回归

def bayesian_linear_regression_gsm(X, y, alpha_prior=1.0, n_epochs=500):
    """
    GSM-VI贝叶斯线性回归
    
    Args:
        X: (n, d) 设计矩阵
        y: (n,) 目标值
        alpha_prior: 先验精度
    """
    n, d = X.shape
    
    # 初始化
    model = GSMVI(dim=d, prior_std=1.0)
    
    def log_likelihood(z):
        """log p(y|X, z) = N(y; Xz, σ²I)"""
        pred = X @ z.unsqueeze(-1)
        return -0.5 * torch.sum((y.unsqueeze(-1) - pred)**2)
    
    # 训练
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    
    for epoch in range(n_epochs):
        optimizer.zero_grad()
        
        # GSM-VI目标
        loss = model.gsm_objective(log_likelihood, torch.zeros(d))
        
        loss.backward()
        optimizer.step()
        
        if epoch % 100 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
    
    # 后验均值和方差
    posterior_mean = model.variational_mean.detach()
    posterior_var = torch.exp(model.log_var.detach())
    
    return posterior_mean, posterior_var
 
 
def bayesian_linear_regression_bam(X, y, n_batches=10):
    """
    BaM贝叶斯线性回归
    """
    n, d = X.shape
    
    # 初始化变分参数
    var_params = {
        'mu': torch.zeros(d),
        'Sigma': torch.eye(d),
    }
    
    optimizer = BaMOptimizer(var_params, alpha=0.1)
    
    # 定义梯度函数
    def grad_fn():
        z = torch.randn(d)
        pred = X @ z
        loss = torch.sum((y - pred)**2)
        grads = torch.autograd.grad(loss, var_params['mu'])
        return grads[0]
    
    # 训练
    for _ in range(n_batches):
        optimizer.step(grad_fn)
    
    return var_params['mu'], var_params['Sigma']

7. 与其他方法的对比

7.1 方法对比表

方法梯度方差收敛速度实现难度适用范围
BBVI (Score Function)任意分布
BBVI (Pathwise)可重参数化
REINFORCE极高很慢任意分布
GSM-VI极快高斯分布
BaM指数任意分布
LaplaceN/AN/AMAP估计

7.2 选择指南

选择GSM-VI/BaM当

  • 变分分布是高斯的
  • 需要快速收敛
  • 梯度方差是关键瓶颈
  • 资源有限(减少Monte Carlo样本)

选择BBVI当

  • 变分分布是非高斯的
  • 需要更灵活的近似族
  • 可以负担大量样本

混合方法

  • 初期使用GSM-VI/BaM快速收敛
  • 后期切换到BBVI精细调节

8. 扩展与应用

8.1 非高斯变分分布

混合高斯

Score Matching应用于每个成分,然后加权合并。

归一化流

使用Jacobian行列式修正score函数。

8.2 深度学习应用

贝叶斯神经网络

  • 用GSM-VI学习权重的后验协方差
  • 减少MC Dropout的采样数

变分自编码器

  • 使用BaM优化encoder的变分参数
  • 加速收敛,减少重建质量波动

生成模型

  • Flow Matching + GSM-VI
  • Score-based模型的变分正则化

9. 理论深度

9.1 Score空间的度量

Fisher散度

与KL散度的关系

9.2 信息几何视角

Score流形

  • 每个分布对应流形上的一个点
  • Score函数定义切空间
  • GSM-VI沿自然梯度方向优化

黎曼度量

  • Fisher信息矩阵定义黎曼度量
  • 自然梯度 = 黎曼最速下降方向
  • GSM-VI使用黎曼几何结构

10. 总结

GSM-VI和BaM的核心贡献

  1. 新视角:将变分推断转化为score matching问题
  2. 低方差:避开ELBO梯度的高方差问题
  3. 闭式解:高斯情况下有解析更新
  4. 指数收敛:BaM有理论收敛保证

实践建议

  1. 对于高斯变分分布,优先尝试GSM-VI
  2. 使用BaM作为通用BBVI加速器
  3. 监控score函数的范数作为收敛指标

参考文献

Footnotes

  1. Shirley, M., et al. (2023). Gaussian Score Matching for Variational Inference. NeurIPS 2023.

  2. Malrande, C., et al. (2024). Batch and Match for Black-Box Variational Inference. ICML 2024.