简介

训练数十亿参数的Transformer通常是脆弱的,经常出现瞬态损失尖峰和发散,浪费大量计算资源。尽管Edge of Stability(EoS)理论提供了通过(预条件)曲率理解和控制优化稳定性的强大工具,但这些曲率控制方法在大规模Transformer训练中并未广泛采用,因为曲率估计的复杂性。本文介绍ICLR 2026的一项工作,首先引入一种快速的在线Hessian最大特征值估计器,然后在亿级参数规模上验证了训练不稳定与预条件曲率激增的关系,并提出Architecture Warm-up方法来解决这一问题。1

Edge of Stability理论回顾

什么是Edge of Stability?

Edge of Stability是深度学习优化领域的一个重要发现。当使用固定步长的SGD或Adam训练神经网络时,训练过程会自发地收敛到损失景观的”稳定边缘”——此时梯度与最陡下降方向的对齐度接近0或1:

这意味着:

  • 训练动态由曲率控制
  • 损失值在稳定边缘附近震荡
  • 梯度裁剪等操作可以改变这一行为

预条件曲率分析

对于自适应优化器如Adam,定义预条件Hessian

其中 是梯度的协方差矩阵, 是真实Hessian。

预条件曲率决定了优化的稳定性和收敛速度:

  • 过大 → 训练不稳定
  • 过小 → 收敛缓慢

快速Hessian特征值估计

Power Iteration基础

传统的Power Iteration需要多次前向-反向传播来估计最大特征值,复杂度为 ,其中 是迭代次数, 是计算复杂度。

Warm-Started Power Iteration

本文提出Warm-Started Power Iteration,利用以下观察:

  1. 连续两步之间的Hessian变化是平滑的
  2. 可以使用上一步的估计作为当前步的初始化
import torch
import torch.nn as nn
import numpy as np
 
class WarmStartedCurvatureEstimator:
    """
    基于Warm-Started Power Iteration的快速曲率估计器
    
    核心思想:利用时间平滑性减少迭代次数
    """
    
    def __init__(self, model, n_iter=5, tol=1e-4, history_size=10):
        self.model = model
        self.n_iter = n_iter  # 每次估计的迭代次数
        self.tol = tol        # 收敛容差
        self.history_size = history_size
        
        # 存储历史估计用于平滑
        self.eigenvalue_history = []
        self.eigenvector_history = []
        
        # 当前估计
        self.current_lambda = None
        self.current_v = None
        
    def hessian_vector_product(self, grads, params, vector):
        """
        计算 Hessian-vector product: (∂²f/∂²θ) @ v
        
        使用有限差分近似,避免显式计算Hessian
        H @ v ≈ (∇f(θ + εv) - ∇f(θ)) / ε
        """
        epsilon = 1e-3
        
        # 保存原始参数
        original_params = {}
        for name, param in params.items():
            original_params[name] = param.data.clone()
            param.data.add_(epsilon * vector[name])
        
        # 计算扰动后的梯度
        self.model.zero_grad()
        # 前向传播(需要model有损失输出)
        output = self.model()
        output.backward()
        
        perturbed_grads = {}
        for name, param in params.items():
            if param.grad is not None:
                perturbed_grads[name] = param.grad.data.clone()
        
        # 恢复原始参数
        for name, param in params.items():
            param.data.copy_(original_params[name])
        
        # 计算Hessian-vector product
        hvp = {}
        for name in grads.keys():
            if grads[name] is not None and name in perturbed_grads:
                hvp[name] = (perturbed_grads[name] - grads[name]) / epsilon
        
        return hvp
    
    def power_iteration_step(self, grads, params, vector):
        """一次Power Iteration"""
        # 计算H @ v
        hvp = self.hessian_vector_product(grads, params, vector)
        
        # 归一化
        norm = 0
        for name in vector.keys():
            norm += torch.sum(vector[name] ** 2)
        norm = torch.sqrt(norm).item()
        
        if norm > 1e-10:
            for name in vector.keys():
                vector[name] = hvp[name] / norm
        else:
            # 随机初始化
            for name in vector.keys():
                vector[name] = torch.randn_like(vector[name])
                vector[name] = vector[name] / torch.norm(vector[name])
        
        return vector, norm
    
    def estimate_curvature(self, grads, params):
        """
        估计最大Hessian特征值
        
        Args:
            grads: 当前梯度字典
            params: 模型参数字典
            
        Returns:
            lambda_max: 最大特征值估计
        """
        # 初始化向量
        if self.current_v is None:
            self.current_v = {
                name: torch.randn_like(param.data)
                for name, param in params.items()
            }
            # 归一化
            for name in self.current_v:
                self.current_v[name] = self.current_v[name] / torch.norm(self.current_v[name])
        
        # Warm-started Power Iteration
        lambda_est = 0
        for i in range(self.n_iter):
            v_old = {name: self.current_v[name].clone() for name in self.current_v}
            
            # Power Iteration步骤
            self.current_v, norm = self.power_iteration_step(grads, params, self.current_v)
            
            # Rayleigh商估计
            # λ ≈ v^T @ H @ v / v^T @ v
            if norm > 1e-10:
                lambda_est = norm
        
        # 更新历史
        self.eigenvalue_history.append(lambda_est)
        if len(self.eigenvalue_history) > self.history_size:
            self.eigenvalue_history.pop(0)
        
        # 平滑估计(指数移动平均)
        if self.current_lambda is None:
            self.current_lambda = lambda_est
        else:
            alpha = 0.3  # 平滑因子
            self.current_lambda = alpha * lambda_est + (1 - alpha) * self.current_lambda
        
        return self.current_lambda
    
    def analyze_curvature_trend(self):
        """分析曲率趋势"""
        if len(self.eigenvalue_history) < 2:
            return None
        
        # 简单线性趋势
        n = len(self.eigenvalue_history)
        x = np.arange(n)
        y = np.array(self.eigenvalue_history)
        
        slope = np.polyfit(x, y, 1)[0]
        
        return {
            'current': self.eigenvalue_history[-1],
            'mean': np.mean(self.eigenvalue_history),
            'trend': slope,
            'history': self.eigenvalue_history
        }
 
 
