二阶优化方法深度理论

二阶优化方法利用目标函数的曲率信息来指导参数更新,是深度学习优化的重要理论基础。与一阶方法相比,二阶方法能够更有效地处理病态条件数问题,实现更快的收敛速度。本文档从信息几何的视角出发,深入探讨自然梯度法、K-FAC、Shampoo等二阶优化器的理论基础与实践方法。12


1. 自然梯度与信息几何

1.1 Fisher信息矩阵的定义与性质

概率分布参数空间的几何结构

在参数化概率分布 的空间中,我们可以定义一种黎曼几何结构。设参数空间 维流形,参数 定义了概率分布。

**Fisher信息矩阵(Fisher Information Matrix, FIM)**定义为对数似然函数梯度的外积的期望:

其中 是 score function(得分函数)。

数学性质

  1. 对称半正定性

  2. Cramer-Rao下界:对于任意无偏估计量 ,其协方差矩阵满足:

    这表明 Fisher 信息的逆给出了参数估计精度的下界。

  3. KL散度的二阶近似:两个无限接近的分布间的 KL 散度可以用 Fisher 矩阵近似:

    3

  4. 一致性条件:若 正确指定,则 Fisher 信息矩阵等于对数似然的 Hessian 矩阵的负期望:

    这建立了 Fisher 信息与 Hessian 之间的联系。

神经网络中的Fisher信息矩阵

对于神经网络 和损失函数 ,当使用交叉熵损失时,有:

这使得 Fisher 矩阵成为自然梯度更新中的自然曲率度量。

1.2 KL散度作为黎曼度量

黎曼度量与参数空间几何

在参数空间 上,黎曼度量 赋予每点 一个正定矩阵 ,定义两点间的无穷小距离为:

KL散度诱导的度量结构

KL散度 在分布空间上定义了非对称的”距离”概念。对于无限接近的参数 ,有:

去掉高阶项后,二次型 定义了参数空间上的黎曼度量,其中 Fisher 矩阵 作为黎曼度量张量。

与欧氏空间的本质区别

方面欧氏梯度下降自然梯度下降
度量(单位矩阵)(Fisher矩阵)
距离定义$\\delta\
更新方向损失函数下降最陡KL散度意义下最陡
学习率敏感性

1.3 自然梯度更新公式

从变分原理推导

自然梯度法寻找在分布空间(而非参数空间)中使损失函数下降最大的方向。考虑约束:

使用拉格朗日乘子法:

在一阶近似下求解,得到自然梯度更新:

其中 被称为 自然度量的逆,用于将欧氏空间中的梯度转换为黎曼流形上的切向量。4

自然梯度的计算流程

import torch
import torch.nn as nn
 
def compute_natural_gradient(model, fisher_matrix, loss_fn, damping=1e-3):
    """
    计算自然梯度
    
    自然梯度 = F^{-1} ∇L
    
    其中 F 是 Fisher 信息矩阵
    """
    # 1. 计算普通梯度
    loss = loss_fn(model)
    grad = torch.autograd.grad(
        loss, 
        model.parameters(), 
        retain_graph=True
    )
    
    # 2. 准备梯度向量(展平)
    grad_vector = torch.cat([g.flatten() for g in grad])
    
    # 3. 求解线性系统 F x = grad
    # 使用对角近似加速求解
    fisher_diag = fisher_matrix.diagonal()
    natural_grad_vector = grad_vector / (fisher_diag + damping)
    
    # 4. 重新整形为参数形状
    natural_grads = []
    offset = 0
    for p in model.parameters():
        size = p.numel()
        natural_grads.append(
            natural_grad_vector[offset:offset+size].view(p.shape)
        )
        offset += size
    
    return natural_grads

1.4 自然梯度的几何意义

黎曼流形上的梯度下降

自然梯度下降本质上是 黎曼梯度下降 在 Fisher-Riemannian 流形上的特例。在黎曼几何中,梯度需要通过度量张量的逆来变换:

时,这就是自然梯度。

概率分布空间的几何直观

考虑两个高斯分布 。它们的 Fisher 信息矩阵为 (标量)。自然梯度更新为:

这表明自然梯度会根据分布的曲率自动调整步长:在低曲率区域(大方差)迈大步,在高曲率区域(小方差)迈小步。

与Kullback-Leibler流的联系

自然梯度更新可以解释为 KL散度下降流(KL Divergence Flow):

时,上式收敛到自然梯度流。

参数重参数化不变性

自然梯度的一个重要性质是 参数重参数化不变性。如果我们对参数进行可逆变换 ,自然梯度在新的参数空间中保持形式不变:

这与欧氏梯度下降形成对比,后者在非线性重参数化下会改变更新方向。


2. Kronecker因子分解(K-FAC)

2.1 Kronecker分解的数学原理

Kronecker积回顾

对于矩阵 ,其 Kronecker 积 定义为:

Kronecker分解的核心思想

是完整的 Fisher 信息矩阵。K-FAC 假设 可以近似分解为 Kronecker 积的直和:

这种分解具有以下性质:

  1. 存储复杂度:从 降至
  2. 求逆复杂度:从 降至
  3. 矩阵-向量乘积

2.2 K-FAC对Fisher信息的近似

神经网络层的结构

考虑一个线性层 ,其中 是权重矩阵, 是输入, 是输出。

独立性假设

K-FAC 的关键假设是:对于给定层,输入激活 与输出梯度 近似独立。这使得 Fisher 矩阵可以分解为:

其中:

  • 是输入的协方差矩阵
  • 是输出梯度的协方差矩阵

经验估计

在实际计算中,使用 mini-batch 上的经验估计:

其中 是 batch size。

分解的数学证明(简化版):

为权重的对数似然梯度。根据链式法则:

则 Fisher 矩阵为:

独立,则:

2.3 实践实现与计算效率分析

K-FAC的完整实现框架

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
 
