Sharp vs Flat Minima:深度学习优化的新视角

深度学习的一个核心谜题是:为什么使用大量参数的过参数化(over-parameterized)网络能够在训练集上完美拟合,却仍具有良好的泛化能力?传统优化理论无法解释这一现象,因为从欠参数化的经典视角看,这样的模型应该严重过拟合。

近年来,损失景观(Loss Landscape) 的几何特性——特别是极小值的平坦度(Flatness) 与泛化能力之间的关联——为这一问题提供了有价值的解释。本章将系统介绍这一领域的基础理论、度量方法、优化算法及最新进展。


一、背景:损失景观与泛化

1.1 经验风险景观的经验观察

训练神经网络本质上是优化高维非凸损失函数:

其中 可达数十亿量级。传统观点认为,如此高维的非凸函数存在大量尖锐的局部极小值,泛化性能差。然而,大量实验表明,现代优化算法(如 SGD、Adam)收敛到的极小值往往具有良好的泛化性能。

Garipov et al. (2018) 通过两条极小值之间的低损失路径发现:可以通过几乎平坦的路径连接不同的局部极小值,这暗示存在”平原”(plateau)结构。1

Keskar et al. (2017) 的开创性工作比较了批大小不同的训练结果:使用大批次(batch size = 5120)训练得到的极小值泛化性能明显差于小批次(batch size = 128),而前者对应的损失景观更为”尖锐”。2

1.2 泛化差距的实验观察

泛化差距 反映了模型在未见数据上的表现差异。大量实验表明:

训练配置训练损失测试损失泛化差距景观特性
小batch + warmup极低较低较平坦
大batch极低较高较尖锐

He et al. (2019) 的研究进一步表明,通过标签平滑(label smoothing)和随机数据增强等技术,可以缓解大批次训练的性能下降,这与平坦度的改善相关。3

1.3 为什么平坦极小值可能泛化更好?

有几种互补的理论解释:

1. 决策边界复杂度假说

平坦极小值对应的参数对扰动不敏感,因此模型对输入的微小变化(噪声、数据分布漂移)具有更强的鲁棒性。从几何角度看,平坦区域对应的决策边界更加平滑。

2. PAC-Bayes理论联系

根据 PAC-Bayes 边界,泛化误差上界与后验分布的复杂度(KL散度)相关。平坦极小值附近的高概率质量区域更大,对应的贝叶斯后验具有更小的有效复杂度。

3. 随机扰动稳定性

考虑测试分布与训练分布的差异:。如果损失景观在极小值附近足够平坦,则:

其中 反映了分布扰动的幅度。平坦的Hessian( 较小)意味着更小的泛化差距。


二、Sharpness 度量

平坦度是一个直观概念,但如何精确度量”锐度”(Sharpness)?本节介绍几种主流方法。

2.1 Hessian特征值分析

Hessian矩阵 的特征值直接反映了损失函数的局部曲率。在极小值点 处,所有特征值 应为非负:

  • 最大特征值 :最尖锐方向
  • 谱范数 :衡量局部锐度的全局指标
  • 特征值分布 描述曲率的多样性

然而,直接计算Hessian在大型神经网络中不可行()。实用方法包括:

# 使用PyTorch的Hessian向量积估计最大特征值
def power_iteration(model, loss_fn, num_iter=50):
    """幂迭代法估计Hessian最大特征值"""
    v = [torch.randn_like(p) for p in model.parameters()]
    v = normalize(v)  # 归一化
    
    for _ in range(num_iter):
        # 计算Hessian向量积:Hessian @ v
        hv = hessian_vector_product(model, loss_fn, v)
        # 归一化
        v_norm = torch.sqrt(sum(torch.sum(vi**2) for vi in v))
        v = [hi / v_norm for hi in hv]
    
    # Rayleigh商作为特征值估计
    Hv = hessian_vector_product(model, loss_fn, v)
    eigenvalue = sum(torch.sum(vi * hvi) for vi, hvi in zip(v, Hv))
    return eigenvalue.item()

2.2 本征维度分析(Intrinsic Dimension)

Li et al. (2018) 提出的本征维度方法绕过了全参数空间的计算复杂性。4

核心思想:许多神经网络的极小值在一个低维子流形上同样是良好的极小值。

定义:本征维度 是指存在一个半径为 的低维球,使得在该球内所有方向上都接近极小值:

实验发现

  • 即使在 的参数空间中, 就足以找到同样好的极小值
  • 尖锐极小值的本征维度更高,需要更多方向来描述其结构
  • 平坦极小值的本征维度较低,泛化能力更强
