IVON:大规模网络的变分学习
1. 概述
IVON(Improved Variational Online Newton)是2024年ICML Spotlight论文提出的可扩展变分学习方法,旨在解决变分推断在大规模神经网络中应用的三大挑战:
- 可扩展性:传统变分推断需要存储和求逆Fisher信息矩阵,对于GPT-2等大模型不可行
- 梯度方差:BBVI的score function梯度估计方差大
- 优化效率:二阶方法虽然理论上更优,但实现复杂
IVON通过以下创新实现了可扩展变分学习:
- K-FAC风格近似:将Fisher信息矩阵分解为Kronecker因子
- Online Fisher估计:避免存储完整Fisher矩阵
- 与Adam兼容接口:仅需修改优化器即可使用
核心贡献:IVON可以在GPT-2(124M参数)上高效运行,提供不确定性估计,同时匹配或超越Adam的优化性能。
2. 背景:变分学习的挑战
2.1 变分推断回顾
在贝叶斯深度学习中,我们对权重施加先验 ,然后计算后验 。由于精确后验不可处理,我们使用变分分布 近似。
证据下界(ELBO):
变分参数:,需要最大化ELBO。
2.2 朴素方法的困难
Naive Variational Inference:
- 存储完整的协方差矩阵 : 空间, 为参数维度
- 计算Fisher信息矩阵: 或
- 梯度更新涉及矩阵求逆
对于GPT-2(124M参数):
- 协方差矩阵需要 bytes 120 TB
- 矩阵求逆在有限时间内不可行
2.3 Natural Gradient的吸引力
Natural Gradient更新:
其中 是Fisher信息矩阵。
理论优势:
- 在黎曼度量下最速下降方向
- 收敛速度优于普通梯度
- 参数空间的几何感知
实践困难:需要 或其近似。
3. IVON核心算法
3.1 Kronecker因子近似
K-FAC (Kronecker-Factored Approximate Curvature) 将Fisher矩阵分解为块对角结构:
其中每个层 的近似Fisher为:
物理意义:
- :输出维度方向的相关性
- :输入维度方向的相关性
- 存储从 降到 ,其中 是层宽
3.2 Online Fisher Estimation
问题:计算精确Fisher需要期望输入输出的梯度外积:
解决方案:使用滑动平均在线估计Kronecker因子:
前向传播存储:
- 输入激活:(输入相关矩阵)
- 输出梯度:(梯度相关矩阵)
更新规则:
其中 是指数衰减率,。
3.3 IVON更新步骤
初始化:
- :预训练权重或随机初始化
- :单位矩阵初始化
每步迭代:
def ivon_step(model, inputs, targets, optimizer, rho=0.01):
# 1. 前向传播(存储激活用于Fisher估计)
outputs = model(inputs)
loss = loss_fn(outputs, targets)
# 2. 反向传播
model.zero_grad()
loss.backward()
# 3. Online Fisher估计
for name, param in model.named_parameters():
if param.grad is not None:
# 更新Kronecker因子 A, B
update_kronecker_factors(param, rho)
# 4. 计算自然梯度
nat_grad = compute_natural_gradient(param.grad, A, B)
# 5. 参数更新
with torch.no_grad():
# 分解形式:A ⊗ B 的逆 = A^{-1} ⊗ B^{-1}
mu = mu - alpha * matvec_inv(param.grad, A, B)
# KL散度正则化(来自先验)
mu = (1 - beta) * mu + beta * prior_mean
# 6. 更新优化器状态(兼容Adam风格)
optimizer.state['exp_avg'] = ...3.4 与Adam的关系
Adam更新:
IVON更新:
关键区别:Adam使用逐元素的学习率缩放,IVON使用Kronecker分解的曲率感知缩放。
4. 变分目标与KL正则化
4.1 变分目标函数
IVON优化的完整目标:
其中:
- 由Kronecker因子构建
- :KL正则化权重
4.2 Prior退火策略
问题:强先验会限制网络容量,特别是对于预训练模型。
解决方案:Prior退火(类似SAE的KL耗散):
训练初期 (自由训练),后期逐渐加入贝叶斯正则化。
4.3 指数族先验优势
当先验 和变分分布 都是高斯时,KL散度有闭式解:
使用Kronecker近似 ,可以高效计算。
5. 实现细节
5.1 PyTorch实现
import torch
import torch.nn as nn
import math
class KroneckerFactoredCurvature:
"""Kronecker因子近似的Fisher估计器"""
def __init__(self, param_shape, rho=0.01, ema=True):
self.rho = rho # 衰减率
self.ema = ema # 是否使用指数滑动平均
# Kronecker因子 A, B
if len(param_shape) == 2: # 权重矩阵
self.A = torch.eye(param_shape[0]) # out × out
self.B = torch.eye(param_shape[1]) # in × in
else: # 偏置向量
self.A = torch.eye(param_shape[0])
self.B = None
# 梯度缓存
self.grad_cache = None
def update(self, grad_output, input_act=None):
"""更新Kronecker因子"""
# 估计 A = E[∇o ℓ ∇o⊤]
if grad_output.dim() > 1:
A_est = torch.matmul(grad_output, grad_output.t()) / grad_output.size(0)
else:
A_est = grad_output.pow(2).mean()
# 估计 B = E[a a⊤](输入相关矩阵)
if input_act is not None and self.B is not None:
B_est = torch.matmul(input_act, input_act.t()) / input_act.size(0)
if self.ema:
self.B = (1 - self.rho) * self.B + self.rho * B_est
else:
self.B = B_est
# 更新A
if self.ema:
self.A = (1 - self.rho) * self.A + self.rho * A_est
else:
self.A = A_est
def matvec(self, vec):
"""计算 (A ⊗ B) vec"""
if self.B is None:
return self.A @ vec
else:
out_dim, in_dim = self.A.shape[0], self.B.shape[0]
# vec reshape: (in, out) -> matmul -> (out, in) -> flatten
mat = vec.view(in_dim, out_dim)
result = self.A @ mat @ self.B.t()
return result.flatten()
def inv_matvec(self, vec, diag_eps=1e-6):
"""计算 (A ⊗ B)^{-1} vec ≈ A^{-1} ⊗ B^{-1} vec"""
if self.B is None:
A_inv = torch.linalg.inv(self.A + diag_eps * torch.eye_like(self.A, 0))
return A_inv @ vec
else:
A_inv = torch.linalg.inv(self.A + diag_eps * torch.eye_like(self.A, 0))
B_inv = torch.linalg.inv(self.B + diag_eps * torch.eye_like(self.B, 0))
out_dim, in_dim = self.A.shape[0], self.B.shape[0]
mat = vec.view(in_dim, out_dim)
result = B_inv @ mat.t() @ A_inv.t()
return result.flatten()
class IVON(torch.optim.Optimizer):
"""Improved Variational Online Newton (IVON)"""
def __init__(
self,
params,
lr=1e-3,
betas=(0.9, 0.999),
rho=0.01,
kl_weight=1e-5,
prior_std=1.0,
diag_eps=1e-6,
):
defaults = dict(
lr=lr,
betas=betas,
rho=rho,
kl_weight=kl_weight,
prior_std=prior_std,
diag_eps=diag_eps,
)
super().__init__(params, defaults)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for param in group['params']:
if param.grad is None:
continue
state = self.state[param]
# 初始化状态
if len(state) == 0:
state['mu'] = param.data.clone()
state['exp_avg'] = torch.zeros_like(param.data)
state['exp_avg_sq'] = torch.zeros_like(param.data)
state['step'] = 0
state['curvature'] = KroneckerFactoredCurvature(
param.shape, rho=group['rho']
)
# 更新步骤计数
state['step'] += 1
beta1, beta2 = group['betas']
# 计算自然梯度
grad = param.grad.data
curv = state['curvature']
# 存储梯度用于后续Fisher估计
# (实际实现需要在前向传播时钩子)
# 自然梯度近似
nat_grad = curv.inv_matvec(grad, group['diag_eps'])
# 动量更新
state['exp_avg'] = beta1 * state['exp_avg'] + (1 - beta1) * nat_grad
# 二阶矩估计(用于数值稳定性)
state['exp_avg_sq'] = beta2 * state['exp_avg_sq'] + (1 - beta2) * nat_grad.pow(2)
# 自适应学习率
bias_correct1 = 1 - beta1 ** state['step']
bias_correct2 = 1 - beta2 ** state['step']
biased_exp_avg = state['exp_avg'] / bias_correct1
biased_exp_avg_sq = state['exp_avg_sq'] / bias_correct2
# 1. 更新均值
lr = group['lr']
new_mu = state['mu'] - lr * biased_exp_avg / (torch.sqrt(biased_exp_avg_sq) + group['diag_eps'])
# 2. KL正则化(向先验收缩)
prior_mean = torch.zeros_like(param)
new_mu = (1 - group['kl_weight']) * new_mu + group['kl_weight'] * prior_mean
# 3. 应用更新
param.data = new_mu
state['mu'] = new_mu.clone()
return loss5.2 训练循环示例
def train_with_ivon(model, train_loader, test_loader, epochs=10):
"""使用IVON训练贝叶斯神经网络"""
optimizer = IVON(
model.parameters(),
lr=1e-3,
rho=0.01, # Fisher估计衰减率
kl_weight=1e-4, # KL正则化权重
prior_std=1.0, # 先验标准差
)
for epoch in range(epochs):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
# 前向传播(需要存储激活用于Fisher估计)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
# 反向传播
loss.backward()
# IVON更新
optimizer.step()
# 评估
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
test_loss += F.cross_entropy(output, target, reduction='sum')
pred = output.argmax(dim=1)
correct += pred.eq(target).sum().item()
print(f'Epoch {epoch}: Test Loss: {test_loss/len(test_loader.dataset):.4f}, '
f'Accuracy: {100*correct/len(test_loader.dataset):.2f}%')5.3 不确定性估计
后验预测分布:
实现:
def predict_with_uncertainty(model, x, n_samples=30):
"""Monte Carlo Dropout风格的预测(使用IVON后验)"""
model.eval()
logits_list = []
with torch.no_grad():
for _ in range(n_samples):
# 从变分后验采样
for param in model.parameters():
# 添加噪声到均值
noise = torch.randn_like(param) * param.std() * 0.1
param.data = param.data + noise
logits = model(x)
logits_list.append(logits)
# 恢复原始均值
for param in model.parameters():
param.data = state['mu']
logits = torch.stack(logits_list)
# 预测均值和方差
pred_mean = logits.mean(dim=0)
pred_std = logits.std(dim=0)
return pred_mean, pred_std6. 实验结果
6.1 优化性能对比
| 方法 | CIFAR-10 (ResNet-18) | ImageNet (ResNet-50) |
|---|---|---|
| SGD | 94.2% | 75.3% |
| Adam | 94.5% | 75.8% |
| IVON | 94.8% | 76.2% |
IVON在标准任务上与Adam性能相当或略优。
6.2 不确定性量化
Out-of-Distribution检测(CIFAR-10 vs SVHN):
| 方法 | AUROC | ECE |
|---|---|---|
| Vanilla | 0.72 | 0.08 |
| MC Dropout | 0.81 | 0.05 |
| IVON | 0.89 | 0.03 |
IVON提供更可靠的概率估计和OOD检测能力。
6.3 大规模实验
| 模型 | 参数 | 训练时间(相对Adam) | GPU内存 |
|---|---|---|---|
| ResNet-50 | 25M | 1.1× | +15% |
| GPT-2 | 124M | 1.3× | +20% |
| ViT-B | 86M | 1.2× | +18% |
IVON的开销可控,适合大规模训练。
7. 理论与实践分析
7.1 收敛性分析
局部收敛速率:
设 为真实Fisher矩阵, 为Kronecker近似,则IVON的收敛速率:
其中 取决于学习率和Fisher条件数。
修复界限:当 时,IVON收敛到true natural gradient descent。
7.2 实践建议
-
Fisher估计衰减率 :
- 推荐范围:0.01 - 0.1
- 大数据集/小batch用小值
-
KL权重 :
- 预训练模型微调:1e-5 - 1e-4
- 从头训练:1e-4 - 1e-3
-
先验标准差:
- Xavier初始化后验:prior_std ≈ 1.0
- 小权重网络:prior_std ≈ 0.5
8. 与其他方法的对比
8.1 vs MC Dropout
| 方面 | MC Dropout | IVON |
|---|---|---|
| 实现复杂度 | 极低 | 中等 |
| 推理开销 | S次前向传播 | 单次(后验均值) |
| 不确定性质量 | 中等 | 高 |
| 可扩展性 | 好 | 极好 |
| 理论基础 | 弱 | 强(变分推断) |
8.2 vs SWA-Gaussian
| 方面 | SWA-Gaussian | IVON |
|---|---|---|
| 权重采样 | 否 | 是(可选) |
| 曲率估计 | 无 | Kronecker近似 |
| KL正则化 | 无 | 有 |
| 适用场景 | 提升泛化 | 不确定性+泛化 |
9. 应用场景
9.1 贝叶斯模型压缩
通过IVON的变分框架,可以同时进行:
- 权重不确定性量化
- 重要权重识别(方差小的权重更可靠)
- 后训练校准
9.2 安全关键系统
在自动驾驶、医疗诊断等场景:
- 不确定性感知决策
- OOD输入检测
- 主动学习样本选择
9.3 联邦学习
IVON的local prior正则化天然适合FL:
- 每个client维护局部变分分布
- Server聚合时自动包含KL约束
- 差分隐私与变分推断结合
10. 总结
IVON实现了变分推断在大规模神经网络中的实用化:
核心贡献:
- Kronecker分解实现 存储和 更新
- Online Fisher估计避免batch计算
- 与Adam兼容的实现接口
- 理论保证的收敛性
适用场景:
- 需要不确定性估计的大规模模型
- 分布外检测和鲁棒性要求
- 安全关键应用
参考文献
Footnotes
-
Immer, A., et al. (2024). Scalable Variational Learning for Large Networks. ICML 2024. ↩