IVON:大规模网络的变分学习

1. 概述

IVON(Improved Variational Online Newton)是2024年ICML Spotlight论文提出的可扩展变分学习方法,旨在解决变分推断在大规模神经网络中应用的三大挑战

  1. 可扩展性:传统变分推断需要存储和求逆Fisher信息矩阵,对于GPT-2等大模型不可行
  2. 梯度方差:BBVI的score function梯度估计方差大
  3. 优化效率:二阶方法虽然理论上更优,但实现复杂

IVON通过以下创新实现了可扩展变分学习:

  • K-FAC风格近似:将Fisher信息矩阵分解为Kronecker因子
  • Online Fisher估计:避免存储完整Fisher矩阵
  • 与Adam兼容接口:仅需修改优化器即可使用

核心贡献:IVON可以在GPT-2(124M参数)上高效运行,提供不确定性估计,同时匹配或超越Adam的优化性能。

1

2. 背景:变分学习的挑战

2.1 变分推断回顾

在贝叶斯深度学习中,我们对权重施加先验 ,然后计算后验 。由于精确后验不可处理,我们使用变分分布 近似。

证据下界(ELBO)

变分参数,需要最大化ELBO。

2.2 朴素方法的困难

Naive Variational Inference:

  1. 存储完整的协方差矩阵 空间, 为参数维度
  2. 计算Fisher信息矩阵:
  3. 梯度更新涉及矩阵求逆

对于GPT-2(124M参数):

  • 协方差矩阵需要 bytes 120 TB
  • 矩阵求逆在有限时间内不可行

2.3 Natural Gradient的吸引力

Natural Gradient更新:

其中 是Fisher信息矩阵。

理论优势

  • 在黎曼度量下最速下降方向
  • 收敛速度优于普通梯度
  • 参数空间的几何感知

实践困难:需要 或其近似。

3. IVON核心算法

3.1 Kronecker因子近似

K-FAC (Kronecker-Factored Approximate Curvature) 将Fisher矩阵分解为块对角结构:

其中每个层 的近似Fisher为:

物理意义

  • :输出维度方向的相关性
  • :输入维度方向的相关性
  • 存储从 降到 ,其中 是层宽

3.2 Online Fisher Estimation

问题:计算精确Fisher需要期望输入输出的梯度外积:

解决方案:使用滑动平均在线估计Kronecker因子:

前向传播存储

  • 输入激活:(输入相关矩阵)
  • 输出梯度:(梯度相关矩阵)

更新规则

其中 是指数衰减率,

3.3 IVON更新步骤

初始化

  • :预训练权重或随机初始化
  • :单位矩阵初始化

每步迭代

def ivon_step(model, inputs, targets, optimizer, rho=0.01):
    # 1. 前向传播(存储激活用于Fisher估计)
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)
    
    # 2. 反向传播
    model.zero_grad()
    loss.backward()
    
    # 3. Online Fisher估计
    for name, param in model.named_parameters():
        if param.grad is not None:
            # 更新Kronecker因子 A, B
            update_kronecker_factors(param, rho)
    
    # 4. 计算自然梯度
    nat_grad = compute_natural_gradient(param.grad, A, B)
    
    # 5. 参数更新
    with torch.no_grad():
        # 分解形式:A ⊗ B 的逆 = A^{-1} ⊗ B^{-1}
        mu = mu - alpha * matvec_inv(param.grad, A, B)
        
        # KL散度正则化(来自先验)
        mu = (1 - beta) * mu + beta * prior_mean
        
    # 6. 更新优化器状态(兼容Adam风格)
    optimizer.state['exp_avg'] = ...

3.4 与Adam的关系

Adam更新

IVON更新

关键区别:Adam使用逐元素的学习率缩放,IVON使用Kronecker分解的曲率感知缩放。

4. 变分目标与KL正则化

4.1 变分目标函数

IVON优化的完整目标:

其中:

  • 由Kronecker因子构建
  • :KL正则化权重

4.2 Prior退火策略

问题:强先验会限制网络容量,特别是对于预训练模型。

解决方案:Prior退火(类似SAE的KL耗散):

训练初期 (自由训练),后期逐渐加入贝叶斯正则化。

4.3 指数族先验优势

当先验 和变分分布 都是高斯时,KL散度有闭式解:

使用Kronecker近似 ,可以高效计算。

5. 实现细节

5.1 PyTorch实现

import torch
import torch.nn as nn
import math
 
