二阶优化方法

二阶优化方法利用目标函数的二阶导数信息(Hessian矩阵)来指导参数更新,通常能获得比一阶方法更快的收敛速度,但计算代价更高。1


背景与动机

一阶方法的局限

梯度下降(GD)使用一阶信息:

问题

  • 对病态条件数敏感
  • 收敛速度受限于最大曲率与最小曲率的比值
  • 需要仔细调节学习率

二阶方法的思路

二阶方法利用曲率信息:

其中 是Hessian矩阵。

优势

  • 对条件数不敏感
  • 理论上二次收敛
  • 自适应学习率

挑战

  • Hessian矩阵计算 ,其中 是参数维度
  • Hessian求逆
  • Hessian可能不正定

泰勒展开视角

二阶近似

处对损失函数做二阶泰勒展开:

其中

解析最优步

令导数为零:

得到:


牛顿法

原始牛顿法

import numpy as np
 
def newton_step(loss_fn, grad_fn, hess_fn, theta):
    """
    牛顿法更新步
    
    θ_{t+1} = θ_t - H^{-1}∇L
    """
    g = grad_fn(theta)
    H = hess_fn(theta)
    
    # 使用 Cholesky 分解求解 H Δ = -g
    try:
        L = np.linalg.cholesky(H)
        delta = np.linalg.solve(L.T, np.linalg.solve(L, -g))
    except np.linalg.LinAlgError:
        # 如果 H 不是正定,添加正则化
        H_reg = H + 1e-4 * np.eye(len(theta))
        delta = -np.linalg.solve(H_reg, g)
    
    return delta
 
def newton_descent(loss_fn, grad_fn, hess_fn, theta_init, max_iter=100, tol=1e-6):
    """牛顿法优化"""
    theta = theta_init.copy()
    
    for i in range(max_iter):
        g = grad_fn(theta)
        if np.linalg.norm(g) < tol:
            break
        
        delta = newton_step(loss_fn, grad_fn, hess_fn, theta)
        theta = theta + delta
        
        print(f"Iter {i}: loss = {loss_fn(theta):.6f}, |grad| = {np.linalg.norm(g):.6f}")
    
    return theta

牛顿法的收敛性质

条件收敛速度
强凸+光滑二次收敛
凸函数二次收敛(局部)
一般非凸可能发散

问题:Hessian计算代价

对于现代神经网络:

  • 参数量
  • Hessian 矩阵大小:
  • 无法直接存储或计算

自然梯度法

理念

自然梯度法在黎曼流形上进行梯度下降,考虑参数空间的几何结构。

KL散度作为度量

在概率分布空间上,使用KL散度定义”距离”:

其中 Fisher信息矩阵(近似Hessian)。

自然梯度更新

关键区别:梯度被 重新缩放,考虑了参数空间的几何。

class NaturalGradient:
    """
    自然梯度下降
    
    使用 Fisher 信息矩阵的近似
    """
    
    def __init__(self, model, lr=0.01, fisher_estimator=None):
        self.model = model
        self.lr = lr
        self.fisher_estimator = fisher_estimator or EmpiricFisherEstimator()
        self.fisher = None
    
    def step(self, batch_x, batch_y):
        """
        自然梯度一步
        """
        # 1. 计算普通梯度
        loss = self.model.loss(batch_x, batch_y)
        loss.backward()
        
        # 2. 估计 Fisher 信息矩阵
        self.fisher = self.fisher_estimator.estimate(self.model)
        
        # 3. 计算自然梯度:F^{-1} g
        natural_grad = self._solve_fisher_system()
        
        # 4. 更新参数
        with torch.no_grad():
            for param, ng in zip(self.model.parameters(), natural_grad):
                param.add_(ng, alpha=-self.lr)
        
        # 清除梯度
        self.model.zero_grad()
    
    def _solve_fisher_system(self):
        """
        求解 F Δ = -g
        
        使用共轭梯度法或近似方法
        """
        # 简化为使用对角近似
        diag_fisher = self.fisher.diagonal()
        grads = [p.grad.flatten() for p in self.model.parameters() if p.grad is not None]
        natural_grads = []
        
        offset = 0
        for p in self.model.parameters():
            if p.grad is None:
                continue
            size = p.numel()
            grad = p.grad.flatten().numpy()
            diag = diag_fisher[offset:offset+size]
            ng = -grad / (diag + 1e-8)
            natural_grads.append(torch.tensor(ng).reshape(p.shape))
            offset += size
        
        return natural_grads
 
