Score Matching变分推断:GSM-VI与BaM
1. 概述
变分推断(Variational Inference, VI)是贝叶斯深度学习的核心技术之一。传统VI使用**证据下界(ELBO)**作为目标函数:
然而,**Black Box变分推断(BBVI)**存在一个根本性问题:梯度方差过大。当使用score function estimator时:
梯度估计的高方差导致收敛缓慢,需要大量Monte Carlo样本。
**Score Matching VI(GSM-VI)和Batch and Match(BaM)**提供了全新的优化视角:将变分推断问题转化为score matching问题,从而绕过ELBO梯度估计的困难。
2. 从ELBO到Score Matching
2.1 ELBO梯度的方差问题
ELBO的两种梯度形式:
Score Function Estimator:
其中 。
问题:
- 当 峰度较高时, 和 都可能很大
- 乘积的方差是两者方差的乘积项
- 需要大量Monte Carlo样本才能得到可靠估计
Pathwise Gradient(重参数化):
问题:
- 要求 可重参数化
- 隐变量通常是离散或混合分布时不可用
- 对网络结构有限制
2.2 Score Matching的直觉
核心观察:KL散度最小化
Score Function定义:
Score Matching目标:
物理意义:最小化真实分布和变分分布在score函数空间的距离。
2.3 为什么Score Matching更稳定?
Score Function的特性:
- Score是梯度的方向,天然具有单位范数的性质
- 不依赖于似然的绝对尺度
- 对概率归一化不敏感
对比:
| 方面 | ELBO (BBVI) | Score Matching |
|---|---|---|
| 梯度尺度 | 依赖似然绝对值 | 归一化到单位球 |
| 方差 | 高(需要大量样本) | 低(几何结构稳定) |
| 可处理分布 | 可重参数化 | 任意可微分分布 |
| 计算成本 | O(1/M),M=样本数 | O(1)(单样本即可) |
3. GSM-VI:Score Matching变分推断
3.1 GSM-VI目标函数
论文:Gaussian Score Matching for Variational Inference (NeurIPS 2023)
GSM-VI针对高斯变分分布推导了简化的score matching目标。
高斯分布的Score:
Score Matching for Gaussian:
简化推导:
最终目标(忽略常数):
3.2 期望的闭式计算
关键发现:期望可以解析计算!
GSM-VI最终目标:
结论:高斯变分分布的GSM目标只依赖于 (协方差),与 无关!
3.3 与ELBO的联系
高斯ELBO:
对比:
- ELBO梯度涉及 和 ,方差大
- GSM-VI梯度只涉及 ,方差小
- GSM不依赖对数似然的梯度,只依赖其值
3.4 梯度计算
GSM-VI梯度:
更新方向:最大化 等价于最小化 的特征值。
物理意义:GSM-VI鼓励变分分布的协方差更加”紧凑”,趋向先验或最大熵分布。
4. BaM:Batch and Match
4.1 核心思想
论文:Batch and Match for Black-Box Variational Inference (ICML 2024)
BaM将GSM-VI的思想进一步发展,提出闭式近端更新:
问题设置:
解:最小化 和 在score空间的差异。
4.2 Proximal Update
近端算子:
物理意义:在保持与当前分布 接近的同时,向参考分布 移动。
4.3 闭式解推导
对于高斯分布 ,:
闭式近端更新:
类比:这类似于指数加权移动平均,但作用于协方差矩阵。
4.4 BaM算法流程
def batch_and_match(model, data, n_batches=10, alpha=0.1):
"""
Batch and Match for BBVI
Args:
model: Neural network with variational params
data: Dataset
n_batches: Number of data batches for Fisher estimation
alpha: Proximal step size
"""
phi = model.variational_params
for iteration in range(n_iterations):
# 1. 计算当前变分分布的统计量
grads = []
for batch in data.random_batches(n_batches):
loss = model.elbo(batch)
grads.append(torch.autograd.grad(loss, phi))
# 2. 估计Fisher信息(使用样本梯度)
F_hat = estimate_fisher(grads)
# 3. BaM闭式更新
phi = prox_update(phi, F_hat, alpha)
return phi
def prox_update(phi, F_hat, alpha):
"""闭式近端更新"""
mu, Sigma = phi['mu'], phi['Sigma']
# Fisher对齐目标分布
Sigma_0 = F_hat # Fisher作为目标分布
# 近端更新
Sigma_new = torch.linalg.inv(
(1 - alpha) * torch.linalg.inv(Sigma) + alpha * torch.linalg.inv(Sigma_0)
)
mu_new = Sigma_new @ (
(1 - alpha) * torch.linalg.inv(Sigma) @ mu +
alpha * torch.linalg.inv(Sigma_0) @ torch.zeros_like(mu)
)
return {'mu': mu_new, 'Sigma': Sigma_new}5. 收敛性分析
5.1 GSM-VI的收敛保证
定理(GSM-VI论文):对于高斯变分分布,GSM目标的最大化等价于KL散度的最小化(在常数项内)。
证明概要:
- KL散度可以写成score函数的期望
- Score matching恰好是最小化这个期望
- 闭式期望计算保证无偏性
5.2 BaM的指数收敛
定理(BaM论文):当目标分布 是高斯且步长 满足 时,BaM更新指数收敛到最优解。
收敛速率:
物理意义:每步减少 的总变差距离。
5.3 与BBVI的对比
| 方面 | BBVI | GSM-VI/BaM |
|---|---|---|
| 梯度方差 | O(1/√M) | O(1)(闭式) |
| 收敛速度 | 慢(高方差) | 快(低方差) |
| 样本需求 | 100-1000 | 1-10 |
| 实现复杂度 | 中等 | 低 |
| 理论保证 | 渐近 | 指数收敛 |
6. 实践实现
6.1 PyTorch实现
import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
class GSMVI(nn.Module):
"""Gaussian Score Matching VI"""
def __init__(self, dim, prior_std=1.0):
super().__init__()
self.dim = dim
# 变分参数(对数协方差)
self.log_var = nn.Parameter(torch.zeros(dim))
self.prior_std = prior_std
@property
def sigma(self):
"""协方差矩阵(对角)"""
return torch.exp(self.log_var).diag_embed()
@property
def sigma_inv(self):
"""协方差逆矩阵"""
return torch.exp(-self.log_var).diag_embed()
def gsm_objective(self, log_likelihood_fn, x, n_samples=1):
"""
GSM-VI目标函数
Args:
log_likelihood_fn: log p(x|z) 函数
x: 观测数据
n_samples: Monte Carlo样本数
"""
batch_size = x.size(0)
z = torch.randn(batch_size, self.dim, device=x.device) @ self.sigma.cholesky()
z = z + self.variational_mean.unsqueeze(0) # 添加均值
# Score matching目标
score_norm = torch.sum((z - self.variational_mean)**2 * torch.exp(-self.log_var))
# Tracy-Widom项
tr_sigma_inv = torch.sum(torch.exp(-self.log_var))
# Score matching损失
loss = score_norm - 2 * self.dim * tr_sigma_inv
# 添加似然项(简化版本)
with torch.no_grad():
ll = log_likelihood_fn(x, z)
loss = loss - ll.mean() # 负ELBO代理
return loss / batch_size
def fit(self, log_likelihood_fn, x, lr=0.01, n_epochs=100):
"""训练"""
optimizer = torch.optim.Adam([self.log_var], lr=lr)
for epoch in range(n_epochs):
optimizer.zero_grad()
loss = self.gsm_objective(log_likelihood_fn, x)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
class BaMOptimizer:
"""Batch and Match优化器"""
def __init__(self, var_params, alpha=0.1):
self.alpha = alpha
self.var_params = var_params
# 初始化
self.mu = var_params['mu'].clone()
self.Sigma = var_params['Sigma'].clone()
def compute_fisher_estimate(self, grad_samples):
"""
从梯度样本估计Fisher信息矩阵
使用empirical Fisher: F ≈ (1/M) Σ ∇ℓ ∇ℓ⊤
"""
grad_stack = torch.stack(grad_samples)
M = grad_stack.size(0)
# 样本协方差
F_hat = torch.cov(grad_stack.t())
# 对角近似(更稳定)
F_hat = torch.diag(torch.diag(F_hat) + 1e-3)
return F_hat
def proximal_update(self, F_hat):
"""闭式近端更新"""
# 协方差更新
Sigma_inv = torch.linalg.inv(self.Sigma)
F_hat_inv = torch.linalg.inv(F_hat + 1e-3 * torch.eye_like(F_hat))
Sigma_new = torch.linalg.inv(
(1 - self.alpha) * Sigma_inv + self.alpha * F_hat_inv
)
# 均值更新(假设先验均值=0)
mu_new = Sigma_new @ ((1 - self.alpha) * Sigma_inv @ self.mu)
self.Sigma = Sigma_new
self.mu = mu_new
return self.mu, self.Sigma
def step(self, grad_fn, batch_size=10):
"""
执行一步BaM更新
Args:
grad_fn: 返回梯度样本的函数
batch_size: 样本数
"""
# 收集梯度样本
grad_samples = []
for _ in range(batch_size):
grads = grad_fn()
grad_samples.append(grads)
# 估计Fisher
F_hat = self.compute_fisher_estimate(grad_samples)
# 近端更新
self.proximal_update(F_hat)
return self.mu, self.Sigma6.2 应用示例:贝叶斯线性回归
def bayesian_linear_regression_gsm(X, y, alpha_prior=1.0, n_epochs=500):
"""
GSM-VI贝叶斯线性回归
Args:
X: (n, d) 设计矩阵
y: (n,) 目标值
alpha_prior: 先验精度
"""
n, d = X.shape
# 初始化
model = GSMVI(dim=d, prior_std=1.0)
def log_likelihood(z):
"""log p(y|X, z) = N(y; Xz, σ²I)"""
pred = X @ z.unsqueeze(-1)
return -0.5 * torch.sum((y.unsqueeze(-1) - pred)**2)
# 训练
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(n_epochs):
optimizer.zero_grad()
# GSM-VI目标
loss = model.gsm_objective(log_likelihood, torch.zeros(d))
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
# 后验均值和方差
posterior_mean = model.variational_mean.detach()
posterior_var = torch.exp(model.log_var.detach())
return posterior_mean, posterior_var
def bayesian_linear_regression_bam(X, y, n_batches=10):
"""
BaM贝叶斯线性回归
"""
n, d = X.shape
# 初始化变分参数
var_params = {
'mu': torch.zeros(d),
'Sigma': torch.eye(d),
}
optimizer = BaMOptimizer(var_params, alpha=0.1)
# 定义梯度函数
def grad_fn():
z = torch.randn(d)
pred = X @ z
loss = torch.sum((y - pred)**2)
grads = torch.autograd.grad(loss, var_params['mu'])
return grads[0]
# 训练
for _ in range(n_batches):
optimizer.step(grad_fn)
return var_params['mu'], var_params['Sigma']7. 与其他方法的对比
7.1 方法对比表
| 方法 | 梯度方差 | 收敛速度 | 实现难度 | 适用范围 |
|---|---|---|---|---|
| BBVI (Score Function) | 高 | 慢 | 中 | 任意分布 |
| BBVI (Pathwise) | 低 | 快 | 中 | 可重参数化 |
| REINFORCE | 极高 | 很慢 | 低 | 任意分布 |
| GSM-VI | 零 | 极快 | 低 | 高斯分布 |
| BaM | 零 | 指数 | 低 | 任意分布 |
| Laplace | N/A | N/A | 低 | MAP估计 |
7.2 选择指南
选择GSM-VI/BaM当:
- 变分分布是高斯的
- 需要快速收敛
- 梯度方差是关键瓶颈
- 资源有限(减少Monte Carlo样本)
选择BBVI当:
- 变分分布是非高斯的
- 需要更灵活的近似族
- 可以负担大量样本
混合方法:
- 初期使用GSM-VI/BaM快速收敛
- 后期切换到BBVI精细调节
8. 扩展与应用
8.1 非高斯变分分布
混合高斯:
Score Matching应用于每个成分,然后加权合并。
归一化流:
使用Jacobian行列式修正score函数。
8.2 深度学习应用
贝叶斯神经网络:
- 用GSM-VI学习权重的后验协方差
- 减少MC Dropout的采样数
变分自编码器:
- 使用BaM优化encoder的变分参数
- 加速收敛,减少重建质量波动
生成模型:
- Flow Matching + GSM-VI
- Score-based模型的变分正则化
9. 理论深度
9.1 Score空间的度量
Fisher散度:
与KL散度的关系:
9.2 信息几何视角
Score流形:
- 每个分布对应流形上的一个点
- Score函数定义切空间
- GSM-VI沿自然梯度方向优化
黎曼度量:
- Fisher信息矩阵定义黎曼度量
- 自然梯度 = 黎曼最速下降方向
- GSM-VI使用黎曼几何结构
10. 总结
GSM-VI和BaM的核心贡献:
- 新视角:将变分推断转化为score matching问题
- 低方差:避开ELBO梯度的高方差问题
- 闭式解:高斯情况下有解析更新
- 指数收敛:BaM有理论收敛保证
实践建议:
- 对于高斯变分分布,优先尝试GSM-VI
- 使用BaM作为通用BBVI加速器
- 监控score函数的范数作为收敛指标