函数空间变分推断

1 引言

传统的变分推断(VI)在参数空间进行:将神经网络权重 作为随机变量,引入先验 和变分分布 ,通过优化ELBO近似后验 1

然而,参数空间VI面临三大核心挑战:

问题描述影响
先验指定困难如何为百万级参数指定有意义的高斯先验?先验与数据可能不匹配
后验病态神经网络后验在高维参数空间中可能严重非高斯、多峰高斯近似不准确
维度诅咒参数数量 巨大,KL散度计算困难可扩展性受限

函数空间变分推断(Function-Space VI, FS-VI) 将推断从参数空间转移到函数空间,直接对函数的分布进行变分推断,从根本上绕过了上述问题。

本文系统介绍FS-VI的理论框架、实现方法,以及与Neural Tangent Kernel(NTK)理论的联系。


2 参数空间VI的问题

2.1 先验与数据的不匹配

在参数空间VI中,先验 完全独立于数据和任务。

问题:对于深度网络,这个先验在参数空间中均匀分布,但对函数行为的影响高度不均匀:

  • 改变一个参数的符号可能导致函数行为的巨大变化
  • 不同参数对网络输出的敏感度差异可达
  • 参数空间的”近邻”不一定是函数空间的”近邻”

2.2 后验的复杂结构

神经网络的真实后验 在参数空间中可能具有:

  1. 多峰结构:不同峰值对应于不同的对称性(权重交换、符号翻转等)
  2. 窄谷结构:后验集中在参数空间中的低维流形上
  3. 高度各向异性:不同方向的曲率差异巨大

高斯变分分布 完全无法捕获这些结构。

2.3 KL散度的计算困难

对于 参数的网络,计算 需要存储完整的协方差矩阵 ,这在实践中是不可能的。

即使使用对角协方差近似 ,计算和存储 个方差项仍然很昂贵。


3 函数空间VI的理论框架

3.1 从参数空间到函数空间

核心思想:不推断参数的后验分布,而是推断函数的后验分布

是函数空间(例如,所有从 的神经网络函数)。FS-VI的目标是:

其中 是函数(而非参数向量), 是函数后验。

3.2 函数空间的概率分布

函数空间上的概率分布可以通过随机函数来定义:

其中 是高斯过程, 是均值函数, 是协方差函数(核函数)。

关键洞察:神经网络在无限宽度极限下就是高斯过程,其核函数为Neural Tangent Kernel (NTK):

3.3 函数空间ELBO

给定观测数据 ,函数空间ELBO为:

关键区别:KL散度现在是在函数空间中计算的,而非参数空间:

3.4 函数空间先验的构造

方案1:NTK-GP先验

使用NTK作为核函数:

方案2:数据依赖先验

利用训练数据构造先验:

其中

方案3:贝叶斯线性模型先验

将神经网络视为在特征空间 上的贝叶斯线性回归:


4 实现方法

4.1 随机函数近似

FS-VI的核心挑战是如何在有限计算资源下表示和操作函数分布。关键方法是将函数分布投影到有限基上。

基函数分解:任意函数 可以表示为基函数的线性组合:

其中 是基函数, 是系数。

变分分布在系数空间:设 ,则函数分布 通过线性映射从 导出:

4.2 基于NTK的有限维近似

方法1:随机傅里叶特征(RFF)

使用随机傅里叶特征近似NTK核:

其中 是随机投影特征。

import torch
import torch.nn as nn
import numpy as np
 
