概述
自然梯度(Natural Gradient)是一种利用Fisher信息矩阵作为黎曼度量的二阶优化方法。相比普通梯度下降,自然梯度在参数空间中具有坐标不变性——无论参数如何缩放或重参数化,优化方向都保持一致。K-FAC(Kronecker-Factored Approximate Curvature)是高效近似自然梯度的方法,已在Transformer训练中被广泛采用。1
核心洞察:在概率分布空间中,“最速下降方向”不是欧几里得梯度方向,而是由Fisher信息矩阵定义的黎曼梯度方向。
从欧几里得梯度到自然梯度
普通梯度下降的问题
设损失函数 ,普通梯度更新:
问题:梯度方向依赖于参数化——相同的学习算法在不同参数化下表现不同。
示例:考虑两种参数化 ,:
参数化 A: θ 参数化 B: 10θ
│ │
│ ∇ℒ │ ∇ℒ
↓ ↓
等高线 等高线(被拉伸!)
真实最速方向相同 梯度方向不同!
黎曼梯度下降
在黎曼流形上,梯度定义为:
其中 是黎曼度量矩阵。
自然梯度使用Fisher信息矩阵 作为度量:
Fisher信息矩阵
定义:对于概率模型 ,Fisher信息矩阵定义为:
性质:
- 对称正定
- 参数化不变
- 是损失函数Hessian的期望(下界)
自然梯度的几何意义
import numpy as np
import matplotlib.pyplot as plt
def illustrate_natural_gradient():
"""可视化自然梯度与普通梯度的区别"""
# 定义损失函数(二维椭球)
def loss(theta):
return 0.5 * (theta[0]**2 + 100 * theta[1]**2)
def gradient(theta):
return np.array([theta[0], 100 * theta[1]])
# Fisher信息矩阵(对于高斯分布,这里简化为对角矩阵)
# 假设参数尺度不同
F = np.diag([1.0, 0.01]) # 模拟不同的信息量
theta = np.array([2.0, 0.1])
grad = gradient(theta)
nat_grad = np.linalg.solve(F, grad)
print(f"普通梯度: {grad}")
print(f"自然梯度: {nat_grad}")
print(f"Fisher矩阵: {F}")
print(f"注意:自然梯度在信息量小的方向上步长更大")
illustrate_natural_gradient()Fisher信息矩阵的计算
对于神经网络
对于神经网络的权重参数 ,Fisher信息矩阵的大小为 ( 是参数数量),直接计算和存储不可行。
定义:
Fisher vs Hessian
| 方面 | Fisher信息矩阵 | Hessian |
|---|---|---|
| 定义 | 梯度的协方差 | 损失的二阶导 |
| 正定性 | 总是正定 | 不一定 |
| 计算 | 只需一阶导数 | 需要二阶导数 |
| 期望 | 遍历数据分布 | 遍历训练样本 |
关系:
其中 是对数似然的Hessian。
经验Fisher
在实际计算中,使用经验Fisher:
这等价于梯度外积的均值。
Kronecker-Factored Approximate Curvature (K-FAC)
核心思想
K-FAC将巨大的Fisher矩阵分解为小块的Kronecker积:
其中 和 是小得多的矩阵。
神经网络层分解
对于全连接层 ,参数为 。
独立近似:
Kronecker分解:
- :对输出单元的Fisher估计
- :对输入的Fisher估计
K-FAC更新公式
对于参数块 :
等价于:
Python实现
import torch
import torch.nn as nn
from collections import OrderedDict
class KFACOptimizer:
"""
K-FAC优化器实现
简化版本:仅处理全连接层
"""
def __init__(self, model, lr=0.001, damping=1e-3, kl_clip=0.01):
self.model = model
self.lr = lr
self.damping = damping
self.kl_clip = kl_clip
# 为每个参数块存储Fisher估计
self.fisher_info = {}
self.state = {}
# 注册所有全连接层
self._register_modules()
def _register_modules(self):
"""识别可优化的参数块"""
self.modules = []
for name, module in self.model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
self.modules.append((name, module))
def step(self, closure=None):
"""一次优化步骤"""
# 1. 计算梯度
if closure is not None:
loss = closure()
else:
loss = None
# 2. 累积Fisher信息(简化:使用单样本估计)
self._update_fisher()
# 3. 执行K-FAC更新
self._kfac_update()
return loss
def _update_fisher(self):
"""
更新Fisher信息估计
简化版本:使用梯度外积
"""
for name, module in self.modules:
if hasattr(module, 'weight') and module.weight.grad is not None:
# 获取梯度
grad_w = module.weight.grad.detach()
grad_b = module.bias.grad.detach() if module.bias is not None else None
# 创建或更新Fisher估计
if name not in self.fisher_info:
self.fisher_info[name] = {
'G': torch.zeros_like(grad_w.T @ grad_w),
'S': torch.zeros_like(grad_w @ grad_w.T),
'G_b': torch.zeros(module.out_features, module.out_features),
}
# Kronecker分解:G ⊗ S ≈ F
# 这里用移动平均近似期望
beta = 0.99 # 指数衰减
# 对于输出维度
G = grad_w.T @ grad_w / grad_w.shape[0]
self.fisher_info[name]['G'] = beta * self.fisher_info[name]['G'] + (1-beta) * G
# 对于输入维度
S = grad_w @ grad_w.T / grad_w.shape[0]
self.fisher_info[name]['S'] = beta * self.fisher_info[name]['S'] + (1-beta) * S
if grad_b is not None:
G_b = grad_b.unsqueeze(1) @ grad_b.unsqueeze(0) / grad_b.shape[0]
self.fisher_info[name]['G_b'] = beta * self.fisher_info[name]['G_b'] + (1-beta) * G_b
def _kfac_update(self):
"""
执行K-FAC更新
W ← W - α * S⁻¹ * (∇W ⊙ ∇W) * G⁻¹
"""
for name, module in self.modules:
if name not in self.fisher_info:
continue
info = self.fisher_info[name]
# 提取梯度
grad_w = module.weight.grad
grad_b = module.bias.grad if module.bias is not None else None
# 计算 Kronecker 逆
G = info['G'] + self.damping * torch.eye(info['G'].shape[0])
S = info['S'] + self.damping * torch.eye(info['S'].shape[0])
G_inv = torch.linalg.solve(G, torch.eye(G.shape[0]))
S_inv = torch.linalg.solve(S, torch.eye(S.shape[0]))
# K-FAC更新(利用分块结构)
# ∇W (S⁻¹ ⊗ G⁻¹) = G⁻¹ @ ∇W @ S⁻¹
update = G_inv @ grad_w @ S_inv
# 应用更新
module.weight.data -= self.lr * update.T
# 偏置更新(如果存在)
if grad_b is not None:
G_b = info['G_b'] + self.damping * torch.eye(info['G_b'].shape[0])
G_b_inv = torch.linalg.solve(G_b, torch.eye(G_b.shape[0]))
update_b = G_b_inv @ grad_b
module.bias.data -= self.lr * update_b
def zero_grad(self):
"""清零梯度"""
for _, module in self.modules:
if module.weight.grad is not None:
module.weight.grad.zero_()
if module.bias is not None and module.bias.grad is not None:
module.bias.grad.zero_()
# 测试K-FAC
def test_kfac():
"""简单测试"""
torch.manual_seed(42)
model = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
)
optimizer = KFACOptimizer(model, lr=0.01, damping=1e-3)
criterion = nn.MSELoss()
# 随机数据
x = torch.randn(32, 10)
y = torch.randn(32, 5)
for epoch in range(3):
optimizer.zero_grad()
pred = model(x)
loss = criterion(pred, y)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
test_kfac()K-FAC的理论保证
| 近似类型 | 准确性 | 计算复杂度 |
|---|---|---|
| 全Fisher | 精确 | 存储, 逆 |
| K-FAC | 分块近似 | 存储, 更新 |
对角近似(EKFAC)
更激进的近似:对每个参数独立估计Fisher对角元:
这将复杂度降到 ,但准确性下降。
K-FAC在Transformer中的应用
Transformer的Fisher结构
Transformer中的注意力机制具有特殊的Fisher结构:
- Query/Key/Value投影:独立块
- FFN层:两个线性层
- 注意力权重:高度条件化
实践考量
class TransformerKFAC:
"""
Transformer的K-FAC优化器
"""
def __init__(self, model, lr=0.001, damping=1e-3):
self.model = model
self.lr = lr
self.damping = damping
# 分层存储Fisher信息
self.fisher = {}
self._setup_blocks()
def _setup_blocks(self):
"""识别参数块"""
self.blocks = []
for name, param in self.model.named_parameters():
if 'weight' in name:
# 按层组织
layer_name = '.'.join(name.split('.')[:-2]) # e.g., encoder.layer.0
if layer_name not in self.fisher:
self.fisher[layer_name] = {}
# 确定是哪个张量
tensor_type = name.split('.')[-1] # weight or bias
self.fisher[layer_name][tensor_type] = {
'param': param,
'G': None, # output dimension
'S': None, # input dimension
}
def update_fisher(self, grads):
"""
从梯度统计更新Fisher估计
"""
# 使用EMA累积
beta = 0.99
for name, grad in grads.items():
layer = '.'.join(name.split('.')[:-2])
tensor_type = name.split('.')[-1]
if layer in self.fisher and tensor_type in self.fisher[layer]:
param_shape = grad.shape
# 计算 Kronecker 因子
if len(param_shape) == 2: # 权重矩阵
# G ≈ E[∂L/∂y ∂L/∂y^T] (output dim)
# S ≈ E[∂L/∂x ∂L/∂x^T] (input dim)
G = torch.einsum('ij,ik->jk', grad, grad) / grad.shape[0]
S = torch.einsum('ij,ik->ji', grad, grad) / grad.shape[0]
else: # 偏置
G = torch.einsum('i,j->ij', grad, grad) / grad.shape[0]
S = None
# 更新估计
if self.fisher[layer][tensor_type]['G'] is None:
self.fisher[layer][tensor_type]['G'] = G
if S is not None:
self.fisher[layer][tensor_type]['S'] = S
else:
self.fisher[layer][tensor_type]['G'] = beta * self.fisher[layer][tensor_type]['G'] + (1-beta) * G
if S is not None:
self.fisher[layer][tensor_type]['S'] = beta * self.fisher[layer][tensor_type]['S'] + (1-beta) * S
def kfac_step(self):
"""
执行K-FAC优化步骤
"""
for layer_name, layer_fisher in self.fisher.items():
for tensor_type, info in layer_fisher.items():
if info['G'] is None:
continue
param = info['param']
grad = param.grad
G = info['G'] + self.damping * torch.eye(info['G'].shape[0])
G_inv = torch.linalg.inv(G)
if info['S'] is not None:
S = info['S'] + self.damping * torch.eye(info['S'].shape[0])
S_inv = torch.linalg.inv(S)
# Kronecker 乘积的逆
update = G_inv @ grad @ S_inv
param.data -= self.lr * update.T
else:
# 偏置更新
update = G_inv @ grad
param.data -= self.lr * update自然梯度的其他应用
1. 神经进化策略
神经进化策略(Natural Evolution Strategies, NES)使用自然梯度更新分布参数:
def nes_update(theta, fitness, population_size=100, lr=0.01):
"""
简化的自然进化策略更新
"""
# 从分布采样
sigma = 0.1
epsilons = [np.random.randn(len(theta)) for _ in range(population_size)]
rewards = [fitness(theta + sigma * eps) for eps in epsilons]
# 归一化奖励
rewards = np.array(rewards)
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
# 自然梯度估计
nat_grad = np.zeros_like(theta)
for eps, r in zip(epsilons, rewards):
nat_grad += eps * r
# 更新
theta += lr * nat_grad / (sigma * population_size)
return theta2. 贝叶斯优化中的自然梯度
在高斯过程超参数优化中,自然梯度用于最大化边缘似然:
3. 强化学习中的策略梯度
策略梯度方法中的自然梯度用于改进策略更新:
实践建议
何时使用自然梯度
| 场景 | 推荐 | 理由 |
|---|---|---|
| 小型网络 | K-FAC | 效果显著 |
| Transformer | K-FAC/AdaGrad | 大批量训练 |
| RNN | 自然梯度必要 | 梯度问题严重 |
| 在线学习 | EKFAC | 效率重要 |
与其他优化器的比较
| 优化器 | 曲率近似 | 每次迭代成本 | 收敛速度 |
|---|---|---|---|
| SGD | 无 | 慢 | |
| Adam | 对角 | 中等 | |
| AdaGrad | 对角+累积 | 中等 | |
| K-FAC | Kronecker | 快 | |
| L-BFGS | 完整 | 快(但内存大) |
总结
| 概念 | 核心要点 |
|---|---|
| 自然梯度 | Fisher度量下的最速下降方向 |
| Fisher信息矩阵 | 梯度协方差的期望 |
| K-FAC | Fisher的Kronecker分解近似 |
| 几何意义 | 坐标不变性、概率空间的几何 |
自然梯度和K-FAC代表了优化理论中几何视角的力量,它们将参数空间视为概率分布的流形,在那里Fisher信息提供了自然的度量。
参考
Footnotes
-
Martens, J., & Grosse, R. (2015). Optimizing neural networks with kronecker-factored approximate curvature. ICML. ↩