训练动态与临界学习率:Edge of Stability

深度神经网络训练过程中的一个关键发现是**临界学习率(Critical Learning Rate)的存在。当学习率设置在临界值附近或略高于临界值时,会出现一种独特的Edge of Stability(稳定性边界)**现象1。本文件系统深入分析这一现象的理论基础、实验验证及其对训练实践的指导意义。

临界学习率理论

基本概念

临界学习率 定义为使训练在稳定与不稳定边界上运行的学习率:

其中 是损失函数Hessian矩阵的最大特征值。

稳定性分析

考虑在某个参数点 附近的一维动态:

梯度下降更新

在Hessian为 的点,定义 ,则:

解这个线性差分方程,得到:

稳定性条件

这给出:

import numpy as np
import torch
import matplotlib.pyplot as plt
 
def analyze_stability(lr, eigenvalues):
    """
    分析给定学习率下各特征值方向的稳定性
    
    Args:
        lr: 学习率
        eigenvalues: Hessian特征值列表
        
    Returns:
        dict: 各特征值方向的稳定性分析
    """
    eigenvalues = np.array(eigenvalues)
    
    # 稳定性因子
    stability_factors = np.abs(1 - lr * eigenvalues)
    
    # 分类
    stable = stability_factors < 1
    marginally_stable = (stability_factors >= 1) & (stability_factors < 2)
    unstable = stability_factors >= 2
    
    return {
        'eigenvalues': eigenvalues,
        'stability_factors': stability_factors,
        'n_stable': np.sum(stable),
        'n_marginally_stable': np.sum(marginally_stable),
        'n_unstable': np.sum(unstable),
        'max_stability_factor': np.max(stability_factors),
        'min_stability_factor': np.min(stability_factors),
        'is_stable': np.all(stability_factors < 1),
        'is_eos': np.any(stability_factors >= 1) and np.all(stability_factors < 2)
    }
 
def compute_critical_lr(hessian_eigenvalues):
    """
    计算临界学习率
    
    η_crit = 2 / λ_max
    """
    lambda_max = np.max(hessian_eigenvalues)
    return 2.0 / lambda_max
 
def plot_stability_diagram(eigenvalues, lr_range):
    """
    绘制稳定性相图
    """
    lrs = np.array(lr_range)
    lambdas = np.array(eigenvalues)
    
    stability_matrix = np.zeros((len(lambdas), len(lrs)))
    
    for i, lam in enumerate(lambdas):
        for j, lr in enumerate(lrs):
            factor = np.abs(1 - lr * lam)
            if factor < 1:
                stability_matrix[i, j] = 0  # 稳定
            elif factor < 2:
                stability_matrix[i, j] = 1  # 边界
            else:
                stability_matrix[i, j] = 2  # 不稳定
    
    plt.figure(figsize=(12, 8))
    plt.imshow(stability_matrix, aspect='auto', cmap='RdYlGn_r', 
               extent=[lrs[0], lrs[-1], lambdas[-1], lambdas[0]])
    plt.colorbar(label='Stability Region')
    
    # 标记临界线
    for lam in lambdas[::5]:
        lr_crit = 2.0 / lam
        plt.axvline(x=lr_crit, color='white', linestyle='--', alpha=0.5)
    
    plt.xlabel('Learning Rate')
    plt.ylabel('Hessian Eigenvalue')
    plt.title('Stability Diagram: λ vs η')
    plt.show()

Edge of Stability现象

现象描述

Edge of Stability (EOS) 现象1

当使用接近或略高于 的学习率时,训练会表现出独特的动态:

  1. 初始不稳定性:训练开始时loss可能上升
  2. 自适应稳定化:网络参数自动调整使得
  3. 长期稳定:最终进入稳定状态,虽然sharpness维持在临界值附近
学习率 η
   │
2/λ_max ──────────────────────  ← 临界线
   │
   │    ┌─────────────┐
   │    │   EOS区域   │
   │    │             │
 η* │────┼─────────────┼─────────
   │    │  初始不稳定性│
   │    │             │
   │    │  Loss可能上升│
   │    │             │
   │    └─────────────┘
   │
   └────────────────────────────────→ Training Time