class FunctionSpaceVI:
    """
    函数空间变分推断
    
    核心思想:在函数空间(而非参数空间)进行变分推断
    """
    def __init__(self, base_model, kernel_type='ntk', n_features=512, prior_std=1.0):
        self.base_model = base_model
        self.kernel_type = kernel_type
        self.n_features = n_features
        self.prior_std = prior_std
        
        # 冻结的基础模型(用于提取特征)
        self.frozen_model = self._freeze_model()
        
        # 随机傅里叶特征(NTK近似)
        self.rff_weights = self._init_rff_features()
        
        # 变分参数:在函数空间
        self.vi_mean = None
        self.vi_cov = None
        
    def _freeze_model(self):
        """冻结基础模型用于特征提取"""
        model = copy.deepcopy(self.base_model)
        model.eval()
        for p in model.parameters():
            p.requires_grad = False
        return model
    
    def _init_rff_features(self):
        """初始化随机傅里叶特征"""
        # NTK特征维度
        input_dim = self.base_model.input_dim
        output_dim = self.base_model.output_dim
        
        # 随机投影矩阵(用于近似NTK核)
        W = torch.randn(input_dim, self.n_features) / np.sqrt(input_dim)
        return W
    
    def extract_features(self, x):
        """提取输入x的NTK特征"""
        # 简化的NTK特征:使用梯度信息
        x.requires_grad_(True)
        logits = self.frozen_model(x)
        
        # 计算一阶导数(NTK特征)
        features = []
        for k in range(logits.shape[-1]):
            grad = torch.autograd.grad(
                logits[0, k], x, 
                retain_graph=True, create_graph=True
            )[0]
            features.append(grad.flatten())
        
        return torch.stack(features, dim=-1)
    
    def compute_kernel_matrix(self, X):
        """计算NTK核矩阵(有限宽度近似)"""
        n = X.shape[0]
        K = torch.zeros(n, n)
        
        self.frozen_model.zero_grad()
        for i in range(n):
            xi = X[i:i+1]
            xi.requires_grad_(True)
            
            # 计算 NTK: K(x_i, x_j) = <∇_θ f(x_i), ∇_θ f(x_j)>
            logits = self.frozen_model(xi)
            for j in range(n):
                xj = X[j:j+1]
                xj.requires_grad_(True)
                
                logits_j = self.frozen_model(xj)
                # 简化:使用一阶泰勒展开
                K[i, j] = torch.dot(
                    torch.autograd.grad(logits[0, 0], xi, retain_graph=True)[0].flatten(),
                    torch.autograd.grad(logits_j[0, 0], xj, retain_graph=True)[0].flatten()
                )
        
        return K
    
    def function_space_elbo(self, X, y, q_mean, q_cov):
        """
        计算函数空间ELBO
        
        ELBO = E_q[log p(y|X)] - KL(q(f) || p(f))
        """
        n = X.shape[0]
        
        # Step 1: 预测均值和方差
        # f(x) ≈ N(φ(x)^T μ_w, φ(x)^T Σ_w φ(x))
        pred_mean = X @ q_mean
        pred_var = torch.diag(X @ q_cov @ X.T)
        
        # Step 2: 数据拟合项
        # E_q[log p(y|x)] = E_{f~q}[log p(y|f(x))]
        log_likelihood = -0.5 * torch.sum(
            (y - pred_mean) ** 2 / (pred_var + 1e-6)
        ) - 0.5 * torch.sum(torch.log(pred_var + 1e-6))
        
        # Step 3: KL散度(在函数空间)
        # KL(q(f) || p(f)) = 0.5 * tr(K^{-1} Σ) + 0.5 * ||μ||_K^2 - 0.5 * log|Σ|
        # 简化为对角近似
        prior_var = self.prior_std ** 2
        kl_div = 0.5 * torch.sum(
            q_cov.diag() / prior_var 
            + (q_mean ** 2) / prior_var 
            - 1 
            - torch.log(q_cov.diag() / prior_var + 1e-8)
        )
        
        return log_likelihood - kl_div, log_likelihood, kl_div
    
    def fit(self, X, y, lr=1e-3, n_iterations=1000):
        """
        优化函数空间变分参数
        """
        n_features = X.shape[1]
        
        # 初始化变分参数
        self.vi_mean = torch.zeros(n_features, requires_grad=True)
        self.vi_cov_diag = torch.ones(n_features, requires_grad=True)
        
        optimizer = torch.optim.Adam([self.vi_mean, self.vi_cov_diag], lr=lr)
        
        for iteration in range(n_iterations):
            optimizer.zero_grad()
            
            # 构建协方差矩阵(对角近似)
            q_cov = torch.diag(torch.exp(self.vi_cov_diag) + 1e-6)
            
            # 计算ELBO
            elbo, ll, kl = self.function_space_elbo(X, y, self.vi_mean, q_cov)
            
            loss = -elbo  # 最大化ELBO = 最小化负ELBO
            loss.backward()
            optimizer.step()
            
            if (iteration + 1) % 100 == 0:
                print(f"Iter {iteration+1}: ELBO={elbo.item():.3f}, "
                      f"LL={ll.item():.3f}, KL={kl.item():.3f}")
        
        return self.vi_mean.detach(), torch.diag(torch.exp(self.vi_cov_diag)).detach()