class EmpiricFisherEstimator:
    """
    经验 Fisher 信息矩阵估计
    
    F ≈ (1/N) Σ ∇_θ log p(y|x,θ) ∇_θ log p(y|x,θ)^T
    """
    
    def estimate(self, model, n_samples=100):
        """估计 Fisher 矩阵"""
        model.eval()
        
        # 收集梯度
        grads = []
        for _ in range(n_samples):
            x, y = model.sample_batch()
            model.zero_grad()
            loss = model.loss(x, y)
            loss.backward()
            
            g = torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None])
            grads.append(g)
        
        grads = torch.stack(grads)
        
        # Fisher = 样本协方差
        fisher = (grads.T @ grads) / n_samples
        
        return fisher

K-FAC(Kronecker-Factored Approximate Curvature)

Kronecker 分解思想

K-FAC 将 Fisher 信息矩阵分解为小块Kronecker积:

这使得:

  • 存储:从 降到
  • 求解:从 降到

K-FAC 的核心假设

对于神经网络层

  • 输入 与输出梯度 近似独立
  • Fisher 矩阵可分解为输入相关与输出相关的部分
import torch
import torch.nn as nn
 
class KFACOptimizer:
    """
    K-FAC 优化器
    
    Kronecker-Factored Approximate Curvature
    """
    
    def __init__(self, model, lr=0.001, momentum=0.9, 
                 kl_step=10, ra_step=200, damping=1e-3):
        self.model = model
        self.lr = lr
        self.momentum = momentum
        self.kl_step = kl_step  # Fisher 更新频率
        self.ra_step = ra_step  # 逆更新频率
        self.damping = damping
        
        self.step_counter = 0
        self.F_blocks = {}  # Fisher 块
        self.buffers = {}   # 存储和逆
        self.momentum_buffers = {}
        
        self._init_buffers()
    
    def _init_buffers(self):
        """初始化 Fisher 块和缓冲"""
        self.F_blocks = {}
        self.buffers = {}
        self.momentum_buffers = {}
        
        for name, param in self.model.named_parameters():
            if param.dim() >= 2:
                # 对于线性层和卷积层
                self.F_blocks[name] = {
                    'A': None,  # 输入侧 Fisher
                    'B': None,  # 输出侧 Fisher
                }
                self.buffers[name] = {
                    'inv_A': None,
                    'inv_B': None,
                }
                self.momentum_buffers[name] = {
                    'grad_buffer': torch.zeros_like(param),
                }
    
    def step(self):
        """K-FAC 优化步骤"""
        self.step_counter += 1
        
        # 1. 累积 Fisher 估计
        if self.step_counter % self.kl_step == 0:
            self._update_fisher_blocks()
        
        # 2. 更新逆矩阵(定期)
        if self.step_counter % self.ra_step == 0:
            self._update_inverse_matrices()
        
        # 3. 应用梯度修正
        self._apply_corrected_gradient()
        
        # 清除梯度
        self.model.zero_grad()
    
    def _update_fisher_blocks(self):
        """
        使用当前梯度更新 Fisher 块估计
        
        E[∇W ⊗ ∇W^T] ≈ (1/B) Σ ∇W_i ⊗ ∇W_i^T
        """
        for name, param in self.model.named_parameters():
            if param.grad is None or name not in self.F_blocks:
                continue
            
            grad = param.grad
            
            # 根据层类型计算 Kronecker 分解
            if isinstance(param, nn.Linear):
                self._update_linear_fisher(name, grad, param)
            elif isinstance(param, nn.Conv2d):
                self._update_conv_fisher(name, grad, param)
    
    def _update_linear_fisher(self, name, grad, param):
        """线性层的 Fisher 更新"""
        # grad: (out_features, in_features)
        # 分解为: A ⊗ B
        
        # B: 沿着输出维度的协方差
        B = torch.einsum('oi,oj->ij', grad, grad) / grad.numel()
        
        # A: 沿着输入维度的协方差
        # 需要保存输入 x,这在实际中需要hook
        if hasattr(self.model, '_kfac_inputs') and name in self.model._kfac_inputs:
            x = self.model._kfac_inputs[name]
            A = torch.einsum('bi,bj->ij', x, x) / x.size(0)
        else:
            # 使用单位矩阵作为近似
            A = torch.eye(grad.size(1), device=grad.device)
        
        # 指数滑动平均更新
        if self.F_blocks[name]['A'] is None:
            self.F_blocks[name]['A'] = A
            self.F_blocks[name]['B'] = B
        else:
            self.F_blocks[name]['A'] = 0.9 * self.F_blocks[name]['A'] + 0.1 * A
            self.F_blocks[name]['B'] = 0.9 * self.F_blocks[name]['B'] + 0.1 * B
    
    def _update_conv_fisher(self, name, grad, param):
        """卷积层的 Fisher 更新"""
        # 对于卷积层,需要更复杂的 reshape
        pass  # 简化处理
    
    def _update_inverse_matrices(self):
        """
        更新 Fisher 块的逆矩阵
        
        使用特征分解或 SVD
        """
        for name in self.F_blocks:
            A = self.F_blocks[name]['A']
            B = self.F_blocks[name]['B']
            
            if A is None or B is None:
                continue
            
            # A^{-1}
            eigenvalues_A, eigenvectors_A = torch.linalg.eigh(A)
            inv_eigenvalues_A = 1.0 / (eigenvalues_A + self.damping)
            inv_A = eigenvectors_A @ torch.diag(inv_eigenvalues_A) @ eigenvectors_A.T
            
            # B^{-1}
            eigenvalues_B, eigenvectors_B = torch.linalg.eigh(B)
            inv_eigenvalues_B = 1.0 / (eigenvalues_B + self.damping)
            inv_B = eigenvectors_B @ torch.diag(inv_eigenvalues_B) @ eigenvectors_B.T
            
            self.buffers[name]['inv_A'] = inv_A
            self.buffers[name]['inv_B'] = inv_B
    
    def _apply_corrected_gradient(self):
        """
        应用 Kronecker-Factored 梯度修正
        
        ΔW = -η * (B^{-1} ⊗ A^{-1}) vec(∇W)
        """
        for name, param in self.model.named_parameters():
            if param.grad is None or name not in self.F_blocks:
                continue
            
            grad = param.grad
            
            inv_A = self.buffers[name]['inv_A']
            inv_B = self.buffers[name]['inv_B']
            
            if inv_A is None or inv_B is None:
                # 如果还没有 Fisher 估计,使用普通梯度
                param.data.add_(grad, alpha=-self.lr)
                continue
            
            # Kronecker 乘积的逆 = 逆的 Kronecker 乘积
            # (B^{-1} ⊗ A^{-1}) vec(∇W) = vec(B^{-1} ∇W A^{-1})
            
            # 重塑梯度
            if isinstance(param, nn.Linear):
                grad_matrix = grad.view(grad.size(0), -1)
                corrected_grad = torch.einsum('ij,jk,kl->il', 
                    inv_B, grad_matrix, inv_A)
                corrected_grad = corrected_grad.view_as(grad)
            else:
                corrected_grad = grad
            
            # 应用梯度
            param.data.add_(corrected_grad, alpha=-self.lr)
 