class KFACOptimizer:
    """
    Kronecker-Factored Approximate Curvature (K-FAC) 优化器
    
    参考文献: Martens & Grosse (2015)
    """
    
    def __init__(
        self,
        model,
        lr=0.001,
        momentum=0.9,
        stat_decay=0.99,
        damping=1e-3,
        kl_interval=10,      # Fisher矩阵更新间隔
        inverse_interval=50 # 逆矩阵更新间隔
    ):
        self.model = model
        self.lr = lr
        self.momentum = momentum
        self.stat_decay = stat_decay
        self.damping = damping
        self.kl_interval = kl_interval
        self.inverse_interval = inverse_interval
        
        self.step_count = 0
        self.fisher_blocks = {}  # Fisher矩阵块
        self.inverse_blocks = {}  # 逆矩阵块
        self.grad_buffers = {}   # 梯度缓冲
        
        self._register_hooks()
    
    def _register_hooks(self):
        """注册前向钩子以捕获中间激活"""
        self.activations = {}
        self.grad_outputs = {}
        
        def forward_hook(name):
            def hook(module, input, output):
                self.activations[name] = input[0].detach()
            return hook
        
        def backward_hook(name):
            def hook(module, grad_input, grad_output):
                self.grad_outputs[name] = grad_output[0].detach()
            return hook
        
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                module.register_forward_hook(forward_hook(name))
                module.register_full_backward_hook(backward_hook(name))
    
    def _update_fisher_statistics(self):
        """更新Fisher统计量(使用EMA)"""
        for name, module in self.model.named_modules():
            if not isinstance(module, (nn.Linear, nn.Conv2d)):
                continue
            
            grad = module.weight.grad
            if grad is None:
                continue
            
            if name not in self.fisher_blocks:
                self.fisher_blocks[name] = {
                    'A': None, 'B': None, 'cnt': 0
                }
            
            # 计算当前batch的统计量
            if isinstance(module, nn.Linear):
                A, B = self._compute_linear_factors(module, grad)
            else:
                A, B = self._compute_conv_factors(module, grad)
            
            # 指数滑动平均
            if self.fisher_blocks[name]['A'] is None:
                self.fisher_blocks[name]['A'] = A
                self.fisher_blocks[name]['B'] = B
            else:
                self.fisher_blocks[name]['A'] = (
                    self.stat_decay * self.fisher_blocks[name]['A'] + 
                    (1 - self.stat_decay) * A
                )
                self.fisher_blocks[name]['B'] = (
                    self.stat_decay * self.fisher_blocks[name]['B'] + 
                    (1 - self.stat_decay) * B
                )
            
            self.fisher_blocks[name]['cnt'] += 1
    
    def _compute_linear_factors(self, module, grad):
        """
        计算线性层的Kronecker因子
        
        对于 Linear(in, out):
        - A: 输入协方差 (in × in)
        - B: 输出梯度协方差 (out × out)
        """
        x = self.activations.get(str(id(module)), None)
        grad_out = self.grad_outputs.get(str(id(module)), None)
        
        if x is None or grad_out is None:
            # 使用默认的独立近似
            return torch.eye(grad.shape[1], device=grad.device), \
                   torch.eye(grad.shape[0], device=grad.device)
        
        batch_size = x.shape[0]
        
        # A: 输入协方差 (使用无偏估计)
        A = x.T @ x / batch_size
        
        # B: 输出梯度协方差
        B = grad_out.T @ grad_out / batch_size
        
        return A, B
    
    def _compute_conv_factors(self, module, grad):
        """
        计算卷积层的Kronecker因子
        
        将卷积操作展开为 im2col 矩阵乘法
        """
        x = self.activations.get(str(id(module)), None)
        grad_out = self.grad_outputs.get(str(id(module)), None)
        
        if x is None or grad_out is None:
            in_channels = grad.shape[1] * module.kernel_size[0] * module.kernel_size[1]
            out_channels = grad.shape[0]
            return torch.eye(in_channels, device=grad.device), \
                   torch.eye(out_channels, device=grad.device)
        
        # 使用 im2col 展开
        x_col = self._im2col(x, module)
        grad_out_unfold = grad_out.flatten(1)
        
        batch_size = x.shape[0]
        
        A = x_col @ x_col.T / batch_size
        B = grad_out_unfold @ grad_out_unfold.T / batch_size
        
        return A, B
    
    def _im2col(self, x, conv_module):
        """将卷积输入展开为列矩阵"""
        # 简化实现:实际中需要正确处理padding, stride等
        batch, channels, height, width = x.shape
        k_h, k_w = conv_module.kernel_size
        out_h = conv_module.output_size[0] if hasattr(conv_module, 'output_size') else \
                (height - k_h) // conv_module.stride[0] + 1
        out_w = conv_module.output_size[1] if hasattr(conv_module, 'output_size') else \
                (width - k_w) // conv_module.stride[1] + 1
        
        # 简化为展平
        return x.flatten(2)
    
    def _update_inverse_matrices(self):
        """使用特征分解更新逆矩阵"""
        for name, fisher in self.fisher_blocks.items():
            A, B = fisher['A'], fisher['B']
            
            if A is None or B is None:
                continue
            
            # 添加阻尼并计算特征分解
            eigenvalues_A, eigenvectors_A = torch.linalg.eigh(
                A + self.damping * torch.eye(A.shape[0], device=A.device)
            )
            eigenvalues_B, eigenvectors_B = torch.linalg.eigh(
                B + self.damping * torch.eye(B.shape[0], device=B.device)
            )
            
            # 计算逆矩阵(通过特征分解)
            inv_eigenvalues_A = 1.0 / torch.clamp(eigenvalues_A, min=self.damping)
            inv_eigenvalues_B = 1.0 / torch.clamp(eigenvalues_B, min=self.damping)
            
            self.inverse_blocks[name] = {
                'A_inv': eigenvectors_A @ torch.diag(inv_eigenvalues_A) @ eigenvectors_A.T,
                'B_inv': eigenvectors_B @ torch.diag(inv_eigenvalues_B) @ eigenvectors_B.T
            }
    
    def _compute_corrected_gradient(self, name, module, grad):
        """
        计算Kronecker分解修正后的梯度
        
        (B^{-1} ⊗ A^{-1}) vec(∇W) = vec(B^{-1} ∇W A^{-1})
        """
        if name not in self.inverse_blocks:
            return grad
        
        inv_A = self.inverse_blocks[name]['A_inv']
        inv_B = self.inverse_blocks[name]['B_inv']
        
        # Kronecker乘积的性质:(B^{-1} ⊗ A^{-1})vec(∇W) = vec(B^{-1}∇W A^{-1})
        corrected_grad = inv_B @ grad @ inv_A.T
        
        return corrected_grad
    
    def step(self):
        """K-FAC优化步骤"""
        self.step_count += 1
        
        # 1. 定期更新Fisher统计
        if self.step_count % self.kl_interval == 0:
            self._update_fisher_statistics()
        
        # 2. 定期更新逆矩阵
        if self.step_count % self.inverse_interval == 0:
            self._update_inverse_matrices()
        
        # 3. 应用修正梯度
        with torch.no_grad():
            for name, module in self.model.named_modules():
                if isinstance(module, (nn.Linear, nn.Conv2d)):
                    if module.weight.grad is None:
                        continue
                    
                    if name in self.inverse_blocks:
                        corrected_grad = self._compute_corrected_gradient(
                            name, module, module.weight.grad
                        )
                        module.weight -= self.lr * corrected_grad
                        
                        if module.bias is not None and module.bias.grad is not None:
                            module.bias -= self.lr * module.bias.grad
                    else:
                        module.weight -= self.lr * module.weight.grad
                        if module.bias is not None:
                            module.bias -= self.lr * module.bias.grad
        
        self.model.zero_grad()

计算效率分析

操作全量FisherK-FAC
存储
Fisher更新
求逆
梯度修正

对于大型Transformer模型,K-FAC通常只需要存储约 级别的内存。

2.4 K-FAC在大型模型训练中的应用

分布式K-FAC实现

