Fisher信息与自然梯度
自然梯度(Natural Gradient)是信息几何框架下的最优梯度下降方向,它通过Fisher信息矩阵对梯度进行预处理,在统计学习的收敛性方面具有理论优势。
1. Fisher信息矩阵
1.1 统计学基础
设 是参数为 的概率分布族。得分函数(Score Function)定义为:
Fisher信息矩阵定义为得分函数的协方差:
1.2 Fisher信息的直观理解
两种等价定义:
-
协方差形式:
-
Hessian形式(对数似然的负Hessian):
直觉:Fisher信息度量了参数空间流形的局部曲率,反映了分布对参数变化的敏感性。
1.3 神经网络中的Fisher信息
对于神经网络 和损失函数 ,经验Fisher信息为:
与梯度外积的关系:
1.4 Fisher信息矩阵的性质
| 性质 | 描述 |
|---|---|
| 对称性 | |
| 半正定 | |
| 正则化 | (满秩时) |
| 标度不变性 | Fisher信息不随参数重参数化而改变 |
2. 自然梯度
2.1 参数空间的几何结构
参数空间 配备黎曼度量:
这定义了参数空间的信息几何结构。
2.2 KL散度作为度量
两个分布 和 之间的二阶KL散度近似为:
2.3 自然梯度定义
自然梯度定义为在KL散度约束下的最速下降方向:
解析形式:
2.4 自然梯度 vs 普通梯度
| 特性 | 普通梯度 | 自然梯度 |
|---|---|---|
| 更新方向 | ||
| 度量 | 欧几里得 | Fisher信息 |
| 坐标系 | 固定 | 随分布变化 |
| 收敛性 | 线性 | 可能超线性 |
| 计算复杂度 | 或更多 |
3. 与常用优化器的联系
3.1 Adam优化器
Adam的更新规则:
与自然梯度的联系:
Adam使用:
- :梯度的一阶矩估计(类似自然梯度中的预处理)
- :梯度平方的二阶矩估计(对角近似Fisher信息)
3.2 FAdam:Adam作为自然梯度
定理(Hwang, 2024):Adam可以被解释为使用对角经验Fisher的近似自然梯度。
修正项:
- 增强动量:改进 的估计
- 自适应 :考虑Fisher的尺度
- 梯度裁剪:提高稳定性
3.3 Adagrad
Adagrad的更新:
与Fisher的联系: 累积梯度平方,类似于对角Fisher信息的估计。
3.4 优化器与Fisher的关系表
| 优化器 | Fisher近似 | 预处理类型 |
|---|---|---|
| SGD | 无 | 无 |
| Momentum | 无 | 动量项 |
| Adagrad | 对角 | |
| RMSprop | 对角 | |
| Adam | 对角 + 动量 | |
| K-FAC | Kronecker分解 | |
| Shampoo | 矩阵形式 |
4. 经验Fisher近似
4.1 经验Fisher的定义
真实Fisher:
经验Fisher:
4.2 忽略的原因
经验Fisher忽略了:
- 期望 vs 样本:用采样均值替代真实期望
- 数据分布 vs 模型分布:使用数据分布而非模型分布
何时影响大:
- 小数据集
- 高度非线性的模型
- 离模型流形较远时
4.3 改进的经验Fisher
iEF(Improved Empirical Fisher, NeurIPS 2024):
提出反比例投影问题并修正:
5. K-FAC:Kronecker分解近似
5.1 问题
直接求 的复杂度是 ,对于大型神经网络不可行。
5.2 Kronecker分解思想
假设Fisher信息矩阵可以分解为:
其中:
- :关于激活值的Kronecker因子
- :关于权重的Kronecker因子
5.3 K-FAC算法
class KFACOptimizer:
def __init__(self, model, lr=0.001, momentum=0.9, damping=0.001):
self.model = model
self.lr = lr
self.momentum = momentum
self.damping = damping
# 存储Kronecker因子
self.F_accum = {} # 激活值的Fisher估计
self.G_accum = {} # 梯度的Fisher估计
def step(self, inputs, targets):
# 1. 前向传播
outputs = self.model(inputs)
loss = self.compute_loss(outputs, targets)
# 2. 反向传播
self.model.zero_grad()
loss.backward()
# 3. 累积Kronecker因子
self._accumulate_factors()
# 4. 计算自然梯度
for name, param in self.model.named_parameters():
if name in self.F_accum:
A = self.F_accum[name]
G = self.G_accum[name]
# Kronecker逆
F_inv = self._kronecker_inverse(A, G, self.damping)
# 自然梯度更新
param.grad.data = F_inv @ param.grad.data
# 5. 参数更新
self.optimizer.step()
def _accumulate_factors(self):
"""累积激活值和梯度的协方差"""
# 简化版本:使用当前batch的统计
for name, param in self.model.named_parameters():
if param.grad is None:
continue
# 简化:使用对角近似
grad = param.grad.data.flatten()
self.F_accum[name] = self.F_accum.get(name, 0) * 0.99 + grad @ grad.T * 0.015.4 K-FAC的理论优势
| 特性 | SGD | Adam | K-FAC |
|---|---|---|---|
| 预处理 | 无 | 对角 | 块对角 |
| 收敛速度 | 慢 | 中 | 快 |
| 理论保证 | 无 | 弱 | 强(线性收敛) |
| 适用场景 | 通用 | 通用 | 大batch |
6. Shampoo优化器
6.1 矩阵值预处理
Shampoo使用更精细的Fisher近似:
6.2 矩阵分解
对于权重矩阵 :
计算:
更新:
6.3 SVD分解优化
def shampoo_update(G, lr, momentum):
"""
Shampoo优化器的核心更新
G: 梯度矩阵
"""
# 计算 Gram 矩阵
GGT = G @ G.T + eps * I # m × m
GTG = G.T @ G + eps * I # n × n
# SVD 分解
U_A, S_A, _ = torch.svd(GGT)
U_B, S_B, _ = torch.svd(GTG)
# 预处理梯度
G_preconditioned = U_A @ (U_A.T @ G) + (G @ U_B) @ U_B.T
return G_preconditioned7. 自然梯度的应用
7.1 神经网络训练
何时使用自然梯度:
- 大batch训练(batch size > 1024)
- 需要快速收敛的任务
- 理论分析需求
7.2 变分推断
变分推断中,自然梯度用于更新变分参数:
7.3 元学习
MAML使用自然梯度快速适应新任务:
8. 实现注意事项
8.1 数值稳定性
def stable_fisher_inverse(F, damping=1e-5):
"""稳定的Fisher逆计算"""
# 添加阻尼
F_damped = F + damping * torch.eye(F.shape[0])
# 使用Cholesky分解
try:
L = torch.linalg.cholesky(F_damped)
F_inv = torch.cholesky_inverse(L)
except:
# 备用:特征分解
eigvals, eigvecs = torch.linalg.eigh(F_damped)
eigvals = torch.clamp(eigvals, min=1e-8)
F_inv = eigvecs @ torch.diag(1/eigvals) @ eigvecs.T
return F_inv8.2 内存考虑
| 方法 | 内存复杂度 | 适用规模 |
|---|---|---|
| 直接求逆 | ||
| 对角近似 | 任意规模 | |
| K-FAC | ||
| Shampoo |
8.3 阻尼参数选择
# 自适应阻尼
def adaptive_damping(F, target_cond=1e6):
"""自适应阻尼以控制条件数"""
eigvals = torch.linalg.eigvalsh(F)
cond = eigvals.max() / (eigvals.min() + 1e-10)
if cond > target_cond:
damping = eigvals.max() / target_cond - eigvals.min()
else:
damping = 1e-6
return damping9. 近期研究进展
9.1 FAdam (ICLR 2024)
核心发现:Adam与对角自然梯度的联系
改进:
- 修正动量估计
- 自适应阻尼
- 梯度裁剪集成
9.2 iEF (NeurIPS 2024)
发现:经验Fisher的”反比例投影”问题
解决方案:引入对角缩放矩阵
9.3 自然梯度与泛化
研究:自然梯度优化器是否带来更好的泛化?
结论:
- 在某些设置下,自然梯度确实改善泛化
- 效果与batch size相关
- 与隐式正则化有关联
10. 总结
10.1 核心要点
- Fisher信息矩阵:梯度外积的期望,度量分布曲率
- 自然梯度:,信息几何最优方向
- Adam:对角经验Fisher的近似
- K-FAC:Kronecker分解实现高效自然梯度
- Shampoo:矩阵值预处理,更精细的近似
10.2 选择指南
| 场景 | 推荐优化器 |
|---|---|
| 标准训练 | AdamW / SGD+Momentum |
| 大batch训练 | K-FAC |
| 超参数优化 | 自然梯度方法 |
| 大嵌入表 | Shampoo |