def intrinsic_dimension(model, full_loss_fn, subspace_dim, r=2000):
    """
    估计本征维度:找到能维持性能所需的最小子空间维度
    """
    # 1. 使用完整参数训练得到基准极小值 θ*
    theta_star = [p.clone() for p in model.parameters()]
    
    # 2. 随机初始化子空间基向量
    basis = [torch.randn(subspace_dim, p.numel()) for p in model.parameters()]
    
    # 3. 在子空间内优化
    alpha = torch.zeros(subspace_dim, requires_grad=True)
    optimizer = torch.optim.Adam([alpha], lr=0.1)
    
    for step in range(1000):
        # 从子空间重建参数
        theta_subspace = reconstruct(theta_star, basis, alpha)
        load_params(model, theta_subspace)
        
        optimizer.zero_grad()
        loss = full_loss_fn()
        loss.backward()
        optimizer.step()
    
    # 4. 比较子空间优化与完整优化的最终损失
    return final_loss_ratio

2.3 最大锐度与平坦度度量

Foret et al. (2021) 提出的SAM(Sharpness-Aware Minimization)算法定义了最常用的锐度度量——最大锐度(Maximum Sharpness)。5

定义:对于扰动 ,定义扰动域内的最大损失变化:

其中:

  • :范数类型
  • :扰动半径

直观理解

  • 尖锐极小值 很大(小的扰动导致损失剧增)
  • 平坦极小值 很小(损失对扰动不敏感)

Fisher信息矩阵视角:在特定假设下,最大锐度与Fisher信息矩阵的特征值相关:

其中 是Fisher信息矩阵。


三、Sharpness-Minimizing Optimizers

既然平坦极小值可能泛化更好,自然的想法是:设计专门寻找平坦极小值的优化器。

3.1 SAM(Sharpness-Aware Minimization)算法详解

Foret et al. (2021) 提出的SAM是最具影响力的平坦度感知优化器。5

3.1.1 算法框架

SAM的核心思想是两步更新:

第一步:在参数空间中沿梯度方向做”探索”,寻找最坏情况的邻域:

其中 是当前梯度, 是扰动半径。

第二步:在扰动后的位置计算梯度并更新:

3.1.2 目标函数解释

SAM实际上最小化了以下替代目标:

这等价于在参数邻域内寻找最差情况(worst-case)的损失值,因此SAM隐式地平滑了损失景观。

3.1.3 PyTorch实现

class SAM(torch.optim.Optimizer):
    """
    Sharpness-Aware Minimization (SAM) optimizer
    
    论文:Foret et al., "Sharpness-Aware Minimization for Efficiently 
          Improving Generalization" (ICLR 2021)
    """
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho: {rho}"
        
        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)
        
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)
    
    @torch.no_grad()
    def first_step(self, zero_grad=False):
        """第一步:计算扰动并更新到扰动点"""
        grad_norm = self._grad_norm()
        
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)
            
            for p in group["params"]:
                if p.grad is None:
                    continue
                # 存储原始参数
                self.state[p]["old_p"] = p.data.clone()
                # 计算扰动 e = ρ * g / ||g||
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # θ̃ = θ + e
        
        if zero_grad:
            self.zero_grad()
    
    @torch.no_grad()
    def second_step(self, zero_grad=False):
        """第二步:从扰动点计算梯度并恢复"""
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                # 恢复到原始参数
                p.data = self.state[p]["old_p"]
        
        self.base_optimizer.step()  # 执行常规梯度更新
        
        if zero_grad:
            self.zero_grad()
    
    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "SAM requires closure for gradient computation"
        
        # 闭包:在扰动点计算损失和梯度
        closure = torch.enable_grad()(closure)
        
        self.first_step(zero_grad=True)
        closure()
        self.second_step()
    
    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device
        norm = torch.norm(
            torch.stack([
                ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                for group in self.param_groups for p in group["params"]
                if p.grad is not None
            ]),
            p=2
        )
        return norm
    
    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

3.1.4 使用示例

# 基础优化器配置
base_optimizer = torch.optim.Adam
optimizer = SAM(
    model.parameters(),
    base_optimizer,
    rho=0.05,        # 扰动半径
    adaptive=False,  # 是否使用自适应SAM
    lr=1e-3
)
 
# 训练循环
for batch_x, batch_y in dataloader:
    # 前向传播
    def closure():
        loss = criterion(model(batch_x), batch_y)
        return loss
    
    # SAM更新
    optimizer.zero_grad()
    loss = criterion(model(batch_x), batch_y)
    loss.backward()
    optimizer.step(closure)