class DistributedKFACOptimizer:
    """
    分布式K-FAC优化器,支持多GPU训练
    
    Fisher统计在每个GPU上独立计算,然后通过AllReduce聚合
    """
    
    def __init__(self, model, lr=0.001, world_size=1, rank=0, **kwargs):
        self.model = model
        self.lr = lr
        self.world_size = world_size
        self.rank = rank
        self.kfac = KFACOptimizer(model, lr=lr, **kwargs)
    
    def sync_fisher_stats(self):
        """跨GPU同步Fisher统计量"""
        for name in self.kfac.fisher_blocks:
            A = self.kfac.fisher_blocks[name]['A']
            B = self.kfac.fisher_blocks[name]['B']
            
            # AllReduce聚合
            A_all = torch.distributed.all_reduce(
                A, op=torch.distributed.ReduceOp.SUM, async_op=True
            )
            B_all = torch.distributed.all_reduce(
                B, op=torch.distributed.ReduceOp.SUM, async_op=True
            )
            
            # 取平均
            self.kfac.fisher_blocks[name]['A'] = A_all / self.world_size
            self.kfac.fisher_blocks[name]['B'] = B_all / self.world_size
    
    def step(self):
        """分布式优化步骤"""
        if self.rank == 0:
            # 主GPU更新Fisher统计
            self.kfac._update_fisher_statistics()
        
        # 同步Fisher统计
        self.sync_fisher_stats()
        
        # 主GPU更新逆矩阵
        if self.rank == 0:
            self.kfac._update_inverse_matrices()
        
        # 广播逆矩阵
        self._broadcast_inverse_blocks()
        
        # 所有GPU应用梯度
        self.kfac.step()

与标准优化器的对比

特性SGDAdamK-FAC
收敛速度中等
内存开销中等
超参敏感性中等
曲率利用对角近似块结构
大规模可扩展性极好中等

3. Shampoo优化器理论

3.1 Preconditioning矩阵的定义

预条件子的几何意义

预条件子的核心思想是将优化问题从”病态”坐标系变换到”良好条件”的坐标系。在二阶优化中,预条件子 被设计为近似 Hessian 的逆:

理想的预条件子应满足:

  1. (曲率近似)
  2. 易于计算
  3. 存储开销可接受

Shampoo的Preconditioning策略

Shampoo优化器为每个参数张量 构建 个预条件子矩阵 ,分别对应每个维度。

3.2 Adagrad的块对角推广

Adagrad回顾

Adagrad维护一个累积二阶矩矩阵 ,更新公式为:

其中 是对角元素。

Shampoo的块对角推广

对于矩阵参数 ,Shampoo定义两个预条件子:

更新公式为:

使用 Kronecker 乘积的性质,这等价于:

数学推导

是完整的 Fisher/Gram 矩阵。

,则:

这意味着我们只需计算 的逆平方根,而非完整矩阵

3.3 SVD分解与参数变换

矩阵平方根的计算

对于正定矩阵 ,其矩阵平方根 可通过特征分解计算:

Power Iteration方法

为避免每次迭代计算完整的特征分解,Shampoo使用幂迭代法(Power Iteration)来近似矩阵平方根的逆:

def matrix_power_iteration(A, p, k=5, num_repeats=2):
    """
    计算 A^{p} 的近似
    
    使用幂迭代法,其中 p 通常为负分数(如 -1/4)
    """
    m, n = A.shape
    device = A.device
    
    # 初始化随机矩阵
    Q = torch.randn(n, min(n, m), device=device)
    Q, _ = torch.linalg.qr(Q)
    
    for _ in range(num_repeats):
        # Power iteration
        for _ in range(k):
            Q = A @ Q
            Q, _ = torch.linalg.qr(Q)
        
        # 归一化
        R = Q.T @ A @ Q
        R_inv_pow = torch.sign(R) * torch.abs(R) ** p
        Q = A @ Q @ torch.linalg.inv(R) @ torch.linalg.cholesky(R_inv_pow + 1e-6 * torch.eye(min(n, m), device=device)).transpose(-2, -1)
        Q, _ = torch.linalg.qr(Q)
    
    return Q @ torch.linalg.cholesky(
        Q.T @ A @ Q + 1e-6 * torch.eye(min(n, m), device=device)
    ).transpose(-2, -1)

LORA框架下的Shampoo

class ShampooOptimizer:
    """
    Shampoo优化器实现
    
    参考文献: Gupta et al. (2018) - Shampoo: Preconditioned Stochastic Tensor Optimization
    """
    
    def __init__(
        self,
        model,
        lr=0.001,
        betas=(0.9, 0.999),
        epsilon=1e-6,
        weight_decay=0.01,
        matrix_power=4,
        power_iter_steps=5,
        power_iter_repeats=2
    ):
        self.model = model
        self.lr = lr
        self.betas = betas
        self.epsilon = epsilon
        self.weight_decay = weight_decay
        self.matrix_power = matrix_power
        self.power_iter_steps = power_iter_steps
        self.power_iter_repeats = power_iter_repeats
        
        self.preconditioners = {}  # 存储每个参数的预条件子
        self.timestep = 0
        self._init_preconditioners()
    
    def _init_preconditioners(self):
        """初始化预条件子"""
        for name, param in self.model.named_parameters():
            if param.dim() >= 2:
                self.preconditioners[name] = {
                    'stats': [],  # 各维度的统计矩阵
                    'inv_stats': []  # 各维度的逆统计矩阵
                }
                
                # 为每个维度初始化统计矩阵
                for dim in range(param.dim()):
                    size = param.shape[dim]
                    self.preconditioners[name]['stats'].append(
                        torch.zeros(size, size, device=param.device)
                    )
                    self.preconditioners[name]['inv_stats'].append(None)
    
    def _compute_tensor_gradient(self, param, grad):
        """
        将梯度张量按各维度求外积
        
        对于 W_{i1,i2,...,ik},计算各维度的 Gram 矩阵
        """
        stats = []
        
        # 沿各维度计算统计量
        for dim in range(grad.dim()):
            # 计算该维度的投影
            # G_d = sum_{i1,...,ik} grad_{..., :, ...} grad_{..., :, ...}^T
            grad_transposed = grad.transpose(0, dim)
            grad_flattened = grad_transposed.reshape(
                grad.shape[dim], -1
            )
            stat = grad_flattened @ grad_flattened.T
            stats.append(stat)
        
        return stats
    
    def _update_statistics(self, param, grad):
        """更新统计矩阵"""
        name = None
        for n, p in self.model.named_parameters():
            if p is param:
                name = n
                break
        
        if name is None or name not in self.preconditioners:
            return
        
        # 计算当前batch的统计量
        current_stats = self._compute_tensor_gradient(param, grad)
        
        # 使用EMA更新
        decay = self.betas[1]
        for i, stat in enumerate(current_stats):
            self.preconditioners[name]['stats'][i] = (
                decay * self.preconditioners[name]['stats'][i] +
                (1 - decay) * stat
            )
    
    def _compute_inverse_square_root(self, A):
        """
        计算 (A + εI)^{-1/2}
        
        使用幂迭代法近似
        """
        # 预处理
        A = A + self.epsilon * torch.eye(A.shape[0], device=A.device)
        
        # 特征分解
        eigenvalues, eigenvectors = torch.linalg.eigh(A)
        
        # 逆平方根
        eigenvalues = torch.clamp(eigenvalues, min=self.epsilon)
        inv_sqrt_eigenvalues = eigenvalues ** (-0.5)
        
        return eigenvectors @ torch.diag(inv_sqrt_eigenvalues) @ eigenvectors.T
    
    def _update_inverse_statistics(self):
        """更新逆统计矩阵"""
        for name, precond in self.preconditioners.items():
            for i, stat in enumerate(precond['stats']):
                if stat is not None and torch.trace(stat) > 0:
                    precond['inv_stats'][i] = self._compute_inverse_square_root(stat)
    
    def step(self):
        """Shampoo优化步骤"""
        self.timestep += 1
        
        # 更新统计量
        with torch.no_grad():
            for param in self.model.parameters():
                if param.grad is None:
                    continue
                if param.dim() >= 2:
                    self._update_statistics(param, param.grad)
        
        # 定期更新逆统计矩阵
        if self.timestep % 100 == 0:
            self._update_inverse_statistics()
        
        # 应用更新
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if param.grad is None:
                    continue
                
                if name in self.preconditioners:
                    precond = self.preconditioners[name]
                    
                    if all(inv is not None for inv in precond['inv_stats']):
                        # 应用预条件梯度
                        update = param.grad.clone()
                        
                        # 从右到左应用各维度的逆平方根
                        for i in range(param.dim() - 1, -1, -1):
                            inv_stat = precond['inv_stats'][i]
                            if i == param.dim() - 1:
                                update = update @ inv_stat
                            else:
                                # 需要正确的索引逻辑
                                pass
                        
                        param -= self.lr * update
                    else:
                        param -= self.lr * param.grad
                else:
                    param -= self.lr * param.grad
                
                # 权重衰减
                if self.weight_decay > 0:
                    param -= self.weight_decay * self.lr * param
        
        self.model.zero_grad()

