概述

自然梯度(Natural Gradient)是一种利用Fisher信息矩阵作为黎曼度量的二阶优化方法。相比普通梯度下降,自然梯度在参数空间中具有坐标不变性——无论参数如何缩放或重参数化,优化方向都保持一致。K-FAC(Kronecker-Factored Approximate Curvature)是高效近似自然梯度的方法,已在Transformer训练中被广泛采用。1

核心洞察:在概率分布空间中,“最速下降方向”不是欧几里得梯度方向,而是由Fisher信息矩阵定义的黎曼梯度方向。


从欧几里得梯度到自然梯度

普通梯度下降的问题

设损失函数 ,普通梯度更新:

问题:梯度方向依赖于参数化——相同的学习算法在不同参数化下表现不同。

示例:考虑两种参数化

参数化 A: θ                     参数化 B: 10θ
    │                              │
    │  ∇ℒ                          │  ∇ℒ
    ↓                              ↓
  等高线                          等高线(被拉伸!)
    
    真实最速方向相同               梯度方向不同!

黎曼梯度下降

在黎曼流形上,梯度定义为:

其中 是黎曼度量矩阵。

自然梯度使用Fisher信息矩阵 作为度量:

Fisher信息矩阵

定义:对于概率模型 ,Fisher信息矩阵定义为:

性质

  • 对称正定
  • 参数化不变
  • 是损失函数Hessian的期望(下界)

自然梯度的几何意义

import numpy as np
import matplotlib.pyplot as plt
 
def illustrate_natural_gradient():
    """可视化自然梯度与普通梯度的区别"""
    
    # 定义损失函数(二维椭球)
    def loss(theta):
        return 0.5 * (theta[0]**2 + 100 * theta[1]**2)
    
    def gradient(theta):
        return np.array([theta[0], 100 * theta[1]])
    
    # Fisher信息矩阵(对于高斯分布,这里简化为对角矩阵)
    # 假设参数尺度不同
    F = np.diag([1.0, 0.01])  # 模拟不同的信息量
    
    theta = np.array([2.0, 0.1])
    
    grad = gradient(theta)
    nat_grad = np.linalg.solve(F, grad)
    
    print(f"普通梯度: {grad}")
    print(f"自然梯度: {nat_grad}")
    print(f"Fisher矩阵: {F}")
    print(f"注意:自然梯度在信息量小的方向上步长更大")
 
illustrate_natural_gradient()

Fisher信息矩阵的计算

对于神经网络

对于神经网络的权重参数 ,Fisher信息矩阵的大小为 是参数数量),直接计算和存储不可行。

定义

Fisher vs Hessian

方面Fisher信息矩阵Hessian
定义梯度的协方差损失的二阶导
正定性总是正定不一定
计算只需一阶导数需要二阶导数
期望遍历数据分布遍历训练样本

关系

其中 是对数似然的Hessian。

经验Fisher

在实际计算中,使用经验Fisher

这等价于梯度外积的均值。


Kronecker-Factored Approximate Curvature (K-FAC)

核心思想

K-FAC将巨大的Fisher矩阵分解为小块的Kronecker积

其中 是小得多的矩阵。

神经网络层分解

对于全连接层 ,参数为

独立近似

Kronecker分解

  • :对输出单元的Fisher估计
  • :对输入的Fisher估计

K-FAC更新公式

对于参数块

等价于:

Python实现

import torch
import torch.nn as nn
from collections import OrderedDict
 