class TransformerCurvatureMonitor:
    """Transformer曲率监控器"""
    
    def __init__(self, model):
        self.model = model
        self.estimator = WarmStartedCurvatureEstimator(model)
        self.layer_curvatures = {}
        
    def compute_per_layer_curvature(self, grads, params):
        """
        计算每层的曲率估计
        
        对于Transformer,关注attention和FFN层的曲率
        """
        layer_names = ['attn', 'ffn', 'mlp', 'q_proj', 'k_proj', 'v_proj', 'o_proj']
        
        for name, param in params.items():
            # 检查是否是目标层
            for layer_name in layer_names:
                if layer_name in name.lower():
                    if layer_name not in self.layer_curvatures:
                        self.layer_curvatures[layer_name] = []
                    
                    # 计算该层的梯度范数作为曲率代理
                    if grads.get(name.replace('weight', 'grad')) is not None:
                        grad_norm = torch.norm(grads.get(name.replace('weight', 'grad'))).item()
                        self.layer_curvatures[layer_name].append(grad_norm)
    
    def get_instability_score(self):
        """计算不稳定分数"""
        if not self.layer_curvatures:
            return 0.0
        
        all_curvatures = []
        for curvatures in self.layer_curvatures.values():
            all_curvatures.extend(curvatures[-10:])  # 最近10个
        
        if len(all_curvatures) < 2:
            return 0.0
        
        # 不稳定性 = 曲率的变化系数
        return np.std(all_curvatures) / (np.mean(all_curvatures) + 1e-8)
 
 
def demo_curvature_estimation():
    """演示曲率估计"""
    print("=== Warm-Started Power Iteration演示 ===")
    
    # 创建简单模型
    model = nn.Sequential(
        nn.Linear(512, 2048),
        nn.GELU(),
        nn.Linear(2048, 2048),
        nn.GELU(),
        nn.Linear(2048, 512)
    )
    
    # 创建估计器
    estimator = WarmStartedCurvatureEstimator(model)
    
    # 模拟多次估计
    print("模拟训练过程中的曲率估计...")
    
    for step in range(20):
        # 模拟梯度
        grads = {name: torch.randn_like(p) * 0.1 for name, p in model.named_parameters()}
        params = {name: p for name, p in model.named_parameters()}
        
        # 估计曲率
        lambda_max = estimator.estimate_curvature(grads, params)
        
        if step % 5 == 0:
            print(f"Step {step:3d}: λ_max = {lambda_max:.4f}")
    
    # 分析趋势
    trend = estimator.analyze_curvature_trend()
    print(f"\n曲率趋势: {trend}")
 
 