4.3 与NTK的联系

定理1(FS-VI与NTK的等价性):当网络宽度趋于无穷且变分分布为高斯时,FS-VI的后验预测分布与NTK-GP后验预测分布一致:

其中:

4.4 有限宽度修正

在有限宽度网络中,NTK是固定不变的(不随训练变化)。FS-VI需要考虑有限宽度效应

def finite_width_correction(self, X, y, base_vi_mean, base_vi_cov, 
                            update_cov=True):
    """
    有限宽度修正:考虑网络学习特征的能力
    
    有限宽度网络中,特征 Φ(X) 不是固定的,而是随训练变化
    """
    # Step 1: 估计当前特征
    current_features = self.extract_features_batch(X)
    
    # Step 2: 修正先验协方差
    # 有限宽度:K_NTK(t) ≈ K_NTK(∞) - c/t(核退化)
    t = self.base_model.n_parameters
    correction_factor = 1.0 / (1.0 + 1.0 / t)
    
    adjusted_cov = base_vi_cov * correction_factor
    
    # Step 3: 如果需要,更新协方差
    if update_cov:
        # 使用经验Fisher信息调整
        empirical_fisher = self._estimate_fisher(X, y)
        adjusted_cov = adjusted_cov * (1.0 + empirical_fisher)
    
    return base_vi_mean, adjusted_cov

5 与参数空间VI的比较

5.1 理论比较

维度参数空间VI函数空间VI
推断对象参数后验 $p(\theta\mathcal{D})$
先验形式
先验解释参数的先验信念函数的先验行为
后验结构可能高度非高斯通常更接近高斯过程
维度依赖性线性(参数数量)对数(核矩阵维度)
可解释性高(函数空间更直观)

5.2 实践比较

维度参数空间VI函数空间VI
计算成本 per step per step
内存占用
不确定性质量依赖后验近似质量通常更好(函数空间更自然)
超参数(方差)核函数 +
实现复杂度中等较高

5.3 混合方法

结合参数空间VI和函数空间VI的优点:

方案:层次化变分推断

  • 第一层:在函数空间指定先验
  • 第二层:在参数空间近似后验
  • 第三层:优化 的超参数
class HierarchicalFunctionSpaceVI:
    """
    层次化函数空间变分推断
    
    层次1: 函数空间先验 p(f)
    层次2: 参数空间后验 q(θ|f) = N(μ(f), σ^2 I)
    层次3: 函数空间变分 q(f)
    """
    def __init__(self, model, kernel_fn):
        self.model = model
        self.kernel_fn = kernel_fn  # e.g., NTK kernel
        
    def hierarchical_elbo(self, X, y, q_f_mean, q_f_cov, 
                          q_theta_given_f_std):
        """
        层次化ELBO
        
        ELBO = E_{q(f)}[E_{q(θ|f)}[log p(y|θ)]] - KL(q(f)||p(f))
        """
        # 第一层:函数空间变分
        kl_f = self._function_space_kl(q_f_mean, q_f_cov)
        
        # 第二层:参数空间条件期望
        # E_{q(θ|f)}[log p(y|θ)] ≈ log p(y|μ(f))
        pred_mean = X @ q_f_mean  # 简化
        data_fit = -0.5 * torch.sum((y - pred_mean) ** 2)
        
        return data_fit - kl_f, data_fit, kl_f