class KFACOptimizer:
    """
    K-FAC优化器实现
    
    简化版本:仅处理全连接层
    """
    
    def __init__(self, model, lr=0.001, damping=1e-3, kl_clip=0.01):
        self.model = model
        self.lr = lr
        self.damping = damping
        self.kl_clip = kl_clip
        
        # 为每个参数块存储Fisher估计
        self.fisher_info = {}
        self.state = {}
        
        # 注册所有全连接层
        self._register_modules()
        
    def _register_modules(self):
        """识别可优化的参数块"""
        self.modules = []
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                self.modules.append((name, module))
                
    def step(self, closure=None):
        """一次优化步骤"""
        
        # 1. 计算梯度
        if closure is not None:
            loss = closure()
        else:
            loss = None
            
        # 2. 累积Fisher信息(简化:使用单样本估计)
        self._update_fisher()
        
        # 3. 执行K-FAC更新
        self._kfac_update()
        
        return loss
    
    def _update_fisher(self):
        """
        更新Fisher信息估计
        
        简化版本:使用梯度外积
        """
        for name, module in self.modules:
            if hasattr(module, 'weight') and module.weight.grad is not None:
                # 获取梯度
                grad_w = module.weight.grad.detach()
                grad_b = module.bias.grad.detach() if module.bias is not None else None
                
                # 创建或更新Fisher估计
                if name not in self.fisher_info:
                    self.fisher_info[name] = {
                        'G': torch.zeros_like(grad_w.T @ grad_w),
                        'S': torch.zeros_like(grad_w @ grad_w.T),
                        'G_b': torch.zeros(module.out_features, module.out_features),
                    }
                
                # Kronecker分解:G ⊗ S ≈ F
                # 这里用移动平均近似期望
                beta = 0.99  # 指数衰减
                
                # 对于输出维度
                G = grad_w.T @ grad_w / grad_w.shape[0]
                self.fisher_info[name]['G'] = beta * self.fisher_info[name]['G'] + (1-beta) * G
                
                # 对于输入维度
                S = grad_w @ grad_w.T / grad_w.shape[0]
                self.fisher_info[name]['S'] = beta * self.fisher_info[name]['S'] + (1-beta) * S
                
                if grad_b is not None:
                    G_b = grad_b.unsqueeze(1) @ grad_b.unsqueeze(0) / grad_b.shape[0]
                    self.fisher_info[name]['G_b'] = beta * self.fisher_info[name]['G_b'] + (1-beta) * G_b
    
    def _kfac_update(self):
        """
        执行K-FAC更新
        
        W ← W - α * S⁻¹ * (∇W ⊙ ∇W) * G⁻¹
        """
        for name, module in self.modules:
            if name not in self.fisher_info:
                continue
                
            info = self.fisher_info[name]
            
            # 提取梯度
            grad_w = module.weight.grad
            grad_b = module.bias.grad if module.bias is not None else None
            
            # 计算 Kronecker 逆
            G = info['G'] + self.damping * torch.eye(info['G'].shape[0])
            S = info['S'] + self.damping * torch.eye(info['S'].shape[0])
            
            G_inv = torch.linalg.solve(G, torch.eye(G.shape[0]))
            S_inv = torch.linalg.solve(S, torch.eye(S.shape[0]))
            
            # K-FAC更新(利用分块结构)
            # ∇W (S⁻¹ ⊗ G⁻¹) = G⁻¹ @ ∇W @ S⁻¹
            update = G_inv @ grad_w @ S_inv
            
            # 应用更新
            module.weight.data -= self.lr * update.T
            
            # 偏置更新(如果存在)
            if grad_b is not None:
                G_b = info['G_b'] + self.damping * torch.eye(info['G_b'].shape[0])
                G_b_inv = torch.linalg.solve(G_b, torch.eye(G_b.shape[0]))
                update_b = G_b_inv @ grad_b
                module.bias.data -= self.lr * update_b
    
    def zero_grad(self):
        """清零梯度"""
        for _, module in self.modules:
            if module.weight.grad is not None:
                module.weight.grad.zero_()
            if module.bias is not None and module.bias.grad is not None:
                module.bias.grad.zero_()
 
 
# 测试K-FAC
def test_kfac():
    """简单测试"""
    torch.manual_seed(42)
    
    model = nn.Sequential(
        nn.Linear(10, 20),
        nn.ReLU(),
        nn.Linear(20, 5)
    )
    
    optimizer = KFACOptimizer(model, lr=0.01, damping=1e-3)
    criterion = nn.MSELoss()
    
    # 随机数据
    x = torch.randn(32, 10)
    y = torch.randn(32, 5)
    
    for epoch in range(3):
        optimizer.zero_grad()
        pred = model(x)
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
 
test_kfac()

K-FAC的理论保证

近似类型准确性计算复杂度
全Fisher精确 存储,
K-FAC分块近似 存储, 更新

对角近似(EKFAC)

更激进的近似:对每个参数独立估计Fisher对角元:

这将复杂度降到 ,但准确性下降。


K-FAC在Transformer中的应用

Transformer的Fisher结构

Transformer中的注意力机制具有特殊的Fisher结构:

  • Query/Key/Value投影:独立块
  • FFN层:两个线性层
  • 注意力权重:高度条件化

实践考量

