Fisher信息与自然梯度

自然梯度(Natural Gradient)是信息几何框架下的最优梯度下降方向,它通过Fisher信息矩阵对梯度进行预处理,在统计学习的收敛性方面具有理论优势。

1. Fisher信息矩阵

1.1 统计学基础

是参数为 的概率分布族。得分函数(Score Function)定义为:

Fisher信息矩阵定义为得分函数的协方差:

1.2 Fisher信息的直观理解

两种等价定义

  1. 协方差形式

  2. Hessian形式(对数似然的负Hessian):

直觉:Fisher信息度量了参数空间流形的局部曲率,反映了分布对参数变化的敏感性。

1.3 神经网络中的Fisher信息

对于神经网络 和损失函数 ,经验Fisher信息为:

与梯度外积的关系

1.4 Fisher信息矩阵的性质

性质描述
对称性
半正定
正则化(满秩时)
标度不变性Fisher信息不随参数重参数化而改变

2. 自然梯度

2.1 参数空间的几何结构

参数空间 配备黎曼度量:

这定义了参数空间的信息几何结构。

2.2 KL散度作为度量

两个分布 之间的二阶KL散度近似为:

2.3 自然梯度定义

自然梯度定义为在KL散度约束下的最速下降方向:

解析形式

2.4 自然梯度 vs 普通梯度

特性普通梯度自然梯度
更新方向
度量欧几里得Fisher信息
坐标系固定随分布变化
收敛性线性可能超线性
计算复杂度 或更多

3. 与常用优化器的联系

3.1 Adam优化器

Adam的更新规则:


与自然梯度的联系

Adam使用:

  • :梯度的一阶矩估计(类似自然梯度中的预处理)
  • :梯度平方的二阶矩估计(对角近似Fisher信息)

3.2 FAdam:Adam作为自然梯度

定理(Hwang, 2024):Adam可以被解释为使用对角经验Fisher的近似自然梯度。

修正项

  1. 增强动量:改进 的估计
  2. 自适应 :考虑Fisher的尺度
  3. 梯度裁剪:提高稳定性

3.3 Adagrad

Adagrad的更新:

与Fisher的联系 累积梯度平方,类似于对角Fisher信息的估计。

3.4 优化器与Fisher的关系表

优化器Fisher近似预处理类型
SGD
Momentum动量项
Adagrad对角
RMSprop对角
Adam对角 + 动量
K-FACKronecker分解
Shampoo矩阵形式

4. 经验Fisher近似

4.1 经验Fisher的定义

真实Fisher

经验Fisher

4.2 忽略的原因

经验Fisher忽略了:

  1. 期望 vs 样本:用采样均值替代真实期望
  2. 数据分布 vs 模型分布:使用数据分布而非模型分布

何时影响大

  • 小数据集
  • 高度非线性的模型
  • 离模型流形较远时

4.3 改进的经验Fisher

iEF(Improved Empirical Fisher, NeurIPS 2024):

提出反比例投影问题并修正:

5. K-FAC:Kronecker分解近似

5.1 问题

直接求 的复杂度是 ,对于大型神经网络不可行。

5.2 Kronecker分解思想

假设Fisher信息矩阵可以分解为:

其中:

  • :关于激活值的Kronecker因子
  • :关于权重的Kronecker因子

5.3 K-FAC算法

class KFACOptimizer:
    def __init__(self, model, lr=0.001, momentum=0.9, damping=0.001):
        self.model = model
        self.lr = lr
        self.momentum = momentum
        self.damping = damping
        
        # 存储Kronecker因子
        self.F_accum = {}  # 激活值的Fisher估计
        self.G_accum = {}  # 梯度的Fisher估计
        
    def step(self, inputs, targets):
        # 1. 前向传播
        outputs = self.model(inputs)
        loss = self.compute_loss(outputs, targets)
        
        # 2. 反向传播
        self.model.zero_grad()
        loss.backward()
        
        # 3. 累积Kronecker因子
        self._accumulate_factors()
        
        # 4. 计算自然梯度
        for name, param in self.model.named_parameters():
            if name in self.F_accum:
                A = self.F_accum[name]
                G = self.G_accum[name]
                
                # Kronecker逆
                F_inv = self._kronecker_inverse(A, G, self.damping)
                
                # 自然梯度更新
                param.grad.data = F_inv @ param.grad.data
        
        # 5. 参数更新
        self.optimizer.step()
    
    def _accumulate_factors(self):
        """累积激活值和梯度的协方差"""
        # 简化版本:使用当前batch的统计
        for name, param in self.model.named_parameters():
            if param.grad is None:
                continue
            
            # 简化:使用对角近似
            grad = param.grad.data.flatten()
            self.F_accum[name] = self.F_accum.get(name, 0) * 0.99 + grad @ grad.T * 0.01