class KroneckerFactoredCurvature:
    """Kronecker因子近似的Fisher估计器"""
    
    def __init__(self, param_shape, rho=0.01, ema=True):
        self.rho = rho  # 衰减率
        self.ema = ema   # 是否使用指数滑动平均
        
        # Kronecker因子 A, B
        if len(param_shape) == 2:  # 权重矩阵
            self.A = torch.eye(param_shape[0])  # out × out
            self.B = torch.eye(param_shape[1])  # in × in
        else:  # 偏置向量
            self.A = torch.eye(param_shape[0])
            self.B = None
        
        # 梯度缓存
        self.grad_cache = None
        
    def update(self, grad_output, input_act=None):
        """更新Kronecker因子"""
        # 估计 A = E[∇o ℓ ∇o⊤]
        if grad_output.dim() > 1:
            A_est = torch.matmul(grad_output, grad_output.t()) / grad_output.size(0)
        else:
            A_est = grad_output.pow(2).mean()
        
        # 估计 B = E[a a⊤](输入相关矩阵)
        if input_act is not None and self.B is not None:
            B_est = torch.matmul(input_act, input_act.t()) / input_act.size(0)
            if self.ema:
                self.B = (1 - self.rho) * self.B + self.rho * B_est
            else:
                self.B = B_est
        
        # 更新A
        if self.ema:
            self.A = (1 - self.rho) * self.A + self.rho * A_est
        else:
            self.A = A_est
            
    def matvec(self, vec):
        """计算 (A ⊗ B) vec"""
        if self.B is None:
            return self.A @ vec
        else:
            out_dim, in_dim = self.A.shape[0], self.B.shape[0]
            # vec reshape: (in, out) -> matmul -> (out, in) -> flatten
            mat = vec.view(in_dim, out_dim)
            result = self.A @ mat @ self.B.t()
            return result.flatten()
            
    def inv_matvec(self, vec, diag_eps=1e-6):
        """计算 (A ⊗ B)^{-1} vec ≈ A^{-1} ⊗ B^{-1} vec"""
        if self.B is None:
            A_inv = torch.linalg.inv(self.A + diag_eps * torch.eye_like(self.A, 0))
            return A_inv @ vec
        else:
            A_inv = torch.linalg.inv(self.A + diag_eps * torch.eye_like(self.A, 0))
            B_inv = torch.linalg.inv(self.B + diag_eps * torch.eye_like(self.B, 0))
            out_dim, in_dim = self.A.shape[0], self.B.shape[0]
            mat = vec.view(in_dim, out_dim)
            result = B_inv @ mat.t() @ A_inv.t()
            return result.flatten()
 
 
class IVON(torch.optim.Optimizer):
    """Improved Variational Online Newton (IVON)"""
    
    def __init__(
        self,
        params,
        lr=1e-3,
        betas=(0.9, 0.999),
        rho=0.01,
        kl_weight=1e-5,
        prior_std=1.0,
        diag_eps=1e-6,
    ):
        defaults = dict(
            lr=lr,
            betas=betas,
            rho=rho,
            kl_weight=kl_weight,
            prior_std=prior_std,
            diag_eps=diag_eps,
        )
        super().__init__(params, defaults)
        
    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()
            
        for group in self.param_groups:
            for param in group['params']:
                if param.grad is None:
                    continue
                    
                state = self.state[param]
                
                # 初始化状态
                if len(state) == 0:
                    state['mu'] = param.data.clone()
                    state['exp_avg'] = torch.zeros_like(param.data)
                    state['exp_avg_sq'] = torch.zeros_like(param.data)
                    state['step'] = 0
                    state['curvature'] = KroneckerFactoredCurvature(
                        param.shape, rho=group['rho']
                    )
                
                # 更新步骤计数
                state['step'] += 1
                beta1, beta2 = group['betas']
                
                # 计算自然梯度
                grad = param.grad.data
                curv = state['curvature']
                
                # 存储梯度用于后续Fisher估计
                # (实际实现需要在前向传播时钩子)
                
                # 自然梯度近似
                nat_grad = curv.inv_matvec(grad, group['diag_eps'])
                
                # 动量更新
                state['exp_avg'] = beta1 * state['exp_avg'] + (1 - beta1) * nat_grad
                
                # 二阶矩估计(用于数值稳定性)
                state['exp_avg_sq'] = beta2 * state['exp_avg_sq'] + (1 - beta2) * nat_grad.pow(2)
                
                # 自适应学习率
                bias_correct1 = 1 - beta1 ** state['step']
                bias_correct2 = 1 - beta2 ** state['step']
                biased_exp_avg = state['exp_avg'] / bias_correct1
                biased_exp_avg_sq = state['exp_avg_sq'] / bias_correct2
                
                # 1. 更新均值
                lr = group['lr']
                new_mu = state['mu'] - lr * biased_exp_avg / (torch.sqrt(biased_exp_avg_sq) + group['diag_eps'])
                
                # 2. KL正则化(向先验收缩)
                prior_mean = torch.zeros_like(param)
                new_mu = (1 - group['kl_weight']) * new_mu + group['kl_weight'] * prior_mean
                
                # 3. 应用更新
                param.data = new_mu
                state['mu'] = new_mu.clone()
                
        return loss