class EdgeOfStabilityAnalyzer:
    """
    Edge of Stability 分析器
    """
    def __init__(self):
        self.loss_history = []
        self.sharpness_history = []  # λ_max(t)
        self.effective_lr_history = []  # η * λ_max(t)
        self.steps = []
    
    def update(self, step, loss, sharpness, lr):
        """记录训练过程中的关键指标"""
        self.steps.append(step)
        self.loss_history.append(loss)
        self.sharpness_history.append(sharpness)
        self.effective_lr_history.append(lr * sharpness)
    
    def detect_eos(self, window=50):
        """
        检测EOS现象
        
        EOS判定条件:
        1. 存在初始loss上升期
        2. 后期 λ_max 围绕 η*λ_max ≈ 2 波动
        """
        if len(self.loss_history) < window:
            return False, {}
        
        # 检测初始不稳定性(loss上升)
        initial_losses = self.loss_history[:window]
        if initial_losses[-1] > initial_losses[0]:
            initial_instability = True
        else:
            initial_instability = False
        
        # 检测sharpness稳定在2/η附近
        effective_lr = np.array(self.effective_lr_history)
        later_eos = effective_lr[-window:]
        
        mean_eos = np.mean(later_eos)
        std_eos = np.std(later_eos)
        
        is_eos = (initial_instability and 
                  np.abs(mean_eos - 2.0) < 0.5 and 
                  std_eos < 0.3)
        
        return is_eos, {
            'initial_instability': initial_instability,
            'mean_eos': mean_eos,
            'std_eos': std_eos,
            'final_sharpness': self.sharpness_history[-1],
            'initial_sharpness': self.sharpness_history[0]
        }
    
    def plot_dynamics(self, save_path=None):
        """可视化EOS动态"""
        fig, axes = plt.subplots(3, 1, figsize=(12, 10))
        
        steps = self.steps
        
        # 1. Loss曲线
        axes[0].plot(steps, self.loss_history, 'b-', linewidth=1)
        axes[0].set_ylabel('Loss')
        axes[0].set_title('Training Loss')
        axes[0].grid(True, alpha=0.3)
        
        # 2. Sharpness曲线
        axes[1].plot(steps, self.sharpness_history, 'g-', linewidth=1)
        axes[1].set_ylabel('λ_max (Sharpness)')
        axes[1].set_title('Hessian Maximum Eigenvalue')
        axes[1].grid(True, alpha=0.3)
        
        # 3. Effective Learning Rate
        axes[2].plot(steps, self.effective_lr_history, 'r-', linewidth=1)
        axes[2].axhline(y=2.0, color='black', linestyle='--', 
                        label='Stability Boundary (η*λ=2)')
        axes[2].set_xlabel('Training Step')
        axes[2].set_ylabel('η * λ_max')
        axes[2].set_title('Edge of Stability Analysis')
        axes[2].legend()
        axes[2].grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path)
        plt.show()

理论解释

EOS的理论基础

  1. 瞬态动力学:训练初期,网络尚未接近局部最小值,Hessian可能有大特征值
  2. 自适应调整:不稳定的动态会改变网络的曲率特性
  3. 吸引盆:最终网络被吸引到一个sharpness约为 的区域
def theoretical_eos_analysis(lr, initial_sharpness, target_eos=2.0):
    """
    EOS的理论分析
    """
    # 临界sharpness
    target_sharpness = target_eos / lr
    
    # 初始到目标的sharpness变化
    reduction_factor = initial_sharpness / target_sharpness
    
    return {
        'critical_sharpness': target_sharpness,
        'initial_sharpness': initial_sharpness,
        'reduction_factor': reduction_factor,
        'lr_crit': 2.0 / initial_sharpness,
        'is_above_crit': lr > 2.0 / initial_sharpness
    }

训练动态的类型

三种训练模式

根据学习率与临界学习率的关系,训练可分为三种模式:

模式学习率范围Sharpness演化Loss动态
过阻尼 下降较快单调下降
临界(EOS)初期可能上升后稳定
欠阻尼发散发散/不稳定
def classify_training_mode(lr, sharpness_history, loss_history):
    """
    分类训练模式
    """
    initial_sharpness = sharpness_history[0]
    final_sharpness = sharpness_history[-1]
    
    lr_crit_low = 0.5 / initial_sharpness
    lr_crit_high = 2.0 / initial_sharpness
    
    # 计算有效学习率
    effective_lrs = [lr * s for s in sharpness_history]
    mean_effective_lr = np.mean(effective_lrs)
    
    if lr < lr_crit_low:
        mode = "over-damped"
        sharpness_trend = "decreasing"
    elif lr > lr_crit_high:
        mode = "under-damped/divergent"
        sharpness_trend = "increasing or divergent"
    else:
        mode = "edge-of-stability"
        sharpness_trend = "stabilizing around 2/lr"
    
    return {
        'mode': mode,
        'sharpness_trend': sharpness_trend,
        'lr_crit_low': lr_crit_low,
        'lr_crit_high': lr_crit_high,
        'mean_effective_lr': mean_effective_lr,
        'is_stable': mean_effective_lr < 2.0
    }

典型动态曲线

Loss
  │
  │     欠阻尼
  │       ↓
  │      ╱ ╲
  │     ╱   ╲
  │    ╱     ╲        临界(EOS)
  │   ╱       ╲         ↓
  │  ╱         ╲       ┌────
  │ ╱           ╲─────╱     初期不稳定
  │╱             ╲   ╱       后稳定
  │              ╲─╱
  │               ↓
  │                过阻尼
  │               单调下降
  └──────────────────────────→ Time

影响EOS的因素

1. 批量大小

def analyze_batch_size_effect(batch_sizes, model, dataloader):
    """
    分析批量大小对EOS的影响
    """
    results = {}
    
    for bs in batch_sizes:
        # 创建数据加载器
        loader = create_dataloader(dataloader.dataset, batch_size=bs)
        
        # 训练并记录动态
        analyzer = EdgeOfStabilityAnalyzer()
        train_with_analysis(model, loader, analyzer, lr=0.1)
        
        # 分析结果
        mode_info = classify_training_mode(
            0.1, 
            analyzer.sharpness_history,
            analyzer.loss_history
        )
        
        results[bs] = {
            'final_sharpness': analyzer.sharpness_history[-1],
            'mode': mode_info['mode'],
            'loss_convergence': analyzer.loss_history[-1]
        }
    
    return results
 
# 典型结果:
# batch_size=32:   final_sharpness≈15, mode=EOS
# batch_size=64:   final_sharpness≈18, mode=EOS  
# batch_size=256:  final_sharpness≈22, mode=EOS
# batch_size=1024: final_sharpness≈25, mode=EOS

2. 网络架构

def analyze_architecture_effect(architectures, dataloader):
    """
    分析网络架构对EOS的影响
    """
    results = {}
    
    for arch_name, arch_fn in architectures.items():
        model = arch_fn()
        
        # 训练并分析
        analyzer = EdgeOfStabilityAnalyzer()
        train_with_analysis(model, dataloader, analyzer, lr=0.1)
        
        results[arch_name] = {
            'sharpness_evolution': analyzer.sharpness_history,
            'final_sharpness': analyzer.sharpness_history[-1],
            'mode': classify_training_mode(
                0.1,
                analyzer.sharpness_history,
                analyzer.loss_history
            )['mode']
        }
    
    return results
 
# 典型结果:
# ResNet-18:    最终sharpness较低,泛化好
# ResNet-50:    最终sharpness中等
# Plain CNN:    最终sharpness较高,可能不稳定
# MLP:          取决于宽度

3. 权重衰减

def analyze_weight_decay_effect(weight_decays, model, dataloader):
    """
    分析权重衰减对EOS的影响
    """
    results = {}
    lr = 0.1
    
    for wd in weight_decays:
        # 设置优化器
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=wd)
        
        analyzer = EdgeOfStabilityAnalyzer()
        train_with_analysis(model, dataloader, analyzer, optimizer=optimizer)
        
        # 权重衰减会影响sharpness的基线
        # wd越大,sharpness基线越低
        effective_sharpness = [s - wd * model_norm for s in analyzer.sharpness_history]
        
        results[wd] = {
            'sharpness': analyzer.sharpness_history,
            'effective_sharpness': effective_sharpness,
            'loss': analyzer.loss_history
        }
    
    return results

EOS与泛化的关系

Sharpness-泛化联系

实验发现:适度高的最终sharpness可能与更好的泛化相关