3.4 Shampoo的收敛性分析

理论基础

Shampoo的收敛性分析基于以下观察:

是累积梯度外积矩阵,则 Adagrad 的学习率调整相当于使用 作为预条件子。

Shampoo的收敛率

对于凸优化问题,Shampoo达到 的 regret bound,与 Adagrad 相当,但在非凸设置下表现更好。

与一阶方法的比较

方法预条件子收敛率每步复杂度
SGD
Adagrad
Shampoo

4. 近似二阶方法的比较

4.1 共轭梯度法(CG)

共轭梯度法的基本原理

共轭梯度法是求解线性系统 (其中 )的迭代方法,特别适用于稀疏矩阵和大型系统。

数学框架

给定二次函数 ,其梯度为 。在极小点处有

共轭梯度法通过构建 -共轭方向序列来求解:

算法实现

def conjugate_gradient(
    matvec,  # 矩阵-向量乘积函数
    b,       # 右手边向量
    x0=None,
    max_iter=None,
    tol=1e-6,
    preconditioner=None
):
    """
    共轭梯度法求解 Ax = b
    
    参数:
        matvec: 矩阵-向量乘积函数
        b: 右手边向量
        x0: 初始解
        max_iter: 最大迭代次数
        tol: 收敛容忍度
        preconditioner: 预条件子函数
    """
    x = torch.zeros_like(b) if x0 is None else x0.clone()
    r = b - matvec(x)
    
    if preconditioner is not None:
        z = preconditioner(r)
    else:
        z = r.clone()
    
    p = z.clone()
    rz_old = torch.dot(r, z)
    
    n = b.numel()
    if max_iter is None:
        max_iter = n
    
    for i in range(max_iter):
        Ap = matvec(p)
        
        # 线搜索步长
        pAp = torch.dot(p, Ap)
        if pAp <= 0:
            # 负曲率方向,使用梯度下降
            alpha = torch.dot(r, r) / (torch.dot(r, Ap) + 1e-10)
        else:
            alpha = rz_old / pAp
        
        # 更新
        x = x + alpha * p
        r = r - alpha * Ap
        
        # 检查收敛
        if torch.norm(r) < tol:
            break
        
        # 预条件子
        if preconditioner is not None:
            z = preconditioner(r)
        else:
            z = r.clone()
        
        # 共轭方向更新
        rz_new = torch.dot(r, z)
        beta = rz_new / (rz_old + 1e-10)
        p = z + beta * p
        rz_old = rz_new
    
    return x
 
def hessian_free_optimization(
    model,
    loss_fn,
    lr=1.0,
    damping=0.1,
    cg_iterations=50,
    max_iter=100
):
    """
    Hessian-Free优化器
    
    使用CG求解 H^{-1}g
    """
    history = {
        'grads': [],
        'params': [],
        'losses': []
    }
    
    for iteration in range(max_iter):
        # 计算损失和梯度
        loss = loss_fn(model)
        loss.backward()
        
        # 收集梯度
        g = torch.cat([p.grad.flatten() for p in model.parameters()])
        history['losses'].append(loss.item())
        history['grads'].append(g.clone())
        
        # 定义Hessian-vector乘积(无精确Hessian)
        def hvp(v):
            # 使用有限差分近似 Hessian-vector 乘积
            # Hv ≈ (∇L(w+εv) - ∇L(w-εv)) / (2ε)
            epsilon = 1e-3
            params = [p.clone() for p in model.parameters()]
            
            # 正向扰动
            offset = 0
            for p in model.parameters():
                size = p.numel()
                v_p = v[offset:offset+size].reshape(p.shape)
                p.data.add_(v_p, alpha=epsilon)
                offset += size
            
            model.zero_grad()
            loss_plus = loss_fn(model)
            grads_plus = torch.cat([p.grad.flatten() for p in model.parameters()])
            
            # 恢复并负向扰动
            offset = 0
            for p, orig in zip(model.parameters(), params):
                p.data.copy_(orig)
            for p in model.parameters():
                size = p.numel()
                v_p = v[offset:offset+size].reshape(p.shape)
                p.data.add_(v_p, alpha=-epsilon)
                offset += size
            
            model.zero_grad()
            loss_minus = loss_fn(model)
            grads_minus = torch.cat([p.grad.flatten() for p in model.parameters()])
            
            # 恢复原始参数
            for p, orig in zip(model.parameters(), params):
                p.data.copy_(orig)
            
            hvp = (grads_plus - grads_minus) / (2 * epsilon)
            hvp = hvp + damping * v  # 添加阻尼
            
            return hvp
        
        # 使用CG求解 H^{-1}g
        delta = conjugate_gradient(
            hvp,
            -g,
            max_iter=cg_iterations
        )
        
        # 更新参数
        offset = 0
        with torch.no_grad():
            for p in model.parameters():
                size = p.numel()
                delta_p = delta[offset:offset+size].reshape(p.shape)
                p.add_(delta_p, alpha=-lr)
                offset += size
        
        model.zero_grad()
        
        print(f"Iter {iteration}: Loss = {loss.item():.6f}")
    
    return history

4.2 拟牛顿方法(L-BFGS)

BFGS的更新公式

BFGS(Broyden-Fletcher-Goldfarb-Shanno)方法使用秩-2校正来逼近Hessian的逆:

其中 (参数更新),(梯度差)。

L-BFGS(Limited-memory BFGS) 存储最近 来近似