3.2 SAM变体

3.2.1 GSAM(Gradient-norm Sharpness-Aware Minimization)

Kwon et al. (2021) 指出SAM存在扰动方向敏感性问题——对某些参数方向,SAM的效果可能不稳定。6

GSAM改进:在原始梯度和扰动梯度之间引入自适应权衡:

其中 是自适应权重:

3.2.2 SAM-Adam / SAM-AdamW

标准SAM使用SGD作为基础优化器。Chen & Hsieh (2022) 证明了SAM与自适应方法的兼容性。7

核心观察:Adam的自适应学习率可以视为对梯度进行了预处理:

其中 是梯度平方的指数移动平均。

class SAM_AdamW(SAM):
    """SAM with AdamW base optimizer"""
    def __init__(self, params, rho=0.05, lr=1e-3, weight_decay=1e-2, **kwargs):
        base_optimizer = torch.optim.AdamW
        super().__init__(
            params,
            base_optimizer,
            rho=rho,
            lr=lr,
            weight_decay=weight_decay,
            **kwargs
        )

3.2.3 ESAM(Efficient SAM)

Zhou et al. (2022) 指出SAM的计算开销是标准优化的2-3倍,因为需要两次前向/反向传播。8

ESAM改进:通过梯度存储和重计算,避免重复计算:

def esam_step(model, batch_x, batch_y, rho=0.05, beta=0.5):
    """ESAM: 仅存储关键中间结果以节省内存"""
    
    # 第一次前向+反向,存储必要的中间变量
    output, stored = forward_with_store(model, batch_x)
    loss = criterion(output, batch_y)
    grad = backward(loss, stored)
    
    # 计算扰动方向
    perturbation = [rho * g / (g.norm() + 1e-8) for g in grad]
    
    # 重计算:应用扰动后重新计算(节省存储开销)
    model.apply perturbation
    perturbed_output = model(batch_x)
    perturbed_loss = criterion(perturbed_output, batch_y)
    perturbed_grad = backward(perturbed_loss)
    
    # 恢复参数
    model.revert perturbation
    
    # 更新
    for p, pg in zip(model.parameters(), perturbed_grad):
        p.data.sub_(lr * pg)

3.3 扰动方向的敏感性分析

SAM的一个关键问题是:扰动方向 是否是最优的?

理论分析:考虑二阶近似:

约束 下,最优扰动满足:

其中 是拉格朗日乘子。

关键发现

  • 的主特征向量对齐时,标准SAM方向接近最优
  • 当Hessian特征值差异大时,标准方向可能次优
  • 自适应方法(如GSAM)可以缓解这一问题

四、理论解释

4.1 PAC-Bayes视角下的平坦极小值

PAC-Bayes理论为平坦度-泛化联系提供了最严格的理论基础。

PAC-Bayes边界:对于任意先验 和后验 ,以概率至少 有:

其中 是样本数。

平坦度的PAC-Bayes解释

考虑在极小值 附近的高斯后验:

KL散度近似为:

其中 是Hessian的迹。

关键结论

  • 平坦极小值 小):较小的 导致更小的KL散度 → 更紧的泛化边界
  • 尖锐极小值 大):需要更大的 来覆盖等效概率质量,但会导致更大的KL散度

4.2 随机松弛理论

Gur-Ari et al. (2019) 的随机松弛(Stochastic Relaxation)理论提供了另一种视角。9

核心思想:泛化性能可以通过”随机扰动稳定性”来预测:

训练动态的解释

  • SGD的噪声可以视为在每次更新时探索周围的损失景观
  • 噪声协方差 与Hessian的相互作用决定了收敛到哪种类型的极小值
  • 当噪声协方差与Hessian特征向量对齐时,会促进向平坦区域漂移

4.3 随机梯度噪声的作用

SGD的噪声结构对极小值的平坦度有决定性影响。

噪声模型:假设

其中 是梯度噪声,近似为:

噪声协方差

与平坦度的关系

  • 小batch 小):噪声方差大,协方差矩阵更各向异性
  • 大batch 大):噪声方差小,有效探索范围受限

随机微分方程视角:连续时间极限下,SGD近似以下SDE:

稳态分布为:

这表明噪声协方差 与Hessian 的相对结构决定了参数分布的集中区域。


五、最新进展(2024-2025)

5.1 ESAM的进一步优化

Zhou et al. (2024) 提出了Fisher SAM,利用Fisher信息矩阵近似Hessian,实现更精确的扰动方向估计:10

其中 是Fisher信息矩阵。

5.2 Lookahead-SAM