if __name__ == "__main__":
    demo_curvature_estimation()

复杂度分析

方法每次迭代复杂度总复杂度(K次调用)
传统Power Iteration
Warm-Started (Ours)

其中 因为warm-start减少了收敛所需迭代次数。实验表明 ,即80%的加速。

训练不稳定性与曲率的关系

核心发现

通过在多个大规模Transformer上的实验,发现以下规律:

  1. 训练不稳定的时刻与预条件曲率激增高度相关
  2. 曲率随深度增加而增长
  3. 不稳定发生在曲率超过某个临界值时

曲率-深度关系

定义第 层的预条件曲率为 。实验观察到:

其中 ,即每层曲率增加10-30%。

不稳定阈值

为稳定边界。训练失稳当且仅当:

典型地,(对于Adam优化器)。

Architecture Warm-up方法

核心思想

不是通过降低学习率或增加权重衰减来控制曲率,而是渐进增加网络深度

  1. 从浅层网络开始训练(稳定)
  2. 逐渐添加新层(保持曲率受控)
  3. 最终达到目标深度

算法流程

class ArchitectureWarmup:
    """
    Architecture Warm-up: 渐进深度增长策略
    
    关键洞察:深层网络的曲率可以通过从浅到深逐步增长来控制
    """
    
    def __init__(self, target_depth, base_depth=1, growth_steps=1000,
                 growth_mode='linear'):
        self.target_depth = target_depth
        self.base_depth = base_depth
        self.growth_steps = growth_steps
        self.growth_mode = growth_mode
        
        self.current_depth = base_depth
        self.step = 0
        
    def get_growth_ratio(self):
        """计算当前增长比例"""
        progress = min(1.0, self.step / self.growth_steps)
        
        if self.growth_mode == 'linear':
            return progress
        elif self.growth_mode == 'sqrt':
            return np.sqrt(progress)
        elif self.growth_mode == 'warmup_decay':
            # 先慢后快再慢
            if progress < 0.3:
                return progress / 0.3 * 0.5
            elif progress < 0.7:
                return 0.5 + (progress - 0.3) / 0.4 * 0.3
            else:
                return 0.8 + (progress - 0.7) / 0.3 * 0.2
        else:
            return progress
    
    def get_current_depth(self):
        """获取当前应该使用的深度"""
        ratio = self.get_growth_ratio()
        return int(self.base_depth + ratio * (self.target_depth - self.base_depth))
    
    def step_update(self):
        """更新步数"""
        self.step += 1
        
    def get_layer_mask(self, total_layers):
        """
        获取应该激活的层掩码
        
        Args:
            total_layers: 模型的总层数
            
        Returns:
            mask: 布尔张量,True表示该层激活
        """
        current_depth = self.get_current_depth()
        mask = torch.zeros(total_layers, dtype=torch.bool)
        mask[:current_depth] = True
        return mask
 
 
class TransformerWithDepthGrowth(nn.Module):
    """支持深度增长的Transformer"""
    
    def __init__(self, d_model=512, d_ff=2048, n_heads=8, max_depth=12):
        super().__init__()
        self.max_depth = max_depth
        self.d_model = d_model
        
        # 创建所有层
        self.layers = nn.ModuleList([
            TransformerLayer(d_model, d_ff, n_heads)
            for _ in range(max_depth)
        ])
        
        # 深度增长控制器
        self.depth_controller = ArchitectureWarmup(
            target_depth=max_depth,
            base_depth=2,
            growth_steps=5000,
            growth_mode='warmup_decay'
        )
        
        self.active_depth = 2
        
    def forward(self, x):
        # 获取当前活跃深度
        self.active_depth = self.depth_controller.get_current_depth()
        
        # 只通过活跃层
        for i in range(self.active_depth):
            x = self.layers[i](x)
        
        return x
    
    def get_curvature_aware_lr(self, layer_idx):
        """
        获取考虑曲率的层特定学习率
        
        深层需要更小的学习率以控制曲率
        """
        base_lr = 1e-4
        depth_penalty = 0.95 ** (layer_idx - self.active_depth + 1)
        return base_lr * depth_penalty
 
 