class LBFGSOptimizer:
    """
    Limited-memory BFGS 优化器
    """
    
    def __init__(
        self,
        model,
        lr=1.0,
        max_history=10,
        damping=0.0,
        Wolfe=True
    ):
        self.model = model
        self.lr = lr
        self.max_history = max_history
        self.damping = damping
        self.Wolfe = Wolfe
        
        self.s_history = []  # 参数差
        self.y_history = []  # 梯度差
        self.rho_history = []  # 缩放因子
    
    def _two_loop_recursion(self, grad):
        """
        两循环递归计算 H^{-1} * grad
        
        这是L-BFGS的核心算法
        """
        q = grad.clone()
        
        # 反向循环
        alphas = []
        for s, y, rho in zip(
            reversed(self.s_history),
            reversed(self.y_history),
            reversed(self.rho_history)
        ):
            alpha = rho * torch.dot(s, q)
            alphas.append(alpha)
            q = q - alpha * y
        
        # 初始缩放(使用最近一次的曲率估计)
        if len(self.s_history) > 0:
            s, y = self.s_history[-1], self.y_history[-1]
            gamma = torch.dot(y, s) / (torch.dot(y, y) + 1e-10)
            z = gamma * q
        else:
            z = q
        
        # 正向循环
        betas = []
        for s, y, rho, alpha in zip(
            self.s_history,
            self.y_history,
            self.rho_history,
            reversed(alphas)
        ):
            beta = rho * torch.dot(y, z)
            betas.append(beta)
            z = z + s * (alpha - beta)
        
        return z
    
    def step(self):
        """
        L-BFGS 优化步骤
        
        使用线搜索满足 Wolfe 条件
        """
        # 获取当前参数和梯度
        params = [p.clone() for p in self.model.parameters()]
        
        loss = self._compute_loss()
        loss.backward()
        
        grad = torch.cat([p.grad.flatten() for p in self.model.parameters()])
        
        # 计算搜索方向
        direction = self._two_loop_recursion(grad)
        
        # 线搜索
        if self.Wolfe:
            alpha = self._wolfe_line_search(params, grad, direction)
        else:
            alpha = self.lr
        
        # 更新参数
        offset = 0
        with torch.no_grad():
            for p in self.model.parameters():
                size = p.numel()
                delta = alpha * direction[offset:offset+size].reshape(p.shape)
                p.add_(delta)
                offset += size
        
        # 更新历史
        new_params = [p.clone() for p in self.model.parameters()]
        new_grad = self._compute_loss_grad()
        
        s = torch.cat([p.flatten() for p in new_params]) - \
            torch.cat([p.flatten() for p in params])
        y = new_grad - grad
        
        # 曲率条件检查:y^T s > 0
        sty = torch.dot(y, s)
        if sty > 0:
            self.s_history.append(s)
            self.y_history.append(y)
            self.rho_history.append(1.0 / (sty + 1e-10))
            
            # 限制历史长度
            if len(self.s_history) > self.max_history:
                self.s_history.pop(0)
                self.y_history.pop(0)
                self.rho_history.pop(0)
        
        self.model.zero_grad()
    
    def _compute_loss(self):
        """计算当前损失"""
        return sum(p.sum() * 0 for p in self.model.parameters())
    
    def _compute_loss_grad(self):
        """计算当前梯度"""
        loss = self._compute_loss()
        return torch.cat([p.grad.flatten() for p in self.model.parameters()])
    
    def _wolfe_line_search(self, params, grad, direction):
        """
        Wolfe线搜索
        
        满足:
        1. Armijo条件: f(x + αd) ≤ f(x) + c1 * α * g^T d
        2. 曲率条件: ∇f(x + αd)^T d ≥ c2 * g^T d
        """
        c1, c2 = 1e-4, 0.9
        alpha = 1.0
        alpha_min, alpha_max = 1e-10, 1e10
        
        loss0 = self._compute_loss().item()
        grad_dot_d = torch.dot(grad, direction).item()
        
        for _ in range(20):
            # 更新参数
            offset = 0
            for p, orig in zip(self.model.parameters(), params):
                size = p.numel()
                delta = alpha * direction[offset:offset+size].reshape(p.shape)
                p.copy_(orig + delta)
                offset += size
            
            loss1 = self._compute_loss().item()
            
            # Armijo 条件
            if loss1 > loss0 + c1 * alpha * grad_dot_d:
                alpha_max = alpha
                alpha = (alpha_min + alpha_max) / 2
            else:
                # 近似曲率条件
                # (在实际中需要计算新梯度)
                return alpha
            
            # 更新边界
            if alpha < alpha_max:
                alpha_max = alpha
            else:
                alpha_min = alpha
            
            alpha = (alpha_min + alpha_max) / 2
        
        return alpha

4.3 随机二阶方法的理论与实践

随机二阶方法的挑战

大规模机器学习中的二阶方法面临以下挑战:

  1. Hessian/Fisher 矩阵的随机估计方差大
  2. 收敛速度与 batch size 的权衡
  3. 通信开销(分布式训练)

随机Curvature估计

class StochasticSecondOrderOptimizer:
    """
    随机二阶优化器基类
    
    提供通用的随机曲率估计框架
    """
    
    def __init__(self, model, lr, damping=0.1, batch_size=32):
        self.model = model
        self.lr = lr
        self.damping = damping
        self.batch_size = batch_size
    
    def _estimate_curvature(self, dataloader):
        """
        随机估计曲率矩阵
        
        子类应实现具体的估计方法
        """
        raise NotImplementedError
    
    def _solve_linear_system(self, A, b):
        """
        求解线性系统 Ax = b
        
        子类可实现高效的求解方法
        """
        return torch.linalg.solve(A + self.damping * torch.eye(A.shape[0]), b)

SVRG(Stochastic Variance Reduced Gradient)二阶扩展

SVRG-SO 算法结合了方差缩减技术与二阶信息:

class SVRGSecondOrder:
    """
    SVRG-二阶方法
    
    结合SVRG的方差缩减与二阶曲率信息
    """
    
    def __init__(
        self,
        model,
        data_loader,
        lr=0.01,
        inner_iter=100,
        damping=0.1
    ):
        self.model = model
        self.data_loader = data_loader
        self.lr = lr
        self.inner_iter = inner_iter
        self.damping = damping
        
        # 存储参考点和梯度
        self.theta_ref = None
        self.grad_ref = None
    
    def step(self):
        """SVRG-SO外循环"""
        # 保存参考点
        self.theta_ref = [p.clone() for p in self.model.parameters()]
        
        # 计算参考梯度
        self.grad_ref = self._compute_full_gradient()
        
        # 内循环
        for _ in range(self.inner_iter):
            self._inner_step()
    
    def _inner_step(self):
        """SVRG内循环"""
        # 随机采样 batch
        batch = next(iter(self.data_loader))
        x_batch, y_batch = batch
        
        # 计算梯度
        self.model.zero_grad()
        loss = self.model.loss(x_batch, y_batch)
        loss.backward()
        
        # 估计曲率(使用当前batch的Hessian)
        g_t = torch.cat([p.grad.flatten() for p in self.model.parameters()])
        
        # SVRG梯度修正
        g_ref = self._compute_gradient_at_ref(x_batch, y_batch)
        g_corrected = g_t - g_ref + self.grad_ref
        
        # 二阶更新(简化版:使用对角曲率)
        diag_curvature = self._estimate_diag_curvature()
        update = g_corrected / (diag_curvature + self.damping)
        
        # 更新参数
        offset = 0
        with torch.no_grad():
            for p in self.model.parameters():
                size = p.numel()
                delta = self.lr * update[offset:offset+size].reshape(p.shape)
                p.sub_(delta)
                offset += size
    
    def _compute_full_gradient(self):
        """计算全量梯度(代价高)"""
        self.model.zero_grad()
        total_loss = 0
        for batch in self.data_loader:
            x, y = batch
            loss = self.model.loss(x, y)
            loss.backward()
            total_loss += loss.item()
        return torch.cat([p.grad.flatten() for p in self.model.parameters()])
    
    def _compute_gradient_at_ref(self, x, y):
        """在参考点计算梯度"""
        # 保存当前参数
        current_params = [p.clone() for p in self.model.parameters()]
        
        # 恢复参考参数
        for p, ref in zip(self.model.parameters(), self.theta_ref):
            p.copy_(ref)
        
        # 计算梯度
        self.model.zero_grad()
        loss = self.model.loss(x, y)
        loss.backward()
        grad = torch.cat([p.grad.flatten() for p in self.model.parameters()])
        
        # 恢复当前参数
        for p, cur in zip(self.model.parameters(), current_params):
            p.copy_(cur)
        
        return grad
    
    def _estimate_diag_curvature(self):
        """估计对角曲率(简化为梯度二范数)"""
        grad = torch.cat([p.grad.flatten() for p in self.model.parameters()])
        return grad ** 2

5. 二阶方法与自适应方法的联系

5.1 Adam作为一阶近似的K-FAC

从信息几何视角理解Adam

Adam的自适应学习率机制可以理解为对K-FAC的粗粒度近似。考虑以下对应关系:

K-FACAdam
Fisher矩阵 二阶矩
Fisher逆 梯度归一化
Kronecker分解对角近似
精确曲率指数滑动平均估计

形式化对应

K-FAC的更新为:

Adam的更新为:

若将 近似为对角矩阵,则:

这表明Adam实际上是在参数级别应用了K-FAC的”粗粒度”版本。

5.2 从二阶视角理解Adam

Adam的收敛性重新审视

从二阶方法的视角,Adam的学习率调整 实际上是对梯度方向的各向异性缩放

梯度分解

设参数 ,Adam的更新可写为:

这相当于为每个参数维护一个局部学习率,其大小与该参数方向上的梯度历史成反比。

与自然梯度的联系

若定义等效的Fisher矩阵为:

则Adam更新近似为:

这正是自然梯度更新的形式!

5.3 混合方法的潜力

AdaGrad-Fisher统一框架

class HybridOptimizer:
    """
    混合优化器:结合二阶信息与自适应学习率
    
    策略1: 在小参数上使用K-FAC,大参数上使用Adam
    策略2: 动态切换基于训练阶段
    策略3: 层次化预条件
    """
    
    def __init__(
        self,
        model,
        lr=0.001,
        kfac_lr=0.001,
        adam_lr=0.001,
        threshold=1e5  # 参数量阈值
    ):
        self.model = model
        self.lr = lr
        self.kfac_lr = kfac_lr
        self.adam_lr = adam_lr
        self.threshold = threshold
        
        # 初始化子优化器
        self.kfac = KFACOptimizer(model, lr=kfac_lr)
        self.adam = torch.optim.Adam(
            model.parameters(),
            lr=adam_lr
        )
        
        # 参数分区
        self.kfac_params = []
        self.adam_params = []
        self._partition_params()
    
    def _partition_params(self):
        """根据参数量分区"""
        for name, param in self.model.named_parameters():
            if param.numel() > self.threshold:
                self.adam_params.append(param)
            else:
                self.kfac_params.append((name, param))
    
    def step(self):
        """混合优化步骤"""
        # K-FAC处理小参数
        for name, param in self.kfac_params:
            # K-FAC 更新逻辑
            pass
        
        # Adam处理大参数
        self.adam.step()
    
    def zero_grad(self):
        """清零梯度"""
        self.model.zero_grad()

Sophia optimizer:二阶-一阶混合

Sophia优化器利用Hessian对角线估计来指导Adam风格的更新:

class SophiaOptimizer:
    """
    Sophia优化器
    
    使用Hessian对角线作为预条件子
    结合二阶信息与Adam的高效性
    """
    
    def __init__(
        self,
        model,
        lr=0.001,
        betas=(0.9, 0.999),
        rho=0.04,
        eps=1e-8
    ):
        self.model = model
        self.lr = lr
        self.betas = betas
        self.rho = rho  # Hessian估计的衰减率
        self.eps = eps
        
        # 状态变量
        self.exp_avg_grad = {}
        self.exp_avg_hessian_diag = {}
        
        # 初始化
        for name, param in model.named_parameters():
            self.exp_avg_grad[name] = torch.zeros_like(param)
            self.exp_avg_hessian_diag[name] = torch.zeros_like(param)
    
    def _estimate_hessian_diag(self, param, grad, prev_param, prev_grad):
        """
        估计Hessian对角线
        
        使用曲率比率: (g_new - g_old) / (p_new - p_old)
        """
        delta_param = param - prev_param
        delta_grad = grad - prev_grad
        
        # 避免除零
        diag = delta_grad / (delta_param + self.eps)
        diag = torch.clamp(diag, min=self.eps)
        
        return diag
    
    def step(self):
        """Sophia优化步骤"""
        prev_params = {name: p.clone() for name, p in self.model.named_parameters()}
        
        loss = self._compute_loss()
        loss.backward()
        
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if param.grad is None:
                    continue
                
                grad = param.grad
                prev_grad = self.exp_avg_grad[name]
                
                # 估计Hessian对角线
                if torch.norm(prev_params[name] - param) > self.eps:
                    hessian_diag = self._estimate_hessian_diag(
                        param, grad, prev_params[name], prev_grad
                    )
                else:
                    hessian_diag = self.eps * torch.ones_like(param)
                
                # EMA更新Hessian估计
                self.exp_avg_hessian_diag[name] = (
                    self.rho * self.exp_avg_hessian_diag[name] +
                    (1 - self.rho) * hessian_diag
                )
                
                # EMA更新梯度(用于下次Hessian估计)
                self.exp_avg_grad[name] = (
                    self.betas[0] * self.exp_avg_grad[name] +
                    (1 - self.betas[0]) * grad
                )
                
                # 预条件梯度
                clipped_hessian = torch.clamp(
                    self.exp_avg_hessian_diag[name],
                    min=self.eps
                )
                precond_grad = grad / clipped_hessian
                
                # 更新参数
                param -= self.lr * precond_grad
                
                # L2正则化
                param -= self.lr * 0.01 * param
        
        self.model.zero_grad()
    
    def _compute_loss(self):
        """计算损失(子类实现)"""
        raise NotImplementedError