class TransformerKFAC:
    """
    Transformer的K-FAC优化器
    """
    
    def __init__(self, model, lr=0.001, damping=1e-3):
        self.model = model
        self.lr = lr
        self.damping = damping
        
        # 分层存储Fisher信息
        self.fisher = {}
        self._setup_blocks()
    
    def _setup_blocks(self):
        """识别参数块"""
        self.blocks = []
        
        for name, param in self.model.named_parameters():
            if 'weight' in name:
                # 按层组织
                layer_name = '.'.join(name.split('.')[:-2])  # e.g., encoder.layer.0
                if layer_name not in self.fisher:
                    self.fisher[layer_name] = {}
                
                # 确定是哪个张量
                tensor_type = name.split('.')[-1]  # weight or bias
                self.fisher[layer_name][tensor_type] = {
                    'param': param,
                    'G': None,  # output dimension
                    'S': None,   # input dimension
                }
    
    def update_fisher(self, grads):
        """
        从梯度统计更新Fisher估计
        """
        # 使用EMA累积
        beta = 0.99
        
        for name, grad in grads.items():
            layer = '.'.join(name.split('.')[:-2])
            tensor_type = name.split('.')[-1]
            
            if layer in self.fisher and tensor_type in self.fisher[layer]:
                param_shape = grad.shape
                
                # 计算 Kronecker 因子
                if len(param_shape) == 2:  # 权重矩阵
                    # G ≈ E[∂L/∂y ∂L/∂y^T] (output dim)
                    # S ≈ E[∂L/∂x ∂L/∂x^T] (input dim)
                    G = torch.einsum('ij,ik->jk', grad, grad) / grad.shape[0]
                    S = torch.einsum('ij,ik->ji', grad, grad) / grad.shape[0]
                else:  # 偏置
                    G = torch.einsum('i,j->ij', grad, grad) / grad.shape[0]
                    S = None
                
                # 更新估计
                if self.fisher[layer][tensor_type]['G'] is None:
                    self.fisher[layer][tensor_type]['G'] = G
                    if S is not None:
                        self.fisher[layer][tensor_type]['S'] = S
                else:
                    self.fisher[layer][tensor_type]['G'] = beta * self.fisher[layer][tensor_type]['G'] + (1-beta) * G
                    if S is not None:
                        self.fisher[layer][tensor_type]['S'] = beta * self.fisher[layer][tensor_type]['S'] + (1-beta) * S
    
    def kfac_step(self):
        """
        执行K-FAC优化步骤
        """
        for layer_name, layer_fisher in self.fisher.items():
            for tensor_type, info in layer_fisher.items():
                if info['G'] is None:
                    continue
                
                param = info['param']
                grad = param.grad
                
                G = info['G'] + self.damping * torch.eye(info['G'].shape[0])
                G_inv = torch.linalg.inv(G)
                
                if info['S'] is not None:
                    S = info['S'] + self.damping * torch.eye(info['S'].shape[0])
                    S_inv = torch.linalg.inv(S)
                    
                    # Kronecker 乘积的逆
                    update = G_inv @ grad @ S_inv
                    param.data -= self.lr * update.T
                else:
                    # 偏置更新
                    update = G_inv @ grad
                    param.data -= self.lr * update

自然梯度的其他应用

1. 神经进化策略

神经进化策略(Natural Evolution Strategies, NES)使用自然梯度更新分布参数:

def nes_update(theta, fitness, population_size=100, lr=0.01):
    """
    简化的自然进化策略更新
    """
    # 从分布采样
    sigma = 0.1
    epsilons = [np.random.randn(len(theta)) for _ in range(population_size)]
    rewards = [fitness(theta + sigma * eps) for eps in epsilons]
    
    # 归一化奖励
    rewards = np.array(rewards)
    rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
    
    # 自然梯度估计
    nat_grad = np.zeros_like(theta)
    for eps, r in zip(epsilons, rewards):
        nat_grad += eps * r
    
    # 更新
    theta += lr * nat_grad / (sigma * population_size)
    
    return theta

2. 贝叶斯优化中的自然梯度

在高斯过程超参数优化中,自然梯度用于最大化边缘似然:

3. 强化学习中的策略梯度

策略梯度方法中的自然梯度用于改进策略更新:


实践建议

何时使用自然梯度

场景推荐理由
小型网络K-FAC效果显著
TransformerK-FAC/AdaGrad大批量训练
RNN自然梯度必要梯度问题严重
在线学习EKFAC效率重要

与其他优化器的比较

优化器曲率近似每次迭代成本收敛速度
SGD
Adam对角中等
AdaGrad对角+累积中等
K-FACKronecker
L-BFGS完整快(但内存大)

总结

概念核心要点
自然梯度Fisher度量下的最速下降方向
Fisher信息矩阵梯度协方差的期望
K-FACFisher的Kronecker分解近似
几何意义坐标不变性、概率空间的几何

自然梯度和K-FAC代表了优化理论中几何视角的力量,它们将参数空间视为概率分布的流形,在那里Fisher信息提供了自然的度量。


参考

Footnotes

  1. Martens, J., & Grosse, R. (2015). Optimizing neural networks with kronecker-factored approximate curvature. ICML.