二阶优化方法深度理论
二阶优化方法利用目标函数的曲率信息来指导参数更新,是深度学习优化的重要理论基础。与一阶方法相比,二阶方法能够更有效地处理病态条件数问题,实现更快的收敛速度。本文档从信息几何的视角出发,深入探讨自然梯度法、K-FAC、Shampoo等二阶优化器的理论基础与实践方法。12
1. 自然梯度与信息几何
1.1 Fisher信息矩阵的定义与性质
概率分布参数空间的几何结构
在参数化概率分布 的空间中,我们可以定义一种黎曼几何结构。设参数空间 是 维流形,参数 定义了概率分布。
**Fisher信息矩阵(Fisher Information Matrix, FIM)**定义为对数似然函数梯度的外积的期望:
其中 是 score function(得分函数)。
数学性质:
-
对称半正定性:
-
Cramer-Rao下界:对于任意无偏估计量 ,其协方差矩阵满足:
这表明 Fisher 信息的逆给出了参数估计精度的下界。
-
KL散度的二阶近似:两个无限接近的分布间的 KL 散度可以用 Fisher 矩阵近似:
-
一致性条件:若 正确指定,则 Fisher 信息矩阵等于对数似然的 Hessian 矩阵的负期望:
这建立了 Fisher 信息与 Hessian 之间的联系。
神经网络中的Fisher信息矩阵
对于神经网络 和损失函数 ,当使用交叉熵损失时,有:
这使得 Fisher 矩阵成为自然梯度更新中的自然曲率度量。
1.2 KL散度作为黎曼度量
黎曼度量与参数空间几何
在参数空间 上,黎曼度量 赋予每点 一个正定矩阵 ,定义两点间的无穷小距离为:
KL散度诱导的度量结构:
KL散度 在分布空间上定义了非对称的”距离”概念。对于无限接近的参数 和 ,有:
去掉高阶项后,二次型 定义了参数空间上的黎曼度量,其中 Fisher 矩阵 作为黎曼度量张量。
与欧氏空间的本质区别:
| 方面 | 欧氏梯度下降 | 自然梯度下降 |
|---|---|---|
| 度量 | (单位矩阵) | (Fisher矩阵) |
| 距离定义 | $\ | \delta\ |
| 更新方向 | 损失函数下降最陡 | KL散度意义下最陡 |
| 学习率敏感性 | 高 | 低 |
1.3 自然梯度更新公式
从变分原理推导
自然梯度法寻找在分布空间(而非参数空间)中使损失函数下降最大的方向。考虑约束:
使用拉格朗日乘子法:
在一阶近似下求解,得到自然梯度更新:
其中 被称为 自然度量的逆,用于将欧氏空间中的梯度转换为黎曼流形上的切向量。4
自然梯度的计算流程:
import torch
import torch.nn as nn
def compute_natural_gradient(model, fisher_matrix, loss_fn, damping=1e-3):
"""
计算自然梯度
自然梯度 = F^{-1} ∇L
其中 F 是 Fisher 信息矩阵
"""
# 1. 计算普通梯度
loss = loss_fn(model)
grad = torch.autograd.grad(
loss,
model.parameters(),
retain_graph=True
)
# 2. 准备梯度向量(展平)
grad_vector = torch.cat([g.flatten() for g in grad])
# 3. 求解线性系统 F x = grad
# 使用对角近似加速求解
fisher_diag = fisher_matrix.diagonal()
natural_grad_vector = grad_vector / (fisher_diag + damping)
# 4. 重新整形为参数形状
natural_grads = []
offset = 0
for p in model.parameters():
size = p.numel()
natural_grads.append(
natural_grad_vector[offset:offset+size].view(p.shape)
)
offset += size
return natural_grads1.4 自然梯度的几何意义
黎曼流形上的梯度下降
自然梯度下降本质上是 黎曼梯度下降 在 Fisher-Riemannian 流形上的特例。在黎曼几何中,梯度需要通过度量张量的逆来变换:
当 时,这就是自然梯度。
概率分布空间的几何直观:
考虑两个高斯分布 和 。它们的 Fisher 信息矩阵为 (标量)。自然梯度更新为:
这表明自然梯度会根据分布的曲率自动调整步长:在低曲率区域(大方差)迈大步,在高曲率区域(小方差)迈小步。
与Kullback-Leibler流的联系:
自然梯度更新可以解释为 KL散度下降流(KL Divergence Flow):
当 时,上式收敛到自然梯度流。
参数重参数化不变性:
自然梯度的一个重要性质是 参数重参数化不变性。如果我们对参数进行可逆变换 ,自然梯度在新的参数空间中保持形式不变:
这与欧氏梯度下降形成对比,后者在非线性重参数化下会改变更新方向。
2. Kronecker因子分解(K-FAC)
2.1 Kronecker分解的数学原理
Kronecker积回顾
对于矩阵 和 ,其 Kronecker 积 定义为:
Kronecker分解的核心思想:
设 是完整的 Fisher 信息矩阵。K-FAC 假设 可以近似分解为 Kronecker 积的直和:
这种分解具有以下性质:
- 存储复杂度:从 降至
- 求逆复杂度:从 降至
- 矩阵-向量乘积:
2.2 K-FAC对Fisher信息的近似
神经网络层的结构
考虑一个线性层 ,其中 是权重矩阵, 是输入, 是输出。
独立性假设:
K-FAC 的关键假设是:对于给定层,输入激活 与输出梯度 近似独立。这使得 Fisher 矩阵可以分解为:
其中:
- 是输入的协方差矩阵
- 是输出梯度的协方差矩阵
经验估计:
在实际计算中,使用 mini-batch 上的经验估计:
其中 是 batch size。
分解的数学证明(简化版):
令 为权重的对数似然梯度。根据链式法则:
则 Fisher 矩阵为:
若 与 独立,则:
2.3 实践实现与计算效率分析
K-FAC的完整实现框架:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
class KFACOptimizer:
"""
Kronecker-Factored Approximate Curvature (K-FAC) 优化器
参考文献: Martens & Grosse (2015)
"""
def __init__(
self,
model,
lr=0.001,
momentum=0.9,
stat_decay=0.99,
damping=1e-3,
kl_interval=10, # Fisher矩阵更新间隔
inverse_interval=50 # 逆矩阵更新间隔
):
self.model = model
self.lr = lr
self.momentum = momentum
self.stat_decay = stat_decay
self.damping = damping
self.kl_interval = kl_interval
self.inverse_interval = inverse_interval
self.step_count = 0
self.fisher_blocks = {} # Fisher矩阵块
self.inverse_blocks = {} # 逆矩阵块
self.grad_buffers = {} # 梯度缓冲
self._register_hooks()
def _register_hooks(self):
"""注册前向钩子以捕获中间激活"""
self.activations = {}
self.grad_outputs = {}
def forward_hook(name):
def hook(module, input, output):
self.activations[name] = input[0].detach()
return hook
def backward_hook(name):
def hook(module, grad_input, grad_output):
self.grad_outputs[name] = grad_output[0].detach()
return hook
for name, module in self.model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.register_forward_hook(forward_hook(name))
module.register_full_backward_hook(backward_hook(name))
def _update_fisher_statistics(self):
"""更新Fisher统计量(使用EMA)"""
for name, module in self.model.named_modules():
if not isinstance(module, (nn.Linear, nn.Conv2d)):
continue
grad = module.weight.grad
if grad is None:
continue
if name not in self.fisher_blocks:
self.fisher_blocks[name] = {
'A': None, 'B': None, 'cnt': 0
}
# 计算当前batch的统计量
if isinstance(module, nn.Linear):
A, B = self._compute_linear_factors(module, grad)
else:
A, B = self._compute_conv_factors(module, grad)
# 指数滑动平均
if self.fisher_blocks[name]['A'] is None:
self.fisher_blocks[name]['A'] = A
self.fisher_blocks[name]['B'] = B
else:
self.fisher_blocks[name]['A'] = (
self.stat_decay * self.fisher_blocks[name]['A'] +
(1 - self.stat_decay) * A
)
self.fisher_blocks[name]['B'] = (
self.stat_decay * self.fisher_blocks[name]['B'] +
(1 - self.stat_decay) * B
)
self.fisher_blocks[name]['cnt'] += 1
def _compute_linear_factors(self, module, grad):
"""
计算线性层的Kronecker因子
对于 Linear(in, out):
- A: 输入协方差 (in × in)
- B: 输出梯度协方差 (out × out)
"""
x = self.activations.get(str(id(module)), None)
grad_out = self.grad_outputs.get(str(id(module)), None)
if x is None or grad_out is None:
# 使用默认的独立近似
return torch.eye(grad.shape[1], device=grad.device), \
torch.eye(grad.shape[0], device=grad.device)
batch_size = x.shape[0]
# A: 输入协方差 (使用无偏估计)
A = x.T @ x / batch_size
# B: 输出梯度协方差
B = grad_out.T @ grad_out / batch_size
return A, B
def _compute_conv_factors(self, module, grad):
"""
计算卷积层的Kronecker因子
将卷积操作展开为 im2col 矩阵乘法
"""
x = self.activations.get(str(id(module)), None)
grad_out = self.grad_outputs.get(str(id(module)), None)
if x is None or grad_out is None:
in_channels = grad.shape[1] * module.kernel_size[0] * module.kernel_size[1]
out_channels = grad.shape[0]
return torch.eye(in_channels, device=grad.device), \
torch.eye(out_channels, device=grad.device)
# 使用 im2col 展开
x_col = self._im2col(x, module)
grad_out_unfold = grad_out.flatten(1)
batch_size = x.shape[0]
A = x_col @ x_col.T / batch_size
B = grad_out_unfold @ grad_out_unfold.T / batch_size
return A, B
def _im2col(self, x, conv_module):
"""将卷积输入展开为列矩阵"""
# 简化实现:实际中需要正确处理padding, stride等
batch, channels, height, width = x.shape
k_h, k_w = conv_module.kernel_size
out_h = conv_module.output_size[0] if hasattr(conv_module, 'output_size') else \
(height - k_h) // conv_module.stride[0] + 1
out_w = conv_module.output_size[1] if hasattr(conv_module, 'output_size') else \
(width - k_w) // conv_module.stride[1] + 1
# 简化为展平
return x.flatten(2)
def _update_inverse_matrices(self):
"""使用特征分解更新逆矩阵"""
for name, fisher in self.fisher_blocks.items():
A, B = fisher['A'], fisher['B']
if A is None or B is None:
continue
# 添加阻尼并计算特征分解
eigenvalues_A, eigenvectors_A = torch.linalg.eigh(
A + self.damping * torch.eye(A.shape[0], device=A.device)
)
eigenvalues_B, eigenvectors_B = torch.linalg.eigh(
B + self.damping * torch.eye(B.shape[0], device=B.device)
)
# 计算逆矩阵(通过特征分解)
inv_eigenvalues_A = 1.0 / torch.clamp(eigenvalues_A, min=self.damping)
inv_eigenvalues_B = 1.0 / torch.clamp(eigenvalues_B, min=self.damping)
self.inverse_blocks[name] = {
'A_inv': eigenvectors_A @ torch.diag(inv_eigenvalues_A) @ eigenvectors_A.T,
'B_inv': eigenvectors_B @ torch.diag(inv_eigenvalues_B) @ eigenvectors_B.T
}
def _compute_corrected_gradient(self, name, module, grad):
"""
计算Kronecker分解修正后的梯度
(B^{-1} ⊗ A^{-1}) vec(∇W) = vec(B^{-1} ∇W A^{-1})
"""
if name not in self.inverse_blocks:
return grad
inv_A = self.inverse_blocks[name]['A_inv']
inv_B = self.inverse_blocks[name]['B_inv']
# Kronecker乘积的性质:(B^{-1} ⊗ A^{-1})vec(∇W) = vec(B^{-1}∇W A^{-1})
corrected_grad = inv_B @ grad @ inv_A.T
return corrected_grad
def step(self):
"""K-FAC优化步骤"""
self.step_count += 1
# 1. 定期更新Fisher统计
if self.step_count % self.kl_interval == 0:
self._update_fisher_statistics()
# 2. 定期更新逆矩阵
if self.step_count % self.inverse_interval == 0:
self._update_inverse_matrices()
# 3. 应用修正梯度
with torch.no_grad():
for name, module in self.model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
if module.weight.grad is None:
continue
if name in self.inverse_blocks:
corrected_grad = self._compute_corrected_gradient(
name, module, module.weight.grad
)
module.weight -= self.lr * corrected_grad
if module.bias is not None and module.bias.grad is not None:
module.bias -= self.lr * module.bias.grad
else:
module.weight -= self.lr * module.weight.grad
if module.bias is not None:
module.bias -= self.lr * module.bias.grad
self.model.zero_grad()计算效率分析:
| 操作 | 全量Fisher | K-FAC |
|---|---|---|
| 存储 | ||
| Fisher更新 | ||
| 求逆 | ||
| 梯度修正 |
对于大型Transformer模型,K-FAC通常只需要存储约 级别的内存。
2.4 K-FAC在大型模型训练中的应用
分布式K-FAC实现:
class DistributedKFACOptimizer:
"""
分布式K-FAC优化器,支持多GPU训练
Fisher统计在每个GPU上独立计算,然后通过AllReduce聚合
"""
def __init__(self, model, lr=0.001, world_size=1, rank=0, **kwargs):
self.model = model
self.lr = lr
self.world_size = world_size
self.rank = rank
self.kfac = KFACOptimizer(model, lr=lr, **kwargs)
def sync_fisher_stats(self):
"""跨GPU同步Fisher统计量"""
for name in self.kfac.fisher_blocks:
A = self.kfac.fisher_blocks[name]['A']
B = self.kfac.fisher_blocks[name]['B']
# AllReduce聚合
A_all = torch.distributed.all_reduce(
A, op=torch.distributed.ReduceOp.SUM, async_op=True
)
B_all = torch.distributed.all_reduce(
B, op=torch.distributed.ReduceOp.SUM, async_op=True
)
# 取平均
self.kfac.fisher_blocks[name]['A'] = A_all / self.world_size
self.kfac.fisher_blocks[name]['B'] = B_all / self.world_size
def step(self):
"""分布式优化步骤"""
if self.rank == 0:
# 主GPU更新Fisher统计
self.kfac._update_fisher_statistics()
# 同步Fisher统计
self.sync_fisher_stats()
# 主GPU更新逆矩阵
if self.rank == 0:
self.kfac._update_inverse_matrices()
# 广播逆矩阵
self._broadcast_inverse_blocks()
# 所有GPU应用梯度
self.kfac.step()与标准优化器的对比:
| 特性 | SGD | Adam | K-FAC |
|---|---|---|---|
| 收敛速度 | 慢 | 中等 | 快 |
| 内存开销 | 低 | 中等 | 高 |
| 超参敏感性 | 高 | 低 | 中等 |
| 曲率利用 | 无 | 对角近似 | 块结构 |
| 大规模可扩展性 | 好 | 极好 | 中等 |
3. Shampoo优化器理论
3.1 Preconditioning矩阵的定义
预条件子的几何意义
预条件子的核心思想是将优化问题从”病态”坐标系变换到”良好条件”的坐标系。在二阶优化中,预条件子 被设计为近似 Hessian 的逆:
理想的预条件子应满足:
- (曲率近似)
- 易于计算
- 存储开销可接受
Shampoo的Preconditioning策略:
Shampoo优化器为每个参数张量 构建 个预条件子矩阵 ,分别对应每个维度。
3.2 Adagrad的块对角推广
Adagrad回顾:
Adagrad维护一个累积二阶矩矩阵 ,更新公式为:
其中 是对角元素。
Shampoo的块对角推广:
对于矩阵参数 ,Shampoo定义两个预条件子:
更新公式为:
使用 Kronecker 乘积的性质,这等价于:
数学推导:
设 是完整的 Fisher/Gram 矩阵。
若 ,则:
这意味着我们只需计算 和 的逆平方根,而非完整矩阵 。
3.3 SVD分解与参数变换
矩阵平方根的计算:
对于正定矩阵 ,其矩阵平方根 可通过特征分解计算:
Power Iteration方法:
为避免每次迭代计算完整的特征分解,Shampoo使用幂迭代法(Power Iteration)来近似矩阵平方根的逆:
def matrix_power_iteration(A, p, k=5, num_repeats=2):
"""
计算 A^{p} 的近似
使用幂迭代法,其中 p 通常为负分数(如 -1/4)
"""
m, n = A.shape
device = A.device
# 初始化随机矩阵
Q = torch.randn(n, min(n, m), device=device)
Q, _ = torch.linalg.qr(Q)
for _ in range(num_repeats):
# Power iteration
for _ in range(k):
Q = A @ Q
Q, _ = torch.linalg.qr(Q)
# 归一化
R = Q.T @ A @ Q
R_inv_pow = torch.sign(R) * torch.abs(R) ** p
Q = A @ Q @ torch.linalg.inv(R) @ torch.linalg.cholesky(R_inv_pow + 1e-6 * torch.eye(min(n, m), device=device)).transpose(-2, -1)
Q, _ = torch.linalg.qr(Q)
return Q @ torch.linalg.cholesky(
Q.T @ A @ Q + 1e-6 * torch.eye(min(n, m), device=device)
).transpose(-2, -1)LORA框架下的Shampoo:
class ShampooOptimizer:
"""
Shampoo优化器实现
参考文献: Gupta et al. (2018) - Shampoo: Preconditioned Stochastic Tensor Optimization
"""
def __init__(
self,
model,
lr=0.001,
betas=(0.9, 0.999),
epsilon=1e-6,
weight_decay=0.01,
matrix_power=4,
power_iter_steps=5,
power_iter_repeats=2
):
self.model = model
self.lr = lr
self.betas = betas
self.epsilon = epsilon
self.weight_decay = weight_decay
self.matrix_power = matrix_power
self.power_iter_steps = power_iter_steps
self.power_iter_repeats = power_iter_repeats
self.preconditioners = {} # 存储每个参数的预条件子
self.timestep = 0
self._init_preconditioners()
def _init_preconditioners(self):
"""初始化预条件子"""
for name, param in self.model.named_parameters():
if param.dim() >= 2:
self.preconditioners[name] = {
'stats': [], # 各维度的统计矩阵
'inv_stats': [] # 各维度的逆统计矩阵
}
# 为每个维度初始化统计矩阵
for dim in range(param.dim()):
size = param.shape[dim]
self.preconditioners[name]['stats'].append(
torch.zeros(size, size, device=param.device)
)
self.preconditioners[name]['inv_stats'].append(None)
def _compute_tensor_gradient(self, param, grad):
"""
将梯度张量按各维度求外积
对于 W_{i1,i2,...,ik},计算各维度的 Gram 矩阵
"""
stats = []
# 沿各维度计算统计量
for dim in range(grad.dim()):
# 计算该维度的投影
# G_d = sum_{i1,...,ik} grad_{..., :, ...} grad_{..., :, ...}^T
grad_transposed = grad.transpose(0, dim)
grad_flattened = grad_transposed.reshape(
grad.shape[dim], -1
)
stat = grad_flattened @ grad_flattened.T
stats.append(stat)
return stats
def _update_statistics(self, param, grad):
"""更新统计矩阵"""
name = None
for n, p in self.model.named_parameters():
if p is param:
name = n
break
if name is None or name not in self.preconditioners:
return
# 计算当前batch的统计量
current_stats = self._compute_tensor_gradient(param, grad)
# 使用EMA更新
decay = self.betas[1]
for i, stat in enumerate(current_stats):
self.preconditioners[name]['stats'][i] = (
decay * self.preconditioners[name]['stats'][i] +
(1 - decay) * stat
)
def _compute_inverse_square_root(self, A):
"""
计算 (A + εI)^{-1/2}
使用幂迭代法近似
"""
# 预处理
A = A + self.epsilon * torch.eye(A.shape[0], device=A.device)
# 特征分解
eigenvalues, eigenvectors = torch.linalg.eigh(A)
# 逆平方根
eigenvalues = torch.clamp(eigenvalues, min=self.epsilon)
inv_sqrt_eigenvalues = eigenvalues ** (-0.5)
return eigenvectors @ torch.diag(inv_sqrt_eigenvalues) @ eigenvectors.T
def _update_inverse_statistics(self):
"""更新逆统计矩阵"""
for name, precond in self.preconditioners.items():
for i, stat in enumerate(precond['stats']):
if stat is not None and torch.trace(stat) > 0:
precond['inv_stats'][i] = self._compute_inverse_square_root(stat)
def step(self):
"""Shampoo优化步骤"""
self.timestep += 1
# 更新统计量
with torch.no_grad():
for param in self.model.parameters():
if param.grad is None:
continue
if param.dim() >= 2:
self._update_statistics(param, param.grad)
# 定期更新逆统计矩阵
if self.timestep % 100 == 0:
self._update_inverse_statistics()
# 应用更新
with torch.no_grad():
for name, param in self.model.named_parameters():
if param.grad is None:
continue
if name in self.preconditioners:
precond = self.preconditioners[name]
if all(inv is not None for inv in precond['inv_stats']):
# 应用预条件梯度
update = param.grad.clone()
# 从右到左应用各维度的逆平方根
for i in range(param.dim() - 1, -1, -1):
inv_stat = precond['inv_stats'][i]
if i == param.dim() - 1:
update = update @ inv_stat
else:
# 需要正确的索引逻辑
pass
param -= self.lr * update
else:
param -= self.lr * param.grad
else:
param -= self.lr * param.grad
# 权重衰减
if self.weight_decay > 0:
param -= self.weight_decay * self.lr * param
self.model.zero_grad()3.4 Shampoo的收敛性分析
理论基础:
Shampoo的收敛性分析基于以下观察:
设 是累积梯度外积矩阵,则 Adagrad 的学习率调整相当于使用 作为预条件子。
Shampoo的收敛率:
对于凸优化问题,Shampoo达到 的 regret bound,与 Adagrad 相当,但在非凸设置下表现更好。
与一阶方法的比较:
| 方法 | 预条件子 | 收敛率 | 每步复杂度 |
|---|---|---|---|
| SGD | |||
| Adagrad | |||
| Shampoo |
4. 近似二阶方法的比较
4.1 共轭梯度法(CG)
共轭梯度法的基本原理
共轭梯度法是求解线性系统 (其中 )的迭代方法,特别适用于稀疏矩阵和大型系统。
数学框架:
给定二次函数 ,其梯度为 。在极小点处有 。
共轭梯度法通过构建 -共轭方向序列来求解:
算法实现:
def conjugate_gradient(
matvec, # 矩阵-向量乘积函数
b, # 右手边向量
x0=None,
max_iter=None,
tol=1e-6,
preconditioner=None
):
"""
共轭梯度法求解 Ax = b
参数:
matvec: 矩阵-向量乘积函数
b: 右手边向量
x0: 初始解
max_iter: 最大迭代次数
tol: 收敛容忍度
preconditioner: 预条件子函数
"""
x = torch.zeros_like(b) if x0 is None else x0.clone()
r = b - matvec(x)
if preconditioner is not None:
z = preconditioner(r)
else:
z = r.clone()
p = z.clone()
rz_old = torch.dot(r, z)
n = b.numel()
if max_iter is None:
max_iter = n
for i in range(max_iter):
Ap = matvec(p)
# 线搜索步长
pAp = torch.dot(p, Ap)
if pAp <= 0:
# 负曲率方向,使用梯度下降
alpha = torch.dot(r, r) / (torch.dot(r, Ap) + 1e-10)
else:
alpha = rz_old / pAp
# 更新
x = x + alpha * p
r = r - alpha * Ap
# 检查收敛
if torch.norm(r) < tol:
break
# 预条件子
if preconditioner is not None:
z = preconditioner(r)
else:
z = r.clone()
# 共轭方向更新
rz_new = torch.dot(r, z)
beta = rz_new / (rz_old + 1e-10)
p = z + beta * p
rz_old = rz_new
return x
def hessian_free_optimization(
model,
loss_fn,
lr=1.0,
damping=0.1,
cg_iterations=50,
max_iter=100
):
"""
Hessian-Free优化器
使用CG求解 H^{-1}g
"""
history = {
'grads': [],
'params': [],
'losses': []
}
for iteration in range(max_iter):
# 计算损失和梯度
loss = loss_fn(model)
loss.backward()
# 收集梯度
g = torch.cat([p.grad.flatten() for p in model.parameters()])
history['losses'].append(loss.item())
history['grads'].append(g.clone())
# 定义Hessian-vector乘积(无精确Hessian)
def hvp(v):
# 使用有限差分近似 Hessian-vector 乘积
# Hv ≈ (∇L(w+εv) - ∇L(w-εv)) / (2ε)
epsilon = 1e-3
params = [p.clone() for p in model.parameters()]
# 正向扰动
offset = 0
for p in model.parameters():
size = p.numel()
v_p = v[offset:offset+size].reshape(p.shape)
p.data.add_(v_p, alpha=epsilon)
offset += size
model.zero_grad()
loss_plus = loss_fn(model)
grads_plus = torch.cat([p.grad.flatten() for p in model.parameters()])
# 恢复并负向扰动
offset = 0
for p, orig in zip(model.parameters(), params):
p.data.copy_(orig)
for p in model.parameters():
size = p.numel()
v_p = v[offset:offset+size].reshape(p.shape)
p.data.add_(v_p, alpha=-epsilon)
offset += size
model.zero_grad()
loss_minus = loss_fn(model)
grads_minus = torch.cat([p.grad.flatten() for p in model.parameters()])
# 恢复原始参数
for p, orig in zip(model.parameters(), params):
p.data.copy_(orig)
hvp = (grads_plus - grads_minus) / (2 * epsilon)
hvp = hvp + damping * v # 添加阻尼
return hvp
# 使用CG求解 H^{-1}g
delta = conjugate_gradient(
hvp,
-g,
max_iter=cg_iterations
)
# 更新参数
offset = 0
with torch.no_grad():
for p in model.parameters():
size = p.numel()
delta_p = delta[offset:offset+size].reshape(p.shape)
p.add_(delta_p, alpha=-lr)
offset += size
model.zero_grad()
print(f"Iter {iteration}: Loss = {loss.item():.6f}")
return history4.2 拟牛顿方法(L-BFGS)
BFGS的更新公式
BFGS(Broyden-Fletcher-Goldfarb-Shanno)方法使用秩-2校正来逼近Hessian的逆:
其中 (参数更新),(梯度差)。
L-BFGS(Limited-memory BFGS) 存储最近 对 来近似 。
class LBFGSOptimizer:
"""
Limited-memory BFGS 优化器
"""
def __init__(
self,
model,
lr=1.0,
max_history=10,
damping=0.0,
Wolfe=True
):
self.model = model
self.lr = lr
self.max_history = max_history
self.damping = damping
self.Wolfe = Wolfe
self.s_history = [] # 参数差
self.y_history = [] # 梯度差
self.rho_history = [] # 缩放因子
def _two_loop_recursion(self, grad):
"""
两循环递归计算 H^{-1} * grad
这是L-BFGS的核心算法
"""
q = grad.clone()
# 反向循环
alphas = []
for s, y, rho in zip(
reversed(self.s_history),
reversed(self.y_history),
reversed(self.rho_history)
):
alpha = rho * torch.dot(s, q)
alphas.append(alpha)
q = q - alpha * y
# 初始缩放(使用最近一次的曲率估计)
if len(self.s_history) > 0:
s, y = self.s_history[-1], self.y_history[-1]
gamma = torch.dot(y, s) / (torch.dot(y, y) + 1e-10)
z = gamma * q
else:
z = q
# 正向循环
betas = []
for s, y, rho, alpha in zip(
self.s_history,
self.y_history,
self.rho_history,
reversed(alphas)
):
beta = rho * torch.dot(y, z)
betas.append(beta)
z = z + s * (alpha - beta)
return z
def step(self):
"""
L-BFGS 优化步骤
使用线搜索满足 Wolfe 条件
"""
# 获取当前参数和梯度
params = [p.clone() for p in self.model.parameters()]
loss = self._compute_loss()
loss.backward()
grad = torch.cat([p.grad.flatten() for p in self.model.parameters()])
# 计算搜索方向
direction = self._two_loop_recursion(grad)
# 线搜索
if self.Wolfe:
alpha = self._wolfe_line_search(params, grad, direction)
else:
alpha = self.lr
# 更新参数
offset = 0
with torch.no_grad():
for p in self.model.parameters():
size = p.numel()
delta = alpha * direction[offset:offset+size].reshape(p.shape)
p.add_(delta)
offset += size
# 更新历史
new_params = [p.clone() for p in self.model.parameters()]
new_grad = self._compute_loss_grad()
s = torch.cat([p.flatten() for p in new_params]) - \
torch.cat([p.flatten() for p in params])
y = new_grad - grad
# 曲率条件检查:y^T s > 0
sty = torch.dot(y, s)
if sty > 0:
self.s_history.append(s)
self.y_history.append(y)
self.rho_history.append(1.0 / (sty + 1e-10))
# 限制历史长度
if len(self.s_history) > self.max_history:
self.s_history.pop(0)
self.y_history.pop(0)
self.rho_history.pop(0)
self.model.zero_grad()
def _compute_loss(self):
"""计算当前损失"""
return sum(p.sum() * 0 for p in self.model.parameters())
def _compute_loss_grad(self):
"""计算当前梯度"""
loss = self._compute_loss()
return torch.cat([p.grad.flatten() for p in self.model.parameters()])
def _wolfe_line_search(self, params, grad, direction):
"""
Wolfe线搜索
满足:
1. Armijo条件: f(x + αd) ≤ f(x) + c1 * α * g^T d
2. 曲率条件: ∇f(x + αd)^T d ≥ c2 * g^T d
"""
c1, c2 = 1e-4, 0.9
alpha = 1.0
alpha_min, alpha_max = 1e-10, 1e10
loss0 = self._compute_loss().item()
grad_dot_d = torch.dot(grad, direction).item()
for _ in range(20):
# 更新参数
offset = 0
for p, orig in zip(self.model.parameters(), params):
size = p.numel()
delta = alpha * direction[offset:offset+size].reshape(p.shape)
p.copy_(orig + delta)
offset += size
loss1 = self._compute_loss().item()
# Armijo 条件
if loss1 > loss0 + c1 * alpha * grad_dot_d:
alpha_max = alpha
alpha = (alpha_min + alpha_max) / 2
else:
# 近似曲率条件
# (在实际中需要计算新梯度)
return alpha
# 更新边界
if alpha < alpha_max:
alpha_max = alpha
else:
alpha_min = alpha
alpha = (alpha_min + alpha_max) / 2
return alpha4.3 随机二阶方法的理论与实践
随机二阶方法的挑战:
大规模机器学习中的二阶方法面临以下挑战:
- Hessian/Fisher 矩阵的随机估计方差大
- 收敛速度与 batch size 的权衡
- 通信开销(分布式训练)
随机Curvature估计:
class StochasticSecondOrderOptimizer:
"""
随机二阶优化器基类
提供通用的随机曲率估计框架
"""
def __init__(self, model, lr, damping=0.1, batch_size=32):
self.model = model
self.lr = lr
self.damping = damping
self.batch_size = batch_size
def _estimate_curvature(self, dataloader):
"""
随机估计曲率矩阵
子类应实现具体的估计方法
"""
raise NotImplementedError
def _solve_linear_system(self, A, b):
"""
求解线性系统 Ax = b
子类可实现高效的求解方法
"""
return torch.linalg.solve(A + self.damping * torch.eye(A.shape[0]), b)SVRG(Stochastic Variance Reduced Gradient)二阶扩展:
SVRG-SO 算法结合了方差缩减技术与二阶信息:
class SVRGSecondOrder:
"""
SVRG-二阶方法
结合SVRG的方差缩减与二阶曲率信息
"""
def __init__(
self,
model,
data_loader,
lr=0.01,
inner_iter=100,
damping=0.1
):
self.model = model
self.data_loader = data_loader
self.lr = lr
self.inner_iter = inner_iter
self.damping = damping
# 存储参考点和梯度
self.theta_ref = None
self.grad_ref = None
def step(self):
"""SVRG-SO外循环"""
# 保存参考点
self.theta_ref = [p.clone() for p in self.model.parameters()]
# 计算参考梯度
self.grad_ref = self._compute_full_gradient()
# 内循环
for _ in range(self.inner_iter):
self._inner_step()
def _inner_step(self):
"""SVRG内循环"""
# 随机采样 batch
batch = next(iter(self.data_loader))
x_batch, y_batch = batch
# 计算梯度
self.model.zero_grad()
loss = self.model.loss(x_batch, y_batch)
loss.backward()
# 估计曲率(使用当前batch的Hessian)
g_t = torch.cat([p.grad.flatten() for p in self.model.parameters()])
# SVRG梯度修正
g_ref = self._compute_gradient_at_ref(x_batch, y_batch)
g_corrected = g_t - g_ref + self.grad_ref
# 二阶更新(简化版:使用对角曲率)
diag_curvature = self._estimate_diag_curvature()
update = g_corrected / (diag_curvature + self.damping)
# 更新参数
offset = 0
with torch.no_grad():
for p in self.model.parameters():
size = p.numel()
delta = self.lr * update[offset:offset+size].reshape(p.shape)
p.sub_(delta)
offset += size
def _compute_full_gradient(self):
"""计算全量梯度(代价高)"""
self.model.zero_grad()
total_loss = 0
for batch in self.data_loader:
x, y = batch
loss = self.model.loss(x, y)
loss.backward()
total_loss += loss.item()
return torch.cat([p.grad.flatten() for p in self.model.parameters()])
def _compute_gradient_at_ref(self, x, y):
"""在参考点计算梯度"""
# 保存当前参数
current_params = [p.clone() for p in self.model.parameters()]
# 恢复参考参数
for p, ref in zip(self.model.parameters(), self.theta_ref):
p.copy_(ref)
# 计算梯度
self.model.zero_grad()
loss = self.model.loss(x, y)
loss.backward()
grad = torch.cat([p.grad.flatten() for p in self.model.parameters()])
# 恢复当前参数
for p, cur in zip(self.model.parameters(), current_params):
p.copy_(cur)
return grad
def _estimate_diag_curvature(self):
"""估计对角曲率(简化为梯度二范数)"""
grad = torch.cat([p.grad.flatten() for p in self.model.parameters()])
return grad ** 25. 二阶方法与自适应方法的联系
5.1 Adam作为一阶近似的K-FAC
从信息几何视角理解Adam:
Adam的自适应学习率机制可以理解为对K-FAC的粗粒度近似。考虑以下对应关系:
| K-FAC | Adam |
|---|---|
| Fisher矩阵 | 二阶矩 |
| Fisher逆 | 梯度归一化 |
| Kronecker分解 | 对角近似 |
| 精确曲率 | 指数滑动平均估计 |
形式化对应:
K-FAC的更新为:
Adam的更新为:
若将 和 近似为对角矩阵,则:
这表明Adam实际上是在参数级别应用了K-FAC的”粗粒度”版本。
5.2 从二阶视角理解Adam
Adam的收敛性重新审视:
从二阶方法的视角,Adam的学习率调整 实际上是对梯度方向的各向异性缩放。
梯度分解:
设参数 ,Adam的更新可写为:
这相当于为每个参数维护一个局部学习率,其大小与该参数方向上的梯度历史成反比。
与自然梯度的联系:
若定义等效的Fisher矩阵为:
则Adam更新近似为:
这正是自然梯度更新的形式!
5.3 混合方法的潜力
AdaGrad-Fisher统一框架:
class HybridOptimizer:
"""
混合优化器:结合二阶信息与自适应学习率
策略1: 在小参数上使用K-FAC,大参数上使用Adam
策略2: 动态切换基于训练阶段
策略3: 层次化预条件
"""
def __init__(
self,
model,
lr=0.001,
kfac_lr=0.001,
adam_lr=0.001,
threshold=1e5 # 参数量阈值
):
self.model = model
self.lr = lr
self.kfac_lr = kfac_lr
self.adam_lr = adam_lr
self.threshold = threshold
# 初始化子优化器
self.kfac = KFACOptimizer(model, lr=kfac_lr)
self.adam = torch.optim.Adam(
model.parameters(),
lr=adam_lr
)
# 参数分区
self.kfac_params = []
self.adam_params = []
self._partition_params()
def _partition_params(self):
"""根据参数量分区"""
for name, param in self.model.named_parameters():
if param.numel() > self.threshold:
self.adam_params.append(param)
else:
self.kfac_params.append((name, param))
def step(self):
"""混合优化步骤"""
# K-FAC处理小参数
for name, param in self.kfac_params:
# K-FAC 更新逻辑
pass
# Adam处理大参数
self.adam.step()
def zero_grad(self):
"""清零梯度"""
self.model.zero_grad()Sophia optimizer:二阶-一阶混合:
Sophia优化器利用Hessian对角线估计来指导Adam风格的更新:
class SophiaOptimizer:
"""
Sophia优化器
使用Hessian对角线作为预条件子
结合二阶信息与Adam的高效性
"""
def __init__(
self,
model,
lr=0.001,
betas=(0.9, 0.999),
rho=0.04,
eps=1e-8
):
self.model = model
self.lr = lr
self.betas = betas
self.rho = rho # Hessian估计的衰减率
self.eps = eps
# 状态变量
self.exp_avg_grad = {}
self.exp_avg_hessian_diag = {}
# 初始化
for name, param in model.named_parameters():
self.exp_avg_grad[name] = torch.zeros_like(param)
self.exp_avg_hessian_diag[name] = torch.zeros_like(param)
def _estimate_hessian_diag(self, param, grad, prev_param, prev_grad):
"""
估计Hessian对角线
使用曲率比率: (g_new - g_old) / (p_new - p_old)
"""
delta_param = param - prev_param
delta_grad = grad - prev_grad
# 避免除零
diag = delta_grad / (delta_param + self.eps)
diag = torch.clamp(diag, min=self.eps)
return diag
def step(self):
"""Sophia优化步骤"""
prev_params = {name: p.clone() for name, p in self.model.named_parameters()}
loss = self._compute_loss()
loss.backward()
with torch.no_grad():
for name, param in self.model.named_parameters():
if param.grad is None:
continue
grad = param.grad
prev_grad = self.exp_avg_grad[name]
# 估计Hessian对角线
if torch.norm(prev_params[name] - param) > self.eps:
hessian_diag = self._estimate_hessian_diag(
param, grad, prev_params[name], prev_grad
)
else:
hessian_diag = self.eps * torch.ones_like(param)
# EMA更新Hessian估计
self.exp_avg_hessian_diag[name] = (
self.rho * self.exp_avg_hessian_diag[name] +
(1 - self.rho) * hessian_diag
)
# EMA更新梯度(用于下次Hessian估计)
self.exp_avg_grad[name] = (
self.betas[0] * self.exp_avg_grad[name] +
(1 - self.betas[0]) * grad
)
# 预条件梯度
clipped_hessian = torch.clamp(
self.exp_avg_hessian_diag[name],
min=self.eps
)
precond_grad = grad / clipped_hessian
# 更新参数
param -= self.lr * precond_grad
# L2正则化
param -= self.lr * 0.01 * param
self.model.zero_grad()
def _compute_loss(self):
"""计算损失(子类实现)"""
raise NotImplementedErrorMuon优化器:Newton-Schulz方法:
Muon使用Newton-Schulz迭代来近似矩阵平方根的逆:
class MuonOptimizer:
"""
Muon优化器
使用Newton-Schulz迭代计算正交预条件子
"""
def __init__(
self,
model,
lr=0.001,
momentum=0.95,
nesterov=True
):
self.model = model
self.lr = lr
self.momentum = momentum
self.nesterov = nesterov
self.m = {} # 动量
self._init_momentum()
def _init_momentum(self):
"""初始化动量"""
for name, param in self.model.named_parameters():
self.m[name] = torch.zeros_like(param)
def _newton_schulz_iteration(self, G, num_iters=5):
"""
Newton-Schulz迭代
将 G 迭代转换为正交矩阵 Q,使得 Q ≈ G / ||G||
迭代公式: X_{k+1} = (3/2) X_k - (1/2) X_k X_k^T X_k
"""
X = G / max(torch.norm(G), 1e-8)
for _ in range(num_iters):
X_new = (3/2) * X - (1/2) * X @ X.T @ X
X = X_new
return X
def step(self):
"""Muon优化步骤"""
loss = self._compute_loss()
loss.backward()
with torch.no_grad():
for name, param in self.model.named_parameters():
if param.grad is None:
continue
grad = param.grad
m_old = self.m[name]
# 更新动量
self.m[name] = (
self.momentum * m_old +
(1 - self.momentum) * grad
)
# 计算更新方向
if self.nesterov:
grad_corrected = (
self.momentum * self.m[name] +
(1 - self.momentum) * grad
)
else:
grad_corrected = self.m[name]
# 对于向量参数,使用简单的梯度更新
if param.dim() == 1:
update = grad_corrected
else:
# 矩阵参数:计算正交预条件子
# 简化为使用权重的正交变换
# 实际中需要使用梯度矩阵
update = grad_corrected
# 应用更新
param -= self.lr * update
self.model.zero_grad()
def _compute_loss(self):
"""计算损失(子类实现)"""
raise NotImplementedError6. 实现细节与代码示例
6.1 K-FAC的简化实现
以下是一个完整可运行的K-FAC简化实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import numpy as np
# ============== 简化K-FAC实现 ==============
class SimpleKFAC:
"""
简化的K-FAC优化器
适用于小型到中型模型
"""
def __init__(
self,
model,
lr=0.01,
damping=1e-3,
stat_decay=0.95,
inverse_interval=30
):
self.model = model
self.lr = lr
self.damping = damping
self.stat_decay = stat_decay
self.inverse_interval = inverse_interval
self.step_count = 0
self.fisher = {} # Fisher统计
self.inverse = {} # Fisher逆
self._init_fisher()
def _init_fisher(self):
"""初始化Fisher统计"""
for name, param in self.model.named_parameters():
if param.dim() >= 2:
self.fisher[name] = {
'a': None, # 输入侧
'g': None # 梯度侧
}
self.inverse[name] = {
'a': None,
'g': None
}
def _update_fisher(self, batch_x, batch_y):
"""更新Fisher统计"""
self.model.zero_grad()
# 前向传播
output = self.model(batch_x)
loss = F.cross_entropy(output, batch_y)
loss.backward()
for name, param in self.model.named_parameters():
if param.grad is None or name not in self.fisher:
continue
grad = param.grad
if isinstance(param, nn.Linear):
# 分解为 a 和 g
# 需要hook捕获输入和梯度
# 简化:使用均匀初始化
if self.fisher[name]['a'] is None:
self.fisher[name]['a'] = torch.eye(param.shape[1], device=param.device)
self.fisher[name]['g'] = torch.eye(param.shape[0], device=param.device)
# 更新g侧(梯度侧)
g_new = grad @ grad.T / grad.numel()
self.fisher[name]['g'] = (
self.stat_decay * self.fisher[name]['g'] +
(1 - self.stat_decay) * g_new
)
def _update_inverse(self):
"""更新Fisher逆矩阵"""
for name in self.fisher:
a = self.fisher[name]['a']
g = self.fisher[name]['g']
if a is None or g is None:
continue
# 添加阻尼并求逆
a_reg = a + self.damping * torch.eye(a.shape[0], device=a.device)
g_reg = g + self.damping * torch.eye(g.shape[0], device=g.device)
# 特征分解求逆
eig_a, vec_a = torch.linalg.eigh(a_reg)
eig_g, vec_g = torch.linalg.eigh(g_reg)
# 截断小特征值
eig_a = torch.clamp(eig_a, min=self.damping)
eig_g = torch.clamp(eig_g, min=self.damping)
self.inverse[name]['a'] = vec_a @ torch.diag(1/eig_a) @ vec_a.T
self.inverse[name]['g'] = vec_g @ torch.diag(1/eig_g) @ vec_g.T
def step(self, batch_x, batch_y):
"""K-FAC优化步骤"""
self.step_count += 1
# 更新Fisher统计
self._update_fisher(batch_x, batch_y)
# 定期更新逆矩阵
if self.step_count % self.inverse_interval == 0:
self._update_inverse()
# 应用修正梯度
self.model.zero_grad()
output = self.model(batch_x)
loss = F.cross_entropy(output, batch_y)
loss.backward()
with torch.no_grad():
for name, param in self.model.named_parameters():
if param.grad is None or name not in self.inverse:
continue
if self.inverse[name]['a'] is None:
continue
inv_a = self.inverse[name]['a']
inv_g = self.inverse[name]['g']
grad = param.grad
# Kronecker乘积应用:(G^{-1} ⊗ A^{-1}) vec(∇W)
corrected = inv_g @ grad @ inv_a.T
param -= self.lr * corrected
self.model.zero_grad()
# ============== 对比实验 ==============
def train_comparison():
"""
比较SGD、Adam和K-FAC在MNIST上的表现
"""
# 生成合成数据(模拟MNIST)
torch.manual_seed(42)
n_samples = 5000
n_features = 784
n_classes = 10
X_train = torch.randn(n_samples, n_features)
y_train = torch.randint(0, n_classes, (n_samples,))
X_test = torch.randn(1000, n_features)
y_test = torch.randint(0, n_classes, (1000,))
# 定义模型
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
# 训练函数
def train_model(model, optimizer, epochs=20):
losses = []
accuracies = []
for epoch in range(epochs):
# Mini-batch训练
indices = torch.randperm(n_samples)[:128]
batch_x = X_train[indices]
batch_y = y_train[indices]
if hasattr(optimizer, 'step'):
if isinstance(optimizer, SimpleKFAC):
optimizer.step(batch_x, batch_y)
else:
optimizer.zero_grad()
output = model(batch_x)
loss = F.cross_entropy(output, batch_y)
loss.backward()
optimizer.step()
# 评估
if epoch % 5 == 0:
with torch.no_grad():
output = model(X_test)
pred = output.argmax(dim=1)
acc = (pred == y_test).float().mean()
losses.append(loss.item())
accuracies.append(acc.item())
return losses, accuracies
# 运行对比
results = {}
# SGD
print("Training with SGD...")
torch.manual_seed(42)
model_sgd = SimpleNet()
opt_sgd = torch.optim.SGD(model_sgd.parameters(), lr=0.01, momentum=0.9)
losses_sgd, accs_sgd = train_model(model_sgd, opt_sgd)
results['SGD'] = {'loss': losses_sgd, 'acc': accs_sgd}
# Adam
print("Training with Adam...")
torch.manual_seed(42)
model_adam = SimpleNet()
opt_adam = torch.optim.Adam(model_adam.parameters(), lr=0.001)
losses_adam, accs_adam = train_model(model_adam, opt_adam)
results['Adam'] = {'loss': losses_adam, 'acc': accs_adam}
# K-FAC
print("Training with K-FAC...")
torch.manual_seed(42)
model_kfac = SimpleNet()
opt_kfac = SimpleKFAC(model_kfac, lr=0.01, damping=1e-2)
losses_kfac, accs_kfac = train_model(model_kfac, opt_kfac)
results['K-FAC'] = {'loss': losses_kfac, 'acc': accs_kfac}
# 绘制结果
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
for name, res in results.items():
ax1.plot(res['loss'], label=name)
ax2.plot(res['acc'], label=name)
ax1.set_xlabel('Evaluation Step')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss Comparison')
ax1.legend()
ax1.grid(True)
ax2.set_xlabel('Evaluation Step')
ax2.set_ylabel('Accuracy')
ax2.set_title('Test Accuracy Comparison')
ax2.legend()
ax2.grid(True)
plt.tight_layout()
plt.savefig('/tmp/kfac_comparison.png', dpi=150)
plt.show()
return results
# 运行实验
if __name__ == "__main__":
results = train_comparison()
print("\n=== Final Results ===")
for name, res in results.items():
print(f"{name}: Final Loss = {res['loss'][-1]:.4f}, Final Acc = {res['acc'][-1]:.4f}")6.2 与标准优化器的对比实验
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
class OptimizerBenchmark:
"""
优化器性能基准测试
"""
def __init__(
self,
model,
train_data,
test_data,
optimizers_to_test
):
self.model = model
self.train_data = train_data
self.test_data = test_data
self.optimizers = optimizers_to_test
def benchmark_step_time(self, optimizer_name, optimizer, n_runs=100):
"""测量单步训练时间"""
times = []
for _ in range(n_runs):
x, y = self.train_data
start = time.time()
optimizer.zero_grad()
output = self.model(x)
loss = F.cross_entropy(output, y)
loss.backward()
optimizer.step()
if torch.cuda.is_available():
torch.cuda.synchronize()
times.append(time.time() - start)
return np.mean(times), np.std(times)
def benchmark_convergence(self, optimizer_name, optimizer, max_epochs=50):
"""测量收敛速度和最终性能"""
losses = []
accuracies = []
for epoch in range(max_epochs):
x, y = self.train_data
optimizer.zero_grad()
output = self.model(x)
loss = F.cross_entropy(output, y)
loss.backward()
optimizer.step()
# 评估
with torch.no_grad():
test_output = self.model(self.test_data[0])
test_loss = F.cross_entropy(test_output, self.test_data[1])
pred = test_output.argmax(dim=1)
acc = (pred == self.test_data[1]).float().mean()
losses.append(test_loss.item())
accuracies.append(acc.item())
if epoch % 10 == 0:
print(f"{optimizer_name} - Epoch {epoch}: Loss = {test_loss:.4f}, Acc = {acc:.4f}")
return losses, accuracies
def run_full_benchmark(self):
"""运行完整基准测试"""
results = {}
for name, optimizer in self.optimizers.items():
print(f"\n{'='*50}")
print(f"Benchmarking: {name}")
print('='*50)
# 重置模型
for p in self.model.parameters():
if hasattr(p, 'data'):
torch.nn.init.xavier_uniform_(p.data)
# 测量时间
mean_time, std_time = self.benchmark_step_time(name, optimizer)
print(f"Step time: {mean_time*1000:.2f} ± {std_time*1000:.2f} ms")
# 测量收敛
losses, accuracies = self.benchmark_convergence(name, optimizer)
results[name] = {
'step_time': (mean_time, std_time),
'final_loss': losses[-1],
'final_acc': accuracies[-1],
'losses': losses,
'accuracies': accuracies
}
return results
def create_optimizers(model):
"""创建待测试的优化器字典"""
return {
'SGD': torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
'Adam': torch.optim.Adam(model.parameters(), lr=0.001),
'AdamW': torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01),
'RMSprop': torch.optim.RMSprop(model.parameters(), lr=0.001),
}7. 总结与展望
7.1 方法对比总结
| 方法 | 曲率估计 | 计算复杂度 | 内存复杂度 | 收敛速度 | 适用场景 |
|---|---|---|---|---|---|
| 牛顿法 | 精确Hessian | 二次(局部) | 小规模问题 | ||
| K-FAC | Kronecker分解 | 快 | 中等规模 | ||
| Shampoo | 分块逆平方根 | 快 | 大规模 | ||
| L-BFGS | 秩-2更新 | 快 | 中等规模 | ||
| CG | Hessian-vector | 快 | 大规模 | ||
| Adam | 对角近似 | 中等 | 超大规模 |
7.2 未来研究方向
- 通信高效分布式二阶优化:减少跨GPU同步开销
- 低精度二阶方法:在BF16/FP8下保持二阶方法的稳定性
- 自适应曲率估计:根据训练阶段动态调整曲率估计精度
- 与Transformer缩放定律的交互:理解二阶优化在大模型训练中的 scaling behavior
参考
Footnotes
-
Amari, S. (1998). Natural gradient works efficiently in learning. Neural Computation, 10(2), 251-276. ↩
-
Martens, J., & Grosse, R. (2015). Optimizing neural networks with kronecker-factored approximate curvature. ICML 2015. ↩
-
Rao, C. R. (1945). Information and the accuracy attainable in the estimation of statistical parameters. Bulletin of the Calcutta Mathematical Society, 37, 81-91. ↩
-
Amari, S. (2020). Information geometry and its applications: On the front line of data science. IEEE Signal Processing Magazine, 37(6), 119-127. ↩