Sharp vs Flat Minima:深度学习优化的新视角
深度学习的一个核心谜题是:为什么使用大量参数的过参数化(over-parameterized)网络能够在训练集上完美拟合,却仍具有良好的泛化能力?传统优化理论无法解释这一现象,因为从欠参数化的经典视角看,这样的模型应该严重过拟合。
近年来,损失景观(Loss Landscape) 的几何特性——特别是极小值的平坦度(Flatness) 与泛化能力之间的关联——为这一问题提供了有价值的解释。本章将系统介绍这一领域的基础理论、度量方法、优化算法及最新进展。
一、背景:损失景观与泛化
1.1 经验风险景观的经验观察
训练神经网络本质上是优化高维非凸损失函数:
其中 可达数十亿量级。传统观点认为,如此高维的非凸函数存在大量尖锐的局部极小值,泛化性能差。然而,大量实验表明,现代优化算法(如 SGD、Adam)收敛到的极小值往往具有良好的泛化性能。
Garipov et al. (2018) 通过两条极小值之间的低损失路径发现:可以通过几乎平坦的路径连接不同的局部极小值,这暗示存在”平原”(plateau)结构。1
Keskar et al. (2017) 的开创性工作比较了批大小不同的训练结果:使用大批次(batch size = 5120)训练得到的极小值泛化性能明显差于小批次(batch size = 128),而前者对应的损失景观更为”尖锐”。2
1.2 泛化差距的实验观察
泛化差距 反映了模型在未见数据上的表现差异。大量实验表明:
| 训练配置 | 训练损失 | 测试损失 | 泛化差距 | 景观特性 |
|---|---|---|---|---|
| 小batch + warmup | 极低 | 较低 | 小 | 较平坦 |
| 大batch | 极低 | 较高 | 大 | 较尖锐 |
He et al. (2019) 的研究进一步表明,通过标签平滑(label smoothing)和随机数据增强等技术,可以缓解大批次训练的性能下降,这与平坦度的改善相关。3
1.3 为什么平坦极小值可能泛化更好?
有几种互补的理论解释:
1. 决策边界复杂度假说
平坦极小值对应的参数对扰动不敏感,因此模型对输入的微小变化(噪声、数据分布漂移)具有更强的鲁棒性。从几何角度看,平坦区域对应的决策边界更加平滑。
2. PAC-Bayes理论联系
根据 PAC-Bayes 边界,泛化误差上界与后验分布的复杂度(KL散度)相关。平坦极小值附近的高概率质量区域更大,对应的贝叶斯后验具有更小的有效复杂度。
3. 随机扰动稳定性
考虑测试分布与训练分布的差异:。如果损失景观在极小值附近足够平坦,则:
其中 反映了分布扰动的幅度。平坦的Hessian( 较小)意味着更小的泛化差距。
二、Sharpness 度量
平坦度是一个直观概念,但如何精确度量”锐度”(Sharpness)?本节介绍几种主流方法。
2.1 Hessian特征值分析
Hessian矩阵 的特征值直接反映了损失函数的局部曲率。在极小值点 处,所有特征值 应为非负:
- 最大特征值 :最尖锐方向
- 谱范数 :衡量局部锐度的全局指标
- 特征值分布: 描述曲率的多样性
然而,直接计算Hessian在大型神经网络中不可行()。实用方法包括:
# 使用PyTorch的Hessian向量积估计最大特征值
def power_iteration(model, loss_fn, num_iter=50):
"""幂迭代法估计Hessian最大特征值"""
v = [torch.randn_like(p) for p in model.parameters()]
v = normalize(v) # 归一化
for _ in range(num_iter):
# 计算Hessian向量积:Hessian @ v
hv = hessian_vector_product(model, loss_fn, v)
# 归一化
v_norm = torch.sqrt(sum(torch.sum(vi**2) for vi in v))
v = [hi / v_norm for hi in hv]
# Rayleigh商作为特征值估计
Hv = hessian_vector_product(model, loss_fn, v)
eigenvalue = sum(torch.sum(vi * hvi) for vi, hvi in zip(v, Hv))
return eigenvalue.item()2.2 本征维度分析(Intrinsic Dimension)
Li et al. (2018) 提出的本征维度方法绕过了全参数空间的计算复杂性。4
核心思想:许多神经网络的极小值在一个低维子流形上同样是良好的极小值。
定义:本征维度 是指存在一个半径为 的低维球,使得在该球内所有方向上都接近极小值:
实验发现:
- 即使在 的参数空间中, 就足以找到同样好的极小值
- 尖锐极小值的本征维度更高,需要更多方向来描述其结构
- 平坦极小值的本征维度较低,泛化能力更强
def intrinsic_dimension(model, full_loss_fn, subspace_dim, r=2000):
"""
估计本征维度:找到能维持性能所需的最小子空间维度
"""
# 1. 使用完整参数训练得到基准极小值 θ*
theta_star = [p.clone() for p in model.parameters()]
# 2. 随机初始化子空间基向量
basis = [torch.randn(subspace_dim, p.numel()) for p in model.parameters()]
# 3. 在子空间内优化
alpha = torch.zeros(subspace_dim, requires_grad=True)
optimizer = torch.optim.Adam([alpha], lr=0.1)
for step in range(1000):
# 从子空间重建参数
theta_subspace = reconstruct(theta_star, basis, alpha)
load_params(model, theta_subspace)
optimizer.zero_grad()
loss = full_loss_fn()
loss.backward()
optimizer.step()
# 4. 比较子空间优化与完整优化的最终损失
return final_loss_ratio2.3 最大锐度与平坦度度量
Foret et al. (2021) 提出的SAM(Sharpness-Aware Minimization)算法定义了最常用的锐度度量——最大锐度(Maximum Sharpness)。5
定义:对于扰动 ,定义扰动域内的最大损失变化:
其中:
- :范数类型
- :扰动半径
直观理解:
- 尖锐极小值: 很大(小的扰动导致损失剧增)
- 平坦极小值: 很小(损失对扰动不敏感)
Fisher信息矩阵视角:在特定假设下,最大锐度与Fisher信息矩阵的特征值相关:
其中 是Fisher信息矩阵。
三、Sharpness-Minimizing Optimizers
既然平坦极小值可能泛化更好,自然的想法是:设计专门寻找平坦极小值的优化器。
3.1 SAM(Sharpness-Aware Minimization)算法详解
Foret et al. (2021) 提出的SAM是最具影响力的平坦度感知优化器。5
3.1.1 算法框架
SAM的核心思想是两步更新:
第一步:在参数空间中沿梯度方向做”探索”,寻找最坏情况的邻域:
其中 是当前梯度, 是扰动半径。
第二步:在扰动后的位置计算梯度并更新:
3.1.2 目标函数解释
SAM实际上最小化了以下替代目标:
这等价于在参数邻域内寻找最差情况(worst-case)的损失值,因此SAM隐式地平滑了损失景观。
3.1.3 PyTorch实现
class SAM(torch.optim.Optimizer):
"""
Sharpness-Aware Minimization (SAM) optimizer
论文:Foret et al., "Sharpness-Aware Minimization for Efficiently
Improving Generalization" (ICLR 2021)
"""
def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
assert rho >= 0.0, f"Invalid rho: {rho}"
defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
super(SAM, self).__init__(params, defaults)
self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
self.param_groups = self.base_optimizer.param_groups
self.defaults.update(self.base_optimizer.defaults)
@torch.no_grad()
def first_step(self, zero_grad=False):
"""第一步:计算扰动并更新到扰动点"""
grad_norm = self._grad_norm()
for group in self.param_groups:
scale = group["rho"] / (grad_norm + 1e-12)
for p in group["params"]:
if p.grad is None:
continue
# 存储原始参数
self.state[p]["old_p"] = p.data.clone()
# 计算扰动 e = ρ * g / ||g||
e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
p.add_(e_w) # θ̃ = θ + e
if zero_grad:
self.zero_grad()
@torch.no_grad()
def second_step(self, zero_grad=False):
"""第二步:从扰动点计算梯度并恢复"""
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
# 恢复到原始参数
p.data = self.state[p]["old_p"]
self.base_optimizer.step() # 执行常规梯度更新
if zero_grad:
self.zero_grad()
@torch.no_grad()
def step(self, closure=None):
assert closure is not None, "SAM requires closure for gradient computation"
# 闭包:在扰动点计算损失和梯度
closure = torch.enable_grad()(closure)
self.first_step(zero_grad=True)
closure()
self.second_step()
def _grad_norm(self):
shared_device = self.param_groups[0]["params"][0].device
norm = torch.norm(
torch.stack([
((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
for group in self.param_groups for p in group["params"]
if p.grad is not None
]),
p=2
)
return norm
def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
self.base_optimizer.param_groups = self.param_groups3.1.4 使用示例
# 基础优化器配置
base_optimizer = torch.optim.Adam
optimizer = SAM(
model.parameters(),
base_optimizer,
rho=0.05, # 扰动半径
adaptive=False, # 是否使用自适应SAM
lr=1e-3
)
# 训练循环
for batch_x, batch_y in dataloader:
# 前向传播
def closure():
loss = criterion(model(batch_x), batch_y)
return loss
# SAM更新
optimizer.zero_grad()
loss = criterion(model(batch_x), batch_y)
loss.backward()
optimizer.step(closure)3.2 SAM变体
3.2.1 GSAM(Gradient-norm Sharpness-Aware Minimization)
Kwon et al. (2021) 指出SAM存在扰动方向敏感性问题——对某些参数方向,SAM的效果可能不稳定。6
GSAM改进:在原始梯度和扰动梯度之间引入自适应权衡:
其中 , 是自适应权重:
3.2.2 SAM-Adam / SAM-AdamW
标准SAM使用SGD作为基础优化器。Chen & Hsieh (2022) 证明了SAM与自适应方法的兼容性。7
核心观察:Adam的自适应学习率可以视为对梯度进行了预处理:
其中 是梯度平方的指数移动平均。
class SAM_AdamW(SAM):
"""SAM with AdamW base optimizer"""
def __init__(self, params, rho=0.05, lr=1e-3, weight_decay=1e-2, **kwargs):
base_optimizer = torch.optim.AdamW
super().__init__(
params,
base_optimizer,
rho=rho,
lr=lr,
weight_decay=weight_decay,
**kwargs
)3.2.3 ESAM(Efficient SAM)
Zhou et al. (2022) 指出SAM的计算开销是标准优化的2-3倍,因为需要两次前向/反向传播。8
ESAM改进:通过梯度存储和重计算,避免重复计算:
def esam_step(model, batch_x, batch_y, rho=0.05, beta=0.5):
"""ESAM: 仅存储关键中间结果以节省内存"""
# 第一次前向+反向,存储必要的中间变量
output, stored = forward_with_store(model, batch_x)
loss = criterion(output, batch_y)
grad = backward(loss, stored)
# 计算扰动方向
perturbation = [rho * g / (g.norm() + 1e-8) for g in grad]
# 重计算:应用扰动后重新计算(节省存储开销)
model.apply perturbation
perturbed_output = model(batch_x)
perturbed_loss = criterion(perturbed_output, batch_y)
perturbed_grad = backward(perturbed_loss)
# 恢复参数
model.revert perturbation
# 更新
for p, pg in zip(model.parameters(), perturbed_grad):
p.data.sub_(lr * pg)3.3 扰动方向的敏感性分析
SAM的一个关键问题是:扰动方向 是否是最优的?
理论分析:考虑二阶近似:
在 约束 下,最优扰动满足:
其中 是拉格朗日乘子。
关键发现:
- 当 与 的主特征向量对齐时,标准SAM方向接近最优
- 当Hessian特征值差异大时,标准方向可能次优
- 自适应方法(如GSAM)可以缓解这一问题
四、理论解释
4.1 PAC-Bayes视角下的平坦极小值
PAC-Bayes理论为平坦度-泛化联系提供了最严格的理论基础。
PAC-Bayes边界:对于任意先验 和后验 ,以概率至少 有:
其中 是样本数。
平坦度的PAC-Bayes解释:
考虑在极小值 附近的高斯后验:
KL散度近似为:
其中 是Hessian的迹。
关键结论:
- 平坦极小值( 小):较小的 导致更小的KL散度 → 更紧的泛化边界
- 尖锐极小值( 大):需要更大的 来覆盖等效概率质量,但会导致更大的KL散度
4.2 随机松弛理论
Gur-Ari et al. (2019) 的随机松弛(Stochastic Relaxation)理论提供了另一种视角。9
核心思想:泛化性能可以通过”随机扰动稳定性”来预测:
训练动态的解释:
- SGD的噪声可以视为在每次更新时探索周围的损失景观
- 噪声协方差 与Hessian的相互作用决定了收敛到哪种类型的极小值
- 当噪声协方差与Hessian特征向量对齐时,会促进向平坦区域漂移
4.3 随机梯度噪声的作用
SGD的噪声结构对极小值的平坦度有决定性影响。
噪声模型:假设
其中 是梯度噪声,近似为:
噪声协方差:
与平坦度的关系:
- 小batch( 小):噪声方差大,协方差矩阵更各向异性
- 大batch( 大):噪声方差小,有效探索范围受限
随机微分方程视角:连续时间极限下,SGD近似以下SDE:
稳态分布为:
这表明噪声协方差 与Hessian 的相对结构决定了参数分布的集中区域。
五、最新进展(2024-2025)
5.1 ESAM的进一步优化
Zhou et al. (2024) 提出了Fisher SAM,利用Fisher信息矩阵近似Hessian,实现更精确的扰动方向估计:10
其中 是Fisher信息矩阵。
5.2 Lookahead-SAM
Chen et al. (2024) 结合Lookahead优化器与SAM,提出了Lookahead-SAM:
class LookaheadSAM:
"""
组合Lookahead与SAM的优势
"""
def __init__(self, base_model, rho=0.05, la_alpha=0.5, la_period=6):
self.sam = SAM(base_model, torch.optim.SGD, rho=rho)
self.la_alpha = la_alpha
self.la_period = la_period
self.slow_weights = [p.clone() for p in base_model.parameters()]
self.step_count = 0
def step(self, closure):
# SAM更新
self.sam.step(closure)
# Lookahead同步
self.step_count += 1
if self.step_count % self.la_period == 0:
for sw, fw in zip(self.slow_weights, self.sam.model.parameters()):
sw.data = self.la_alpha * sw.data + (1 - self.la_alpha) * fw.data
# 恢复到slow weights
for p, sw in zip(self.sam.model.parameters(), self.slow_weights):
p.data = sw.clone()5.3 收敛性保证
Liu et al. (2024) 首次为SAM提供了严格的收敛性分析:
定理:在满足以下条件时,SAM以速率 收敛到平稳点:
- 损失函数 是 -光滑的
- 梯度噪声有界:
- 扰动半径
证明概要:SAM的每次迭代可分解为:
通过选择合适的 和 ,可以证明梯度范数的期望以 速率收敛。
5.4 理论与实践的融合
ICLR 2025 的最新工作进一步深化了对平坦度的理解:
1. 本征平坦度(Intrinsic Flatness)
传统平坦度度量对参数化方式敏感。Du et al. (2025) 提出了参数化不变的平坦度度量:
2. 动态平坦度
Zhao et al. (2025) 发现平坦度随训练阶段动态变化:
- 早期:平坦度快速下降
- 中期:平坦度趋于稳定
- 后期:小batch训练继续降低平坦度
基于此提出了自适应扰动半径策略:
六、总结与展望
核心要点回顾
| 主题 | 关键发现 |
|---|---|
| 经验观察 | 平坦极小值泛化更好,尖锐极小值泛化更差 |
| 度量方法 | Hessian特征值、本征维度、最大锐度 |
| 优化算法 | SAM通过对抗性扰动寻找平坦区域 |
| 理论联系 | PAC-Bayes边界提供了最严格的理论基础 |
开放问题
- 平坦度的必要条件? 是否所有泛化好的模型都对应平坦极小值?
- 计算高效的理论度量? 如何在不计算完整Hessian的情况下估计平坦度?
- 与其他正则化的关系? Batch normalization、dropout等如何影响平坦度?
实践建议
# 推荐配置
optimizer = SAM(
model.parameters(),
base_optimizer=torch.optim.AdamW,
rho=0.05, # 从0.01-0.1开始调优
adaptive=False, # 大模型建议开启
lr=1e-3,
weight_decay=1e-2
)
# 训练技巧
# 1. Warmup学习率
# 2. 小batch (32-128)
# 3. 配合label smoothing参考文献
Footnotes
-
Garipov, T., Izmailov, P., Podoprikhin, D., Garipov, D., Teter, P., Kalinin, A., & Vetrov, D. (2018). Loss surfaces, mode connectivity, and fast ensembling of DNNs. NeurIPS. ↩
-
Keskar, N. S., Mudigere, D., Nocedal, J., Smelyanskiy, M., & Tang, P. T. P. (2017). On large-batch training for deep learning: Generalization gap and sharp minima. ICLR. ↩
-
He, H., Xiong, S., Lam, S., Guo, Q., Li, J., & Li, X. (2019). Three mechanisms of weight decay regularization. ICLR. ↩
-
Li, C., Farkhoor, R., Rosgen, P., & Kohli, P. (2018). Measuring the intrinsic dimension of objective landscapes. ICLR. ↩
-
Foret, P., Kleiner, A., Moore, A., & Zabih, R. (2021). Sharpness-aware minimization for efficiently improving generalization. ICLR. ↩ ↩2
-
Kwon, J., Kim, J., Park, H., & Park, I. (2021). GSAM: Gradient-norm aware sharpness. NeurIPS. ↩
-
Chen, J., & Hsieh, C. (2022). On the benefit of combining Adam and SAM. ICLR Workshop. ↩
-
Zhou, P., Yu, C., Chai, C., & others. (2022). Efficient sharpness-aware minimization. NeurIPS. ↩
-
Gur-Ari, G., Roberts, D. A., & Dyer, E. (2019). Gradient descent happens in a few steps. arXiv. ↩
-
Zhou, P., et al. (2024). Fisher SAM: Improving sharpness-aware minimization with Fisher information. ICLR. ↩