5.4 K-FAC的理论优势

特性SGDAdamK-FAC
预处理对角块对角
收敛速度
理论保证强(线性收敛)
适用场景通用通用大batch

6. Shampoo优化器

6.1 矩阵值预处理

Shampoo使用更精细的Fisher近似:

6.2 矩阵分解

对于权重矩阵

计算:

更新:

6.3 SVD分解优化

def shampoo_update(G, lr, momentum):
    """
    Shampoo优化器的核心更新
    G: 梯度矩阵
    """
    # 计算 Gram 矩阵
    GGT = G @ G.T + eps * I  # m × m
    GTG = G.T @ G + eps * I  # n × n
    
    # SVD 分解
    U_A, S_A, _ = torch.svd(GGT)
    U_B, S_B, _ = torch.svd(GTG)
    
    # 预处理梯度
    G_preconditioned = U_A @ (U_A.T @ G) + (G @ U_B) @ U_B.T
    
    return G_preconditioned

7. 自然梯度的应用

7.1 神经网络训练

何时使用自然梯度

  • 大batch训练(batch size > 1024)
  • 需要快速收敛的任务
  • 理论分析需求

7.2 变分推断

变分推断中,自然梯度用于更新变分参数:

7.3 元学习

MAML使用自然梯度快速适应新任务:

8. 实现注意事项

8.1 数值稳定性

def stable_fisher_inverse(F, damping=1e-5):
    """稳定的Fisher逆计算"""
    # 添加阻尼
    F_damped = F + damping * torch.eye(F.shape[0])
    
    # 使用Cholesky分解
    try:
        L = torch.linalg.cholesky(F_damped)
        F_inv = torch.cholesky_inverse(L)
    except:
        # 备用:特征分解
        eigvals, eigvecs = torch.linalg.eigh(F_damped)
        eigvals = torch.clamp(eigvals, min=1e-8)
        F_inv = eigvecs @ torch.diag(1/eigvals) @ eigvecs.T
    
    return F_inv

8.2 内存考虑

方法内存复杂度适用规模
直接求逆
对角近似任意规模
K-FAC
Shampoo

8.3 阻尼参数选择

# 自适应阻尼
def adaptive_damping(F, target_cond=1e6):
    """自适应阻尼以控制条件数"""
    eigvals = torch.linalg.eigvalsh(F)
    cond = eigvals.max() / (eigvals.min() + 1e-10)
    
    if cond > target_cond:
        damping = eigvals.max() / target_cond - eigvals.min()
    else:
        damping = 1e-6
    
    return damping

9. 近期研究进展

9.1 FAdam (ICLR 2024)

核心发现:Adam与对角自然梯度的联系

改进

  1. 修正动量估计
  2. 自适应阻尼
  3. 梯度裁剪集成

9.2 iEF (NeurIPS 2024)

发现:经验Fisher的”反比例投影”问题

解决方案:引入对角缩放矩阵

9.3 自然梯度与泛化

研究:自然梯度优化器是否带来更好的泛化?

结论

  • 在某些设置下,自然梯度确实改善泛化
  • 效果与batch size相关
  • 与隐式正则化有关联

10. 总结

10.1 核心要点

  1. Fisher信息矩阵:梯度外积的期望,度量分布曲率
  2. 自然梯度,信息几何最优方向
  3. Adam:对角经验Fisher的近似
  4. K-FAC:Kronecker分解实现高效自然梯度
  5. Shampoo:矩阵值预处理,更精细的近似

10.2 选择指南

场景推荐优化器
标准训练AdamW / SGD+Momentum
大batch训练K-FAC
超参数优化自然梯度方法
大嵌入表Shampoo

10.3 相关专题

参考资料