# 使用示例
def demo_kfac():
    """K-FAC 优化器使用示例"""
    import torch.nn.functional as F
    
    model = nn.Sequential(
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Linear(256, 10)
    )
    
    optimizer = KFACOptimizer(model, lr=0.001)
    
    # 训练循环
    for epoch in range(10):
        for batch_x, batch_y in dataloader:
            output = model(batch_x)
            loss = F.cross_entropy(output, batch_y)
            loss.backward()
            optimizer.step()

Hessian-Free 优化

思想

不显式存储 Hessian 矩阵,使用共轭梯度法(CG)求解

共轭梯度法

def conjugate_gradient(A, b, x0=None, max_iter=None, tol=1e-6):
    """
    共轭梯度法求解 Ax = b
    
    其中 A 是对称正定矩阵
    """
    if x0 is None:
        x0 = torch.zeros_like(b)
    
    x = x0.clone()
    r = b - A @ x
    p = r.clone()
    rsold = r @ r
    
    n = len(b) if hasattr(b, '__len__') else b.numel()
    if max_iter is None:
        max_iter = n
    
    for i in range(max_iter):
        Ap = A @ p
        alpha = rsold / (p @ Ap + 1e-10)
        
        x = x + alpha * p
        r = r - alpha * Ap
        
        rsnew = r @ r
        if np.sqrt(rsnew) < tol:
            break
        
        beta = rsnew / rsold
        p = r + beta * p
        rsold = rsnew
    
    return x
 