Chen et al. (2024) 结合Lookahead优化器与SAM,提出了Lookahead-SAM

class LookaheadSAM:
    """
    组合Lookahead与SAM的优势
    """
    def __init__(self, base_model, rho=0.05, la_alpha=0.5, la_period=6):
        self.sam = SAM(base_model, torch.optim.SGD, rho=rho)
        self.la_alpha = la_alpha
        self.la_period = la_period
        self.slow_weights = [p.clone() for p in base_model.parameters()]
        self.step_count = 0
    
    def step(self, closure):
        # SAM更新
        self.sam.step(closure)
        
        # Lookahead同步
        self.step_count += 1
        if self.step_count % self.la_period == 0:
            for sw, fw in zip(self.slow_weights, self.sam.model.parameters()):
                sw.data = self.la_alpha * sw.data + (1 - self.la_alpha) * fw.data
            # 恢复到slow weights
            for p, sw in zip(self.sam.model.parameters(), self.slow_weights):
                p.data = sw.clone()

5.3 收敛性保证

Liu et al. (2024) 首次为SAM提供了严格的收敛性分析:

定理:在满足以下条件时,SAM以速率 收敛到平稳点:

  1. 损失函数 -光滑的
  2. 梯度噪声有界:
  3. 扰动半径

证明概要:SAM的每次迭代可分解为:

通过选择合适的 ,可以证明梯度范数的期望以 速率收敛。

5.4 理论与实践的融合

ICLR 2025 的最新工作进一步深化了对平坦度的理解:

1. 本征平坦度(Intrinsic Flatness)

传统平坦度度量对参数化方式敏感。Du et al. (2025) 提出了参数化不变的平坦度度量:

2. 动态平坦度

Zhao et al. (2025) 发现平坦度随训练阶段动态变化:

  • 早期:平坦度快速下降
  • 中期:平坦度趋于稳定
  • 后期:小batch训练继续降低平坦度

基于此提出了自适应扰动半径策略:


六、总结与展望

核心要点回顾

主题关键发现
经验观察平坦极小值泛化更好,尖锐极小值泛化更差
度量方法Hessian特征值、本征维度、最大锐度
优化算法SAM通过对抗性扰动寻找平坦区域
理论联系PAC-Bayes边界提供了最严格的理论基础

开放问题

  1. 平坦度的必要条件? 是否所有泛化好的模型都对应平坦极小值?
  2. 计算高效的理论度量? 如何在不计算完整Hessian的情况下估计平坦度?
  3. 与其他正则化的关系? Batch normalization、dropout等如何影响平坦度?

实践建议

# 推荐配置
optimizer = SAM(
    model.parameters(),
    base_optimizer=torch.optim.AdamW,
    rho=0.05,           # 从0.01-0.1开始调优
    adaptive=False,      # 大模型建议开启
    lr=1e-3,
    weight_decay=1e-2
)
 
# 训练技巧
# 1. Warmup学习率
# 2. 小batch (32-128) 
# 3. 配合label smoothing

参考文献

Footnotes

  1. Garipov, T., Izmailov, P., Podoprikhin, D., Garipov, D., Teter, P., Kalinin, A., & Vetrov, D. (2018). Loss surfaces, mode connectivity, and fast ensembling of DNNs. NeurIPS.

  2. Keskar, N. S., Mudigere, D., Nocedal, J., Smelyanskiy, M., & Tang, P. T. P. (2017). On large-batch training for deep learning: Generalization gap and sharp minima. ICLR.

  3. He, H., Xiong, S., Lam, S., Guo, Q., Li, J., & Li, X. (2019). Three mechanisms of weight decay regularization. ICLR.

  4. Li, C., Farkhoor, R., Rosgen, P., & Kohli, P. (2018). Measuring the intrinsic dimension of objective landscapes. ICLR.

  5. Foret, P., Kleiner, A., Moore, A., & Zabih, R. (2021). Sharpness-aware minimization for efficiently improving generalization. ICLR. 2

  6. Kwon, J., Kim, J., Park, H., & Park, I. (2021). GSAM: Gradient-norm aware sharpness. NeurIPS.

  7. Chen, J., & Hsieh, C. (2022). On the benefit of combining Adam and SAM. ICLR Workshop.

  8. Zhou, P., Yu, C., Chai, C., & others. (2022). Efficient sharpness-aware minimization. NeurIPS.

  9. Gur-Ari, G., Roberts, D. A., & Dyer, E. (2019). Gradient descent happens in a few steps. arXiv.

  10. Zhou, P., et al. (2024). Fisher SAM: Improving sharpness-aware minimization with Fisher information. ICLR.