Muon优化器:Newton-Schulz方法

Muon使用Newton-Schulz迭代来近似矩阵平方根的逆:

class MuonOptimizer:
    """
    Muon优化器
    
    使用Newton-Schulz迭代计算正交预条件子
    """
    
    def __init__(
        self,
        model,
        lr=0.001,
        momentum=0.95,
        nesterov=True
    ):
        self.model = model
        self.lr = lr
        self.momentum = momentum
        self.nesterov = nesterov
        
        self.m = {}  # 动量
        self._init_momentum()
    
    def _init_momentum(self):
        """初始化动量"""
        for name, param in self.model.named_parameters():
            self.m[name] = torch.zeros_like(param)
    
    def _newton_schulz_iteration(self, G, num_iters=5):
        """
        Newton-Schulz迭代
        
        将 G 迭代转换为正交矩阵 Q,使得 Q ≈ G / ||G||
        
        迭代公式: X_{k+1} = (3/2) X_k - (1/2) X_k X_k^T X_k
        """
        X = G / max(torch.norm(G), 1e-8)
        
        for _ in range(num_iters):
            X_new = (3/2) * X - (1/2) * X @ X.T @ X
            X = X_new
        
        return X
    
    def step(self):
        """Muon优化步骤"""
        loss = self._compute_loss()
        loss.backward()
        
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if param.grad is None:
                    continue
                
                grad = param.grad
                m_old = self.m[name]
                
                # 更新动量
                self.m[name] = (
                    self.momentum * m_old +
                    (1 - self.momentum) * grad
                )
                
                # 计算更新方向
                if self.nesterov:
                    grad_corrected = (
                        self.momentum * self.m[name] +
                        (1 - self.momentum) * grad
                    )
                else:
                    grad_corrected = self.m[name]
                
                # 对于向量参数,使用简单的梯度更新
                if param.dim() == 1:
                    update = grad_corrected
                else:
                    # 矩阵参数:计算正交预条件子
                    # 简化为使用权重的正交变换
                    # 实际中需要使用梯度矩阵
                    update = grad_corrected
                
                # 应用更新
                param -= self.lr * update
        
        self.model.zero_grad()
    
    def _compute_loss(self):
        """计算损失(子类实现)"""
        raise NotImplementedError

6. 实现细节与代码示例

6.1 K-FAC的简化实现

以下是一个完整可运行的K-FAC简化实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import numpy as np
 
# ============== 简化K-FAC实现 ==============
 
class SimpleKFAC:
    """
    简化的K-FAC优化器
    
    适用于小型到中型模型
    """
    
    def __init__(
        self,
        model,
        lr=0.01,
        damping=1e-3,
        stat_decay=0.95,
        inverse_interval=30
    ):
        self.model = model
        self.lr = lr
        self.damping = damping
        self.stat_decay = stat_decay
        self.inverse_interval = inverse_interval
        
        self.step_count = 0
        self.fisher = {}  # Fisher统计
        self.inverse = {}  # Fisher逆
        
        self._init_fisher()
    
    def _init_fisher(self):
        """初始化Fisher统计"""
        for name, param in self.model.named_parameters():
            if param.dim() >= 2:
                self.fisher[name] = {
                    'a': None,  # 输入侧
                    'g': None   # 梯度侧
                }
                self.inverse[name] = {
                    'a': None,
                    'g': None
                }
    
    def _update_fisher(self, batch_x, batch_y):
        """更新Fisher统计"""
        self.model.zero_grad()
        
        # 前向传播
        output = self.model(batch_x)
        loss = F.cross_entropy(output, batch_y)
        loss.backward()
        
        for name, param in self.model.named_parameters():
            if param.grad is None or name not in self.fisher:
                continue
            
            grad = param.grad
            
            if isinstance(param, nn.Linear):
                # 分解为 a 和 g
                # 需要hook捕获输入和梯度
                # 简化:使用均匀初始化
                if self.fisher[name]['a'] is None:
                    self.fisher[name]['a'] = torch.eye(param.shape[1], device=param.device)
                    self.fisher[name]['g'] = torch.eye(param.shape[0], device=param.device)
                
                # 更新g侧(梯度侧)
                g_new = grad @ grad.T / grad.numel()
                self.fisher[name]['g'] = (
                    self.stat_decay * self.fisher[name]['g'] +
                    (1 - self.stat_decay) * g_new
                )
    
    def _update_inverse(self):
        """更新Fisher逆矩阵"""
        for name in self.fisher:
            a = self.fisher[name]['a']
            g = self.fisher[name]['g']
            
            if a is None or g is None:
                continue
            
            # 添加阻尼并求逆
            a_reg = a + self.damping * torch.eye(a.shape[0], device=a.device)
            g_reg = g + self.damping * torch.eye(g.shape[0], device=g.device)
            
            # 特征分解求逆
            eig_a, vec_a = torch.linalg.eigh(a_reg)
            eig_g, vec_g = torch.linalg.eigh(g_reg)
            
            # 截断小特征值
            eig_a = torch.clamp(eig_a, min=self.damping)
            eig_g = torch.clamp(eig_g, min=self.damping)
            
            self.inverse[name]['a'] = vec_a @ torch.diag(1/eig_a) @ vec_a.T
            self.inverse[name]['g'] = vec_g @ torch.diag(1/eig_g) @ vec_g.T
    
    def step(self, batch_x, batch_y):
        """K-FAC优化步骤"""
        self.step_count += 1
        
        # 更新Fisher统计
        self._update_fisher(batch_x, batch_y)
        
        # 定期更新逆矩阵
        if self.step_count % self.inverse_interval == 0:
            self._update_inverse()
        
        # 应用修正梯度
        self.model.zero_grad()
        
        output = self.model(batch_x)
        loss = F.cross_entropy(output, batch_y)
        loss.backward()
        
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if param.grad is None or name not in self.inverse:
                    continue
                
                if self.inverse[name]['a'] is None:
                    continue
                
                inv_a = self.inverse[name]['a']
                inv_g = self.inverse[name]['g']
                grad = param.grad
                
                # Kronecker乘积应用:(G^{-1} ⊗ A^{-1}) vec(∇W)
                corrected = inv_g @ grad @ inv_a.T
                
                param -= self.lr * corrected
        
        self.model.zero_grad()
 
 
# ============== 对比实验 ==============
 