6 FS-VI与Neural Tangent Kernel的联系

6.1 NTK视角下的FS-VI

NTK定义:神经网络对输入 的输出 ,其梯度为:

NTK定义为两个输入点的梯度内积:

6.2 FS-VI作为NTK上的贝叶斯推断

将网络输出视为特征空间上的线性模型:

则FS-VI等价于在这个特征空间上进行贝叶斯线性回归

  • 先验
  • 似然
  • 后验,其中:

6.3 NTK固定 vs NTK学习

模式NTK行为FS-VI行为表达能力
NTK固定(无限宽度)训练前后不变等价于GP推断有限
NTK学习(有限宽度)训练中变化动态核推断更强
FS-VI + NTK学习通过后验采样捕获自适应核调整最强

7 应用场景

7.1 不确定性量化

FS-VI为深度网络提供自然的不确定性估计

def predict_with_uncertainty(model, x_test, q_f_mean, q_f_cov):
    """
    使用FS-VI进行不确定性感知预测
    """
    # 预测均值
    pred_mean = x_test @ q_f_mean
    
    # 预测方差(包含认知不确定性和偶然不确定性)
    # Var[f(x*)] = Var_posterior + Var_noise
    pred_variance = torch.diag(
        x_test @ q_f_cov @ x_test.T
    ) + model.noise_variance
    
    # 分解不确定性
    epistemic_uncertainty = torch.diag(
        x_test @ q_f_cov @ x_test.T
    )
    aleatoric_uncertainty = model.noise_variance
    
    return pred_mean, pred_variance, epistemic_uncertainty, aleatoric_uncertainty

7.2 主动学习

FS-VI自然地提供获取函数(acquisition function):

def active_learning_acquisition(model, x_candidates, q_f_mean, q_f_cov):
    """
    基于FS-VI的主动学习获取函数
    
    使用预测方差作为获取函数
    """
    pred_variance = torch.diag(
        x_candidates @ q_f_cov @ x_candidates.T
    )
    
    # 多种获取函数
    variance_acq = pred_variance  # 方差最大化
    mean_std_acq = torch.sqrt(pred_variance)  # 标准差
    confidence_bound = -torch.abs(x_candidates @ q_f_mean) + 2 * torch.sqrt(pred_variance)  # UCB
    
    return {
        'variance': variance_acq,
        'std': mean_std_acq, 
        'ucb': confidence_bound
    }

7.3 安全关键应用

FS-VI提供个体级别的预测认证(参见 PAC-Bayes深度网络风险认证):

def certify_prediction(x, y_true, q_f_mean, q_f_cov, delta=0.05):
    """
    为FS-VI预测提供PAC-Bayes风格的风险认证
    """
    # 预测分布
    pred_mean = x @ q_f_mean
    pred_std = torch.sqrt(torch.diag(x @ q_f_cov @ x.T))
    
    # 基于预测分布的风险上界
    # P(|f(x) - y| > ε) ≤ exp(-ε²/(2σ²))
    epsilon = torch.sqrt(-2 * pred_std**2 * torch.log(delta))
    
    # 风险上界
    risk_bound = torch.norm(pred_mean - y_true) + epsilon
    
    return {
        'prediction': pred_mean.item(),
        'uncertainty': pred_std.item(),
        'risk_bound': risk_bound.item(),
        'confidence_level': 1 - delta
    }

8 总结

8.1 核心贡献

  1. 函数空间范式转换:将变分推断从参数空间转移到函数空间,解决先验指定和后验病态问题
  2. 与NTK的统一:揭示了FS-VI与Neural Tangent Kernel的理论联系
  3. 不确定性量化:FS-VI自然地提供高质量的不确定性估计
  4. 实践框架:提供了完整的算法实现和代码模板

8.2 与本 Wiki 其他内容的联系


Footnotes

  1. Wu, M. et al. (2025). “Bridging the Gap Between Variational Inference and Stochastic Gradient MCMC in Function Space.” ICLR 2025.