函数空间变分推断
1 引言
传统的变分推断(VI)在参数空间进行:将神经网络权重 作为随机变量,引入先验 和变分分布 ,通过优化ELBO近似后验 。1
然而,参数空间VI面临三大核心挑战:
| 问题 | 描述 | 影响 |
|---|---|---|
| 先验指定困难 | 如何为百万级参数指定有意义的高斯先验? | 先验与数据可能不匹配 |
| 后验病态 | 神经网络后验在高维参数空间中可能严重非高斯、多峰 | 高斯近似不准确 |
| 维度诅咒 | 参数数量 巨大,KL散度计算困难 | 可扩展性受限 |
函数空间变分推断(Function-Space VI, FS-VI) 将推断从参数空间转移到函数空间,直接对函数的分布进行变分推断,从根本上绕过了上述问题。
本文系统介绍FS-VI的理论框架、实现方法,以及与Neural Tangent Kernel(NTK)理论的联系。
2 参数空间VI的问题
2.1 先验与数据的不匹配
在参数空间VI中,先验 完全独立于数据和任务。
问题:对于深度网络,这个先验在参数空间中均匀分布,但对函数行为的影响高度不均匀:
- 改变一个参数的符号可能导致函数行为的巨大变化
- 不同参数对网络输出的敏感度差异可达 倍
- 参数空间的”近邻”不一定是函数空间的”近邻”
2.2 后验的复杂结构
神经网络的真实后验 在参数空间中可能具有:
- 多峰结构:不同峰值对应于不同的对称性(权重交换、符号翻转等)
- 窄谷结构:后验集中在参数空间中的低维流形上
- 高度各向异性:不同方向的曲率差异巨大
高斯变分分布 完全无法捕获这些结构。
2.3 KL散度的计算困难
对于 参数的网络,计算 需要存储完整的协方差矩阵 ,这在实践中是不可能的。
即使使用对角协方差近似 ,计算和存储 个方差项仍然很昂贵。
3 函数空间VI的理论框架
3.1 从参数空间到函数空间
核心思想:不推断参数的后验分布,而是推断函数的后验分布。
设 是函数空间(例如,所有从 到 的神经网络函数)。FS-VI的目标是:
其中 是函数(而非参数向量), 是函数后验。
3.2 函数空间的概率分布
函数空间上的概率分布可以通过随机函数来定义:
其中 是高斯过程, 是均值函数, 是协方差函数(核函数)。
关键洞察:神经网络在无限宽度极限下就是高斯过程,其核函数为Neural Tangent Kernel (NTK):
3.3 函数空间ELBO
给定观测数据 ,函数空间ELBO为:
关键区别:KL散度现在是在函数空间中计算的,而非参数空间:
3.4 函数空间先验的构造
方案1:NTK-GP先验
使用NTK作为核函数:
方案2:数据依赖先验
利用训练数据构造先验:
其中 。
方案3:贝叶斯线性模型先验
将神经网络视为在特征空间 上的贝叶斯线性回归:
4 实现方法
4.1 随机函数近似
FS-VI的核心挑战是如何在有限计算资源下表示和操作函数分布。关键方法是将函数分布投影到有限基上。
基函数分解:任意函数 可以表示为基函数的线性组合:
其中 是基函数, 是系数。
变分分布在系数空间:设 ,则函数分布 通过线性映射从 导出:
4.2 基于NTK的有限维近似
方法1:随机傅里叶特征(RFF)
使用随机傅里叶特征近似NTK核:
其中 是随机投影特征。
import torch
import torch.nn as nn
import numpy as np
class FunctionSpaceVI:
"""
函数空间变分推断
核心思想:在函数空间(而非参数空间)进行变分推断
"""
def __init__(self, base_model, kernel_type='ntk', n_features=512, prior_std=1.0):
self.base_model = base_model
self.kernel_type = kernel_type
self.n_features = n_features
self.prior_std = prior_std
# 冻结的基础模型(用于提取特征)
self.frozen_model = self._freeze_model()
# 随机傅里叶特征(NTK近似)
self.rff_weights = self._init_rff_features()
# 变分参数:在函数空间
self.vi_mean = None
self.vi_cov = None
def _freeze_model(self):
"""冻结基础模型用于特征提取"""
model = copy.deepcopy(self.base_model)
model.eval()
for p in model.parameters():
p.requires_grad = False
return model
def _init_rff_features(self):
"""初始化随机傅里叶特征"""
# NTK特征维度
input_dim = self.base_model.input_dim
output_dim = self.base_model.output_dim
# 随机投影矩阵(用于近似NTK核)
W = torch.randn(input_dim, self.n_features) / np.sqrt(input_dim)
return W
def extract_features(self, x):
"""提取输入x的NTK特征"""
# 简化的NTK特征:使用梯度信息
x.requires_grad_(True)
logits = self.frozen_model(x)
# 计算一阶导数(NTK特征)
features = []
for k in range(logits.shape[-1]):
grad = torch.autograd.grad(
logits[0, k], x,
retain_graph=True, create_graph=True
)[0]
features.append(grad.flatten())
return torch.stack(features, dim=-1)
def compute_kernel_matrix(self, X):
"""计算NTK核矩阵(有限宽度近似)"""
n = X.shape[0]
K = torch.zeros(n, n)
self.frozen_model.zero_grad()
for i in range(n):
xi = X[i:i+1]
xi.requires_grad_(True)
# 计算 NTK: K(x_i, x_j) = <∇_θ f(x_i), ∇_θ f(x_j)>
logits = self.frozen_model(xi)
for j in range(n):
xj = X[j:j+1]
xj.requires_grad_(True)
logits_j = self.frozen_model(xj)
# 简化:使用一阶泰勒展开
K[i, j] = torch.dot(
torch.autograd.grad(logits[0, 0], xi, retain_graph=True)[0].flatten(),
torch.autograd.grad(logits_j[0, 0], xj, retain_graph=True)[0].flatten()
)
return K
def function_space_elbo(self, X, y, q_mean, q_cov):
"""
计算函数空间ELBO
ELBO = E_q[log p(y|X)] - KL(q(f) || p(f))
"""
n = X.shape[0]
# Step 1: 预测均值和方差
# f(x) ≈ N(φ(x)^T μ_w, φ(x)^T Σ_w φ(x))
pred_mean = X @ q_mean
pred_var = torch.diag(X @ q_cov @ X.T)
# Step 2: 数据拟合项
# E_q[log p(y|x)] = E_{f~q}[log p(y|f(x))]
log_likelihood = -0.5 * torch.sum(
(y - pred_mean) ** 2 / (pred_var + 1e-6)
) - 0.5 * torch.sum(torch.log(pred_var + 1e-6))
# Step 3: KL散度(在函数空间)
# KL(q(f) || p(f)) = 0.5 * tr(K^{-1} Σ) + 0.5 * ||μ||_K^2 - 0.5 * log|Σ|
# 简化为对角近似
prior_var = self.prior_std ** 2
kl_div = 0.5 * torch.sum(
q_cov.diag() / prior_var
+ (q_mean ** 2) / prior_var
- 1
- torch.log(q_cov.diag() / prior_var + 1e-8)
)
return log_likelihood - kl_div, log_likelihood, kl_div
def fit(self, X, y, lr=1e-3, n_iterations=1000):
"""
优化函数空间变分参数
"""
n_features = X.shape[1]
# 初始化变分参数
self.vi_mean = torch.zeros(n_features, requires_grad=True)
self.vi_cov_diag = torch.ones(n_features, requires_grad=True)
optimizer = torch.optim.Adam([self.vi_mean, self.vi_cov_diag], lr=lr)
for iteration in range(n_iterations):
optimizer.zero_grad()
# 构建协方差矩阵(对角近似)
q_cov = torch.diag(torch.exp(self.vi_cov_diag) + 1e-6)
# 计算ELBO
elbo, ll, kl = self.function_space_elbo(X, y, self.vi_mean, q_cov)
loss = -elbo # 最大化ELBO = 最小化负ELBO
loss.backward()
optimizer.step()
if (iteration + 1) % 100 == 0:
print(f"Iter {iteration+1}: ELBO={elbo.item():.3f}, "
f"LL={ll.item():.3f}, KL={kl.item():.3f}")
return self.vi_mean.detach(), torch.diag(torch.exp(self.vi_cov_diag)).detach()4.3 与NTK的联系
定理1(FS-VI与NTK的等价性):当网络宽度趋于无穷且变分分布为高斯时,FS-VI的后验预测分布与NTK-GP后验预测分布一致:
其中:
4.4 有限宽度修正
在有限宽度网络中,NTK是固定不变的(不随训练变化)。FS-VI需要考虑有限宽度效应:
def finite_width_correction(self, X, y, base_vi_mean, base_vi_cov,
update_cov=True):
"""
有限宽度修正:考虑网络学习特征的能力
有限宽度网络中,特征 Φ(X) 不是固定的,而是随训练变化
"""
# Step 1: 估计当前特征
current_features = self.extract_features_batch(X)
# Step 2: 修正先验协方差
# 有限宽度:K_NTK(t) ≈ K_NTK(∞) - c/t(核退化)
t = self.base_model.n_parameters
correction_factor = 1.0 / (1.0 + 1.0 / t)
adjusted_cov = base_vi_cov * correction_factor
# Step 3: 如果需要,更新协方差
if update_cov:
# 使用经验Fisher信息调整
empirical_fisher = self._estimate_fisher(X, y)
adjusted_cov = adjusted_cov * (1.0 + empirical_fisher)
return base_vi_mean, adjusted_cov5 与参数空间VI的比较
5.1 理论比较
| 维度 | 参数空间VI | 函数空间VI |
|---|---|---|
| 推断对象 | 参数后验 $p(\theta | \mathcal{D})$ |
| 先验形式 | ||
| 先验解释 | 参数的先验信念 | 函数的先验行为 |
| 后验结构 | 可能高度非高斯 | 通常更接近高斯过程 |
| 维度依赖性 | 线性(参数数量) | 对数(核矩阵维度) |
| 可解释性 | 低 | 高(函数空间更直观) |
5.2 实践比较
| 维度 | 参数空间VI | 函数空间VI |
|---|---|---|
| 计算成本 | per step | per step |
| 内存占用 | ||
| 不确定性质量 | 依赖后验近似质量 | 通常更好(函数空间更自然) |
| 超参数 | (方差) | 核函数 + |
| 实现复杂度 | 中等 | 较高 |
5.3 混合方法
结合参数空间VI和函数空间VI的优点:
方案:层次化变分推断
- 第一层:在函数空间指定先验
- 第二层:在参数空间近似后验
- 第三层:优化 的超参数
class HierarchicalFunctionSpaceVI:
"""
层次化函数空间变分推断
层次1: 函数空间先验 p(f)
层次2: 参数空间后验 q(θ|f) = N(μ(f), σ^2 I)
层次3: 函数空间变分 q(f)
"""
def __init__(self, model, kernel_fn):
self.model = model
self.kernel_fn = kernel_fn # e.g., NTK kernel
def hierarchical_elbo(self, X, y, q_f_mean, q_f_cov,
q_theta_given_f_std):
"""
层次化ELBO
ELBO = E_{q(f)}[E_{q(θ|f)}[log p(y|θ)]] - KL(q(f)||p(f))
"""
# 第一层:函数空间变分
kl_f = self._function_space_kl(q_f_mean, q_f_cov)
# 第二层:参数空间条件期望
# E_{q(θ|f)}[log p(y|θ)] ≈ log p(y|μ(f))
pred_mean = X @ q_f_mean # 简化
data_fit = -0.5 * torch.sum((y - pred_mean) ** 2)
return data_fit - kl_f, data_fit, kl_f6 FS-VI与Neural Tangent Kernel的联系
6.1 NTK视角下的FS-VI
NTK定义:神经网络对输入 的输出 ,其梯度为:
NTK定义为两个输入点的梯度内积:
6.2 FS-VI作为NTK上的贝叶斯推断
将网络输出视为特征空间上的线性模型:
则FS-VI等价于在这个特征空间上进行贝叶斯线性回归:
- 先验:
- 似然:
- 后验:,其中:
6.3 NTK固定 vs NTK学习
| 模式 | NTK行为 | FS-VI行为 | 表达能力 |
|---|---|---|---|
| NTK固定(无限宽度) | 训练前后不变 | 等价于GP推断 | 有限 |
| NTK学习(有限宽度) | 训练中变化 | 动态核推断 | 更强 |
| FS-VI + NTK学习 | 通过后验采样捕获 | 自适应核调整 | 最强 |
7 应用场景
7.1 不确定性量化
FS-VI为深度网络提供自然的不确定性估计:
def predict_with_uncertainty(model, x_test, q_f_mean, q_f_cov):
"""
使用FS-VI进行不确定性感知预测
"""
# 预测均值
pred_mean = x_test @ q_f_mean
# 预测方差(包含认知不确定性和偶然不确定性)
# Var[f(x*)] = Var_posterior + Var_noise
pred_variance = torch.diag(
x_test @ q_f_cov @ x_test.T
) + model.noise_variance
# 分解不确定性
epistemic_uncertainty = torch.diag(
x_test @ q_f_cov @ x_test.T
)
aleatoric_uncertainty = model.noise_variance
return pred_mean, pred_variance, epistemic_uncertainty, aleatoric_uncertainty7.2 主动学习
FS-VI自然地提供获取函数(acquisition function):
def active_learning_acquisition(model, x_candidates, q_f_mean, q_f_cov):
"""
基于FS-VI的主动学习获取函数
使用预测方差作为获取函数
"""
pred_variance = torch.diag(
x_candidates @ q_f_cov @ x_candidates.T
)
# 多种获取函数
variance_acq = pred_variance # 方差最大化
mean_std_acq = torch.sqrt(pred_variance) # 标准差
confidence_bound = -torch.abs(x_candidates @ q_f_mean) + 2 * torch.sqrt(pred_variance) # UCB
return {
'variance': variance_acq,
'std': mean_std_acq,
'ucb': confidence_bound
}7.3 安全关键应用
FS-VI提供个体级别的预测认证(参见 PAC-Bayes深度网络风险认证):
def certify_prediction(x, y_true, q_f_mean, q_f_cov, delta=0.05):
"""
为FS-VI预测提供PAC-Bayes风格的风险认证
"""
# 预测分布
pred_mean = x @ q_f_mean
pred_std = torch.sqrt(torch.diag(x @ q_f_cov @ x.T))
# 基于预测分布的风险上界
# P(|f(x) - y| > ε) ≤ exp(-ε²/(2σ²))
epsilon = torch.sqrt(-2 * pred_std**2 * torch.log(delta))
# 风险上界
risk_bound = torch.norm(pred_mean - y_true) + epsilon
return {
'prediction': pred_mean.item(),
'uncertainty': pred_std.item(),
'risk_bound': risk_bound.item(),
'confidence_level': 1 - delta
}8 总结
8.1 核心贡献
- 函数空间范式转换:将变分推断从参数空间转移到函数空间,解决先验指定和后验病态问题
- 与NTK的统一:揭示了FS-VI与Neural Tangent Kernel的理论联系
- 不确定性量化:FS-VI自然地提供高质量的不确定性估计
- 实践框架:提供了完整的算法实现和代码模板
8.2 与本 Wiki 其他内容的联系
Footnotes
-
Wu, M. et al. (2025). “Bridging the Gap Between Variational Inference and Stochastic Gradient MCMC in Function Space.” ICLR 2025. ↩