def train_with_architecture_warmup():
    """使用Architecture Warm-up训练"""
    print("=== Architecture Warm-up训练演示 ===")
    
    # 创建模型
    model = TransformerWithDepthGrowth(
        d_model=512,
        d_ff=2048,
        n_heads=8,
        max_depth=12
    )
    
    # 优化器配置
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=20000
    )
    
    # 训练循环
    print("开始训练...")
    for step in range(10000):
        # 模拟训练步骤
        x = torch.randn(32, 128, 512)  # batch, seq, dim
        y = torch.randn(32, 128, 512)
        
        optimizer.zero_grad()
        output = model(x)
        loss = nn.functional.mse_loss(output, y)
        loss.backward()
        
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        scheduler.step()
        
        # 更新深度控制器
        model.depth_controller.step_update()
        
        # 定期报告
        if step % 1000 == 0:
            active_depth = model.depth_controller.get_current_depth()
            lr = optimizer.param_groups[0]['lr']
            print(f"Step {step:5d}: depth={active_depth:2d}, lr={lr:.6f}, loss={loss.item():.4f}")
    
    print("训练完成!")
 
 
if __name__ == "__main__":
    train_with_architecture_warmup()

与传统Warm-up的对比

方面学习率Warm-up架构Warm-up
控制目标梯度幅度曲率增长
稳定性机制渐进学习渐进深度
计算开销极小
对深层网络有限效果直接控制
与其他方法兼容

实验验证

实验设置

在以下模型上验证Architecture Warm-up:

  • GPT-2 (124M, 345M, 762M参数)
  • T5-base (220M参数)
  • ViT-B/16 (86M参数)

主要结果

=== 训练稳定性对比 ===
模型           | 传统方法失稳次数 | Architecture Warm-up失稳次数
-------------|-----------------|--------------------------
GPT-2 124M   | 3               | 0
GPT-2 345M   | 7               | 1
GPT-2 762M   | 12              | 2
T5-base      | 5               | 1
ViT-B/16     | 4               | 0

收敛速度对比

=== 最终性能对比 ===
模型           | 传统方法困惑度 | Architecture Warm-up困惑度
-------------|---------------|---------------------------
GPT-2 124M   | 18.92         | 18.67
GPT-2 345M   | 15.34         | 15.08
GPT-2 762M   | 13.21         | 12.95

与现有稳定化技术的对比

技术原理优点缺点
梯度裁剪控制梯度范数简单有效不直接控制曲率
学习率衰减降低更新幅度标准做法收敛慢
权重归一化控制权重增长稳定可能限制表达能力
Architecture Warm-up渐进深度增长直接控制曲率需要修改架构

实践建议

何时使用Architecture Warm-up

  1. 训练深层Transformer(>8层)
  2. 使用大学习率
  3. 资源充足(可以在训练中途增加层)
  4. 训练不稳定频繁出现

超参数选择

# 推荐配置
config = {
    # 初始深度:建议2-4层
    'base_depth': 2,
    
    # 增长步数:通常为总步数的10-25%
    'growth_steps': 5000,  # 总步数20000的25%
    
    # 增长模式:'linear'最稳定,'warmup_decay'更快
    'growth_mode': 'warmup_decay',
    
    # 最大深度:根据任务和资源确定
    'max_depth': 12,
}

与其他技术的兼容性

Architecture Warm-up可以与以下技术结合使用:

  • 梯度裁剪
  • 学习率调度
  • 权重归一化
  • 混合精度训练

总结

本文提出Architecture Warm-up方法来解决大规模Transformer训练的不稳定性问题。主要贡献:

  1. 快速曲率估计器:Warm-Started Power Iteration实现80%加速
  2. 曲率-深度关系:证明曲率随深度指数增长
  3. 不稳定阈值:发现训练失稳的临界曲率条件
  4. Architecture Warm-up:通过渐进深度增长控制曲率

这种方法为Transformer训练提供了一种新的稳定性控制视角,与传统的学习率控制互补。

Footnotes

  1. Source: Taming Curvature: Architecture Warm-up for Stable Transformer Training