def analyze_sharpness_generalization_relationship(experiments):
    """
    分析sharpness与泛化的关系
    """
    # experiments: 包含不同lr训练的实验结果
    
    sharpness_gen_pairs = []
    
    for exp_name, exp_data in experiments.items():
        sharpness = exp_data['final_sharpness']
        test_acc = exp_data['test_accuracy']
        
        sharpness_gen_pairs.append((sharpness, test_acc))
    
    # 排序
    sharpness_gen_pairs.sort(key=lambda x: x[0])
    
    # 分析相关性
    sharpnesses = [p[0] for p in sharpness_gen_pairs]
    accuracies = [p[1] for p in sharpness_gen_pairs]
    
    correlation = np.corrcoef(sharpnesses, accuracies)[0, 1]
    
    return {
        'pairs': sharpness_gen_pairs,
        'correlation': correlation,
        'optimal_sharpness_range': find_optimal_range(sharpnesses, accuracies)
    }

EOS作为正则化

假说:EOS现象可能起到一种隐式正则化的作用。

class EOSRegularization Hypothesis:
    """
    EOS正则化假说
    
    当训练在EOS区域时:
    1. 参数在loss landscape的"ridge"上移动
    2. 这种动态可能促进更好的特征学习
    3. 最终收敛到泛化较好的区域
    """
    
    def __init__(self, lr, target_eos=2.0):
        self.lr = lr
        self.target_eos = target_eos
    
    def should_use_eos(self, task_type):
        """
        判断任务是否适合使用EOS学习率
        """
        if task_type in ['image_classification', 'language_modeling']:
            # 复杂任务可能从EOS获益
            return True
        elif task_type in ['simple_regression']:
            # 简单任务不需要
            return False
        
        return None

实用指导

1. 寻找临界学习率

def find_critical_learning_rate(model, dataloader, lr_range=(1e-4, 10)):
    """
    使用二分搜索找到临界学习率
    """
    def check_stability(lr):
        """检查给定lr下训练是否稳定"""
        analyzer = EdgeOfStabilityAnalyzer()
        train_for_steps(model, dataloader, analyzer, lr=lr, n_steps=100)
        
        # 稳定标准:loss不持续上升,sharpness不发散
        loss_trend = np.polyfit(range(len(analyzer.loss_history)), 
                               analyzer.loss_history, 1)[0]
        sharp_trend = np.polyfit(range(len(analyzer.sharpness_history)),
                                analyzer.sharpness_history, 1)[0]
        
        return loss_trend < 0.1 and sharp_trend < 100
    
    # 二分搜索
    left, right = lr_range
    while right - left > 1e-5:
        mid = (left + right) / 2
        if check_stability(mid):
            left = mid
        else:
            right = mid
    
    return left, right
 
def plot_lr_finding_result(lrs, stabilities, lr_crit_low, lr_crit_high):
    """可视化lr搜索结果"""
    plt.figure(figsize=(10, 6))
    
    plt.semilogx(lrs, stabilities, 'b-o', markersize=4)
    plt.axvline(x=lr_crit_low, color='g', linestyle='--', 
                label=f'Stable boundary: {lr_crit_low:.4f}')
    plt.axvline(x=lr_crit_high, color='r', linestyle='--',
                label=f'EOS boundary: {lr_crit_high:.4f}')
    
    plt.xlabel('Learning Rate (log scale)')
    plt.ylabel('Training Stability')
    plt.title('Critical Learning Rate Analysis')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

2. 自适应EOS学习率

class EOSAwareLRScheduler:
    """
    EOS感知的学习率调度器
    
    策略:初期使用较高学习率(利用EOS效应),
          后期适当降低以精细优化
    """
    def __init__(self, base_lr, total_steps, eos_target=2.0):
        self.base_lr = base_lr
        self.total_steps = total_steps
        self.eos_target = eos_target
        self.step = 0
    
    def step(self, current_sharpness=None):
        """返回当前学习率"""
        progress = self.step / self.total_steps
        self.step += 1
        
        if current_sharpness is not None:
            # 基于sharpness的自适应
            if current_sharpness > 0:
                lr_crit = self.eos_target / current_sharpness
                lr = min(self.base_lr, lr_crit * 0.9)
            else:
                lr = self.base_lr
        else:
            # 标准cosine衰减
            lr = self.base_lr * (0.5 * (1 + np.cos(np.pi * progress)))
        
        return lr
 
# 使用示例
scheduler = EOSAwareLRScheduler(base_lr=0.1, total_steps=100000)
 