5.2 训练循环示例

def train_with_ivon(model, train_loader, test_loader, epochs=10):
    """使用IVON训练贝叶斯神经网络"""
    
    optimizer = IVON(
        model.parameters(),
        lr=1e-3,
        rho=0.01,        # Fisher估计衰减率
        kl_weight=1e-4,  # KL正则化权重
        prior_std=1.0,   # 先验标准差
    )
    
    for epoch in range(epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            
            # 前向传播(需要存储激活用于Fisher估计)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            
            # 反向传播
            loss.backward()
            
            # IVON更新
            optimizer.step()
            
        # 评估
        model.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data, target in test_loader:
                output = model(data)
                test_loss += F.cross_entropy(output, target, reduction='sum')
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
        
        print(f'Epoch {epoch}: Test Loss: {test_loss/len(test_loader.dataset):.4f}, '
              f'Accuracy: {100*correct/len(test_loader.dataset):.2f}%')

5.3 不确定性估计

后验预测分布

实现

def predict_with_uncertainty(model, x, n_samples=30):
    """Monte Carlo Dropout风格的预测(使用IVON后验)"""
    model.eval()
    logits_list = []
    
    with torch.no_grad():
        for _ in range(n_samples):
            # 从变分后验采样
            for param in model.parameters():
                # 添加噪声到均值
                noise = torch.randn_like(param) * param.std() * 0.1
                param.data = param.data + noise
            
            logits = model(x)
            logits_list.append(logits)
    
    # 恢复原始均值
    for param in model.parameters():
        param.data = state['mu']
    
    logits = torch.stack(logits_list)
    
    # 预测均值和方差
    pred_mean = logits.mean(dim=0)
    pred_std = logits.std(dim=0)
    
    return pred_mean, pred_std

6. 实验结果

6.1 优化性能对比

方法CIFAR-10 (ResNet-18)ImageNet (ResNet-50)
SGD94.2%75.3%
Adam94.5%75.8%
IVON94.8%76.2%

IVON在标准任务上与Adam性能相当或略优。

6.2 不确定性量化

Out-of-Distribution检测(CIFAR-10 vs SVHN):

方法AUROCECE
Vanilla0.720.08
MC Dropout0.810.05
IVON0.890.03

IVON提供更可靠的概率估计和OOD检测能力。

6.3 大规模实验

模型参数训练时间(相对Adam)GPU内存
ResNet-5025M1.1×+15%
GPT-2124M1.3×+20%
ViT-B86M1.2×+18%

IVON的开销可控,适合大规模训练。

7. 理论与实践分析

7.1 收敛性分析

局部收敛速率

为真实Fisher矩阵, 为Kronecker近似,则IVON的收敛速率:

其中 取决于学习率和Fisher条件数。

修复界限:当 时,IVON收敛到true natural gradient descent。

7.2 实践建议

  1. Fisher估计衰减率

    • 推荐范围:0.01 - 0.1
    • 大数据集/小batch用小值
  2. KL权重

    • 预训练模型微调:1e-5 - 1e-4
    • 从头训练:1e-4 - 1e-3
  3. 先验标准差

    • Xavier初始化后验:prior_std ≈ 1.0
    • 小权重网络:prior_std ≈ 0.5

8. 与其他方法的对比

8.1 vs MC Dropout

方面MC DropoutIVON
实现复杂度极低中等
推理开销S次前向传播单次(后验均值)
不确定性质量中等
可扩展性极好
理论基础强(变分推断)

8.2 vs SWA-Gaussian

方面SWA-GaussianIVON
权重采样是(可选)
曲率估计Kronecker近似
KL正则化
适用场景提升泛化不确定性+泛化

9. 应用场景

9.1 贝叶斯模型压缩

通过IVON的变分框架,可以同时进行:

  • 权重不确定性量化
  • 重要权重识别(方差小的权重更可靠)
  • 后训练校准

9.2 安全关键系统

在自动驾驶、医疗诊断等场景:

  • 不确定性感知决策
  • OOD输入检测
  • 主动学习样本选择

9.3 联邦学习

IVON的local prior正则化天然适合FL:

  • 每个client维护局部变分分布
  • Server聚合时自动包含KL约束
  • 差分隐私与变分推断结合

10. 总结

IVON实现了变分推断在大规模神经网络中的实用化:

核心贡献

  1. Kronecker分解实现 存储和 更新
  2. Online Fisher估计避免batch计算
  3. 与Adam兼容的实现接口
  4. 理论保证的收敛性

适用场景

  • 需要不确定性估计的大规模模型
  • 分布外检测和鲁棒性要求
  • 安全关键应用

参考文献

Footnotes

  1. Immer, A., et al. (2024). Scalable Variational Learning for Large Networks. ICML 2024.