def train_comparison():
    """
    比较SGD、Adam和K-FAC在MNIST上的表现
    """
    # 生成合成数据(模拟MNIST)
    torch.manual_seed(42)
    n_samples = 5000
    n_features = 784
    n_classes = 10
    
    X_train = torch.randn(n_samples, n_features)
    y_train = torch.randint(0, n_classes, (n_samples,))
    
    X_test = torch.randn(1000, n_features)
    y_test = torch.randint(0, n_classes, (1000,))
    
    # 定义模型
    class SimpleNet(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(784, 256)
            self.fc2 = nn.Linear(256, 128)
            self.fc3 = nn.Linear(128, 10)
        
        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            return self.fc3(x)
    
    # 训练函数
    def train_model(model, optimizer, epochs=20):
        losses = []
        accuracies = []
        
        for epoch in range(epochs):
            # Mini-batch训练
            indices = torch.randperm(n_samples)[:128]
            batch_x = X_train[indices]
            batch_y = y_train[indices]
            
            if hasattr(optimizer, 'step'):
                if isinstance(optimizer, SimpleKFAC):
                    optimizer.step(batch_x, batch_y)
                else:
                    optimizer.zero_grad()
                    output = model(batch_x)
                    loss = F.cross_entropy(output, batch_y)
                    loss.backward()
                    optimizer.step()
            
            # 评估
            if epoch % 5 == 0:
                with torch.no_grad():
                    output = model(X_test)
                    pred = output.argmax(dim=1)
                    acc = (pred == y_test).float().mean()
                    losses.append(loss.item())
                    accuracies.append(acc.item())
        
        return losses, accuracies
    
    # 运行对比
    results = {}
    
    # SGD
    print("Training with SGD...")
    torch.manual_seed(42)
    model_sgd = SimpleNet()
    opt_sgd = torch.optim.SGD(model_sgd.parameters(), lr=0.01, momentum=0.9)
    losses_sgd, accs_sgd = train_model(model_sgd, opt_sgd)
    results['SGD'] = {'loss': losses_sgd, 'acc': accs_sgd}
    
    # Adam
    print("Training with Adam...")
    torch.manual_seed(42)
    model_adam = SimpleNet()
    opt_adam = torch.optim.Adam(model_adam.parameters(), lr=0.001)
    losses_adam, accs_adam = train_model(model_adam, opt_adam)
    results['Adam'] = {'loss': losses_adam, 'acc': accs_adam}
    
    # K-FAC
    print("Training with K-FAC...")
    torch.manual_seed(42)
    model_kfac = SimpleNet()
    opt_kfac = SimpleKFAC(model_kfac, lr=0.01, damping=1e-2)
    losses_kfac, accs_kfac = train_model(model_kfac, opt_kfac)
    results['K-FAC'] = {'loss': losses_kfac, 'acc': accs_kfac}
    
    # 绘制结果
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    for name, res in results.items():
        ax1.plot(res['loss'], label=name)
        ax2.plot(res['acc'], label=name)
    
    ax1.set_xlabel('Evaluation Step')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Loss Comparison')
    ax1.legend()
    ax1.grid(True)
    
    ax2.set_xlabel('Evaluation Step')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Test Accuracy Comparison')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig('/tmp/kfac_comparison.png', dpi=150)
    plt.show()
    
    return results
 
 
# 运行实验
if __name__ == "__main__":
    results = train_comparison()
    
    print("\n=== Final Results ===")
    for name, res in results.items():
        print(f"{name}: Final Loss = {res['loss'][-1]:.4f}, Final Acc = {res['acc'][-1]:.4f}")

6.2 与标准优化器的对比实验

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
 
class OptimizerBenchmark:
    """
    优化器性能基准测试
    """
    
    def __init__(
        self,
        model,
        train_data,
        test_data,
        optimizers_to_test
    ):
        self.model = model
        self.train_data = train_data
        self.test_data = test_data
        self.optimizers = optimizers_to_test
    
    def benchmark_step_time(self, optimizer_name, optimizer, n_runs=100):
        """测量单步训练时间"""
        times = []
        
        for _ in range(n_runs):
            x, y = self.train_data
            
            start = time.time()
            optimizer.zero_grad()
            output = self.model(x)
            loss = F.cross_entropy(output, y)
            loss.backward()
            optimizer.step()
            
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            
            times.append(time.time() - start)
        
        return np.mean(times), np.std(times)
    
    def benchmark_convergence(self, optimizer_name, optimizer, max_epochs=50):
        """测量收敛速度和最终性能"""
        losses = []
        accuracies = []
        
        for epoch in range(max_epochs):
            x, y = self.train_data
            
            optimizer.zero_grad()
            output = self.model(x)
            loss = F.cross_entropy(output, y)
            loss.backward()
            optimizer.step()
            
            # 评估
            with torch.no_grad():
                test_output = self.model(self.test_data[0])
                test_loss = F.cross_entropy(test_output, self.test_data[1])
                pred = test_output.argmax(dim=1)
                acc = (pred == self.test_data[1]).float().mean()
            
            losses.append(test_loss.item())
            accuracies.append(acc.item())
            
            if epoch % 10 == 0:
                print(f"{optimizer_name} - Epoch {epoch}: Loss = {test_loss:.4f}, Acc = {acc:.4f}")
        
        return losses, accuracies
    
    def run_full_benchmark(self):
        """运行完整基准测试"""
        results = {}
        
        for name, optimizer in self.optimizers.items():
            print(f"\n{'='*50}")
            print(f"Benchmarking: {name}")
            print('='*50)
            
            # 重置模型
            for p in self.model.parameters():
                if hasattr(p, 'data'):
                    torch.nn.init.xavier_uniform_(p.data)
            
            # 测量时间
            mean_time, std_time = self.benchmark_step_time(name, optimizer)
            print(f"Step time: {mean_time*1000:.2f} ± {std_time*1000:.2f} ms")
            
            # 测量收敛
            losses, accuracies = self.benchmark_convergence(name, optimizer)
            
            results[name] = {
                'step_time': (mean_time, std_time),
                'final_loss': losses[-1],
                'final_acc': accuracies[-1],
                'losses': losses,
                'accuracies': accuracies
            }
        
        return results
 
 
def create_optimizers(model):
    """创建待测试的优化器字典"""
    return {
        'SGD': torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
        'Adam': torch.optim.Adam(model.parameters(), lr=0.001),
        'AdamW': torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01),
        'RMSprop': torch.optim.RMSprop(model.parameters(), lr=0.001),
    }

7. 总结与展望

7.1 方法对比总结

方法曲率估计计算复杂度内存复杂度收敛速度适用场景
牛顿法精确Hessian二次(局部)小规模问题
K-FACKronecker分解中等规模
Shampoo分块逆平方根大规模
L-BFGS秩-2更新中等规模
CGHessian-vector大规模
Adam对角近似中等超大规模

7.2 未来研究方向

  1. 通信高效分布式二阶优化:减少跨GPU同步开销
  2. 低精度二阶方法:在BF16/FP8下保持二阶方法的稳定性
  3. 自适应曲率估计:根据训练阶段动态调整曲率估计精度
  4. Transformer缩放定律的交互:理解二阶优化在大模型训练中的 scaling behavior

参考

Footnotes

  1. Amari, S. (1998). Natural gradient works efficiently in learning. Neural Computation, 10(2), 251-276.

  2. Martens, J., & Grosse, R. (2015). Optimizing neural networks with kronecker-factored approximate curvature. ICML 2015.

  3. Rao, C. R. (1945). Information and the accuracy attainable in the estimation of statistical parameters. Bulletin of the Calcutta Mathematical Society, 37, 81-91.

  4. Amari, S. (2020). Information geometry and its applications: On the front line of data science. IEEE Signal Processing Magazine, 37(6), 119-127.