for epoch in range(num_epochs):
    for batch in dataloader:
        optimizer.zero_grad()
        outputs = model(batch['input'])
        loss = criterion(outputs, batch['target'])
        loss.backward()
        
        # 获取sharpness估计(使用随机K-FAC或Power Method)
        sharpness = estimate_sharpness(model, batch)
        
        # 获取自适应学习率
        lr = scheduler.step(current_sharpness=sharpness)
        
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        
        optimizer.step()

3. 批量大小与学习率的联合调整

class JointBatchLRScheduler:
    """
    批量大小与学习率的联合调度
    
    原理:当批量大小增加k倍时,学习率应增加√k倍
          以保持有效的梯度噪声尺度
    """
    def __init__(self, base_batch_size=32, base_lr=0.1):
        self.base_batch_size = base_batch_size
        self.base_lr = base_lr
        self.current_batch_size = base_batch_size
    
    def get_lr(self, batch_size):
        """计算给定批量大小下的学习率"""
        # 线性缩放规则
        lr_linear = self.base_lr * (batch_size / self.base_batch_size)
        
        # 更保守的√k缩放
        lr_sqrt = self.base_lr * np.sqrt(batch_size / self.base_batch_size)
        
        # EOS考虑:确保lr * λ_max 在合理范围
        # 通常 [0.5, 2.0] 区间
        lr_eos_adjusted = min(lr_linear, 0.1)  # 上限保护
        
        return {
            'linear': lr_linear,
            'sqrt': lr_sqrt,
            'eos_adjusted': lr_eos_adjusted
        }
 
def plot_batch_lr_relationship():
    """可视化批量大小与学习率的关系"""
    batch_sizes = [16, 32, 64, 128, 256, 512]
    scheduler = JointBatchLRScheduler()
    
    lrs_linear = []
    lrs_sqrt = []
    
    for bs in batch_sizes:
        lrs = scheduler.get_lr(bs)
        lrs_linear.append(lrs['linear'])
        lrs_sqrt.append(lrs['sqrt'])
    
    plt.figure(figsize=(10, 6))
    plt.plot(batch_sizes, lrs_linear, 'b-o', label='Linear scaling')
    plt.plot(batch_sizes, lrs_sqrt, 'r-s', label='√k scaling')
    plt.xlabel('Batch Size')
    plt.ylabel('Learning Rate')
    plt.title('Learning Rate vs Batch Size')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

与其他现象的联系

1. Sharpness-Aware Minimization (SAM)

SAM通过显式最小化sharpness来改善泛化:

class SAM(optimizer):
    """
    Sharpness-Aware Minimization
    
    本质上是在优化过程中主动降低sharpness,
    使得训练区域更加"平坦"
    """
    def __init__(self, params, lr, rho=0.05):
        super().__init__(params, lr)
        self.rho = rho  # 扰动半径
    
    def step(self, closure=None):
        # 1. 计算正常梯度
        loss = closure()
        loss.backward()
        
        # 2. 添加扰动
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    # 扰动方向 = grad / ||grad||
                    p.data.add_(p.grad / (torch.norm(p.grad) + 1e-8) * self.rho)
        
        # 3. 重新计算梯度
        optimizer.zero_grad()
        loss_perturbed = closure()
        loss_perturbed.backward()
        
        # 4. 恢复参数并应用梯度
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    # 恢复到原始位置
                    p.data.sub_(p.grad / (torch.norm(p.grad) + 1e-8) * self.rho)
                    # 应用梯度
                    p.data.add_(p.grad, alpha=-group['lr'])
        
        return loss

2. Learning Rate Warmup

def cosine_warmup_schedule(epoch, warmup_epochs=5, total_epochs=100, base_lr=0.1):
    """
    带warmup的余弦退火
    
    Warmup可以避免初期sharpness过高时的不稳定性
    """
    if epoch < warmup_epochs:
        # 线性warmup
        return base_lr * epoch / warmup_epochs
    else:
        # 余弦退火
        progress = (epoch - warmup_epochs) / (total_epochs - warmup_epochs)
        return base_lr * 0.5 * (1 + np.cos(np.pi * progress))

参考


相关阅读

Footnotes

  1. Cohen et al., “Gradient Descent on Neural Networks Typically Occurs at the Edge of Stability”, ICLR 2021 2