class HessianFreeOptimizer:
    """
    Hessian-Free 优化器
    
    使用共轭梯度法求解牛顿步
    """
    
    def __init__(self, model, lr=1.0, damping=0.1):
        self.model = model
        self.lr = lr
        self.damping = damping
    
    def _hessian_vector_product(self, vectors):
        """
        计算 Hv(无需显式计算 H)
        
        使用有限差分近似:
        Hv ≈ (∇²L) v ≈ (∇L(w + εv) - ∇L(w - εv)) / (2ε)
        """
        epsilon = 1e-3
        
        # 保存原始参数
        params = [p.clone() for p in self.model.parameters()]
        
        # 计算 ∇L(w + εv)
        self._add_to_params(vectors, epsilon)
        self.model.zero_grad()
        loss_plus = self.model.loss()
        grads_plus = torch.cat([p.grad.flatten() for p in self.model.parameters() if p.grad])
        
        # 计算 ∇L(w - εv)
        self._add_to_params(vectors, -2*epsilon)
        self.model.zero_grad()
        loss_minus = self.model.loss()
        grads_minus = torch.cat([p.grad.flatten() for p in self.model.parameters() if p.grad])
        
        # 恢复原始参数
        for p, orig in zip(self.model.parameters(), params):
            p.data = orig
        
        # Hv ≈ (g_plus - g_minus) / (2ε)
        Hv = (grads_plus - grads_minus) / (2 * epsilon)
        
        # 添加 damping 项:D = damping * I
        Hv = Hv + self.damping * torch.cat(vectors)
        
        return Hv
    
    def _add_to_params(self, vectors, alpha):
        """给参数添加向量"""
        offset = 0
        for p in self.model.parameters():
            size = p.numel()
            v = vectors[offset:offset+size].reshape(p.shape)
            p.data.add_(v, alpha=alpha)
            offset += size
    
    def step(self):
        """
        Hessian-Free 优化步骤
        
        使用 CG 求解 H^{-1}g
        """
        # 获取梯度
        self.model.zero_grad()
        loss = self.model.loss()
        loss.backward()
        
        g = torch.cat([p.grad.flatten() for p in self.model.parameters() if p.grad])
        
        # 定义 Hv 函数
        def Hv_func(v):
            vectors = torch.split(g, [p.numel() for p in self.model.parameters() 
                                       if p.grad is not None])
            return self._hessian_vector_product(vectors)
        
        # 使用 CG 求解 H^{-1}g
        delta = conjugate_gradient(
            Hv_func, 
            -g,  # 注意 CG 求解 Ax = b,所以这里用 -g
            max_iter=50
        )
        
        # 更新参数
        offset = 0
        for p in self.model.parameters():
            if p.grad is None:
                continue
            size = p.numel()
            delta_p = delta[offset:offset+size].reshape(p.shape)
            p.data.add_(delta_p, alpha=-self.lr)
            offset += size
        
        return delta

其他二阶方法

1. L-BFGS(Limited-memory BFGS)

使用历史梯度信息逼近 Hessian:

def lbfgs_step(grads, deltas, g_new, s_new, y_new):
    """
    L-BFGS 两循环递归
    
    计算近似逆 Hessian 与梯度的乘积
    """
    # 两循环递归
    q = g_new.clone()
    
    # 反向循环
    alphas = []
    for rho, s, y in zip(reversed(grads), reversed(deltas), reversed(grads)):
        alpha = rho * torch.dot(s, q)
        alphas.insert(0, alpha)
        q = q - alpha * y
    
    # 初始 guess:使用最近的曲率信息
    if len(grads) > 0:
        s, y = deltas[-1], grads[-1]
        rho = 1.0 / (torch.dot(y, s) + 1e-10)
        z = rho * torch.dot(s, y) / (torch.dot(y, y) + 1e-10) * q
    else:
        z = q
    
    # 正向循环
    for alpha, rho, s, y in zip(alphas, reversed(grads), reversed(deltas), reversed(grads)):
        beta = rho * torch.dot(y, z)
        z = z + s * (alpha - beta)
    
    return z

2. Gauss-Newton 与 Fisher 的联系

Gauss-Newton 矩阵:

其中 是残差的雅可比矩阵。

重要关系

方法曲率矩阵
原始牛顿Hessian
Gauss-Newton
Fisher

在分类问题中,当模型正确指定时,


实践注意事项

何时使用二阶方法

def should_use_second_order(
    n_params,
    condition_number=None,
    batch_size=None,
    training_time_budget=None
):
    """
    决策:是否使用二阶优化器
    """
    # K-FAC 存储需求
    kfac_storage = n_params * 4  # 近似值(4个浮点数)
    
    # Hessian-Free 每次迭代成本
    hf_cost_per_iter = 50 * n_params  # CG 约50次迭代
    
    # 建议
    if n_params > 1e8:
        return "K-FAC(如果显存足够)"
    elif n_params > 1e6:
        return "L-BFGS 或 K-FAC"
    elif condition_number and condition_number > 1000:
        return "L-BFGS 或 HF"
    else:
        return "Adam(通常足够好)"

数值稳定性

def stabilize_hessian_inverse(H, damping=1e-3):
    """
    稳定化 Hessian 逆计算
    """
    # 方法1:添加正则化
    H_reg = H + damping * torch.eye(H.size(0))
    
    # 方法2:使用 Cholesky 分解(自动稳定)
    try:
        L = torch.linalg.cholesky(H_reg)
        H_inv = torch.cholesky_inverse(L)
    except:
        # 回退到特征分解
        eigenvalues, eigenvectors = torch.linalg.eigh(H_reg)
        eigenvalues = torch.clamp(eigenvalues, min=damping)
        H_inv = eigenvectors @ torch.diag(1/eigenvalues) @ eigenvectors.T
    
    return H_inv

参考

Footnotes

  1. Martens, J., & Grosse, R. (2015). Optimizing neural networks with kronecker-factored approximate curvature.