Unifying Learning Dynamics and Generalization in Transformers

This article presents a comprehensive theoretical framework for rigorously analyzing the scaling law phenomenon in Large Language Models (LLMs), specifically addressing the empirically observed power-law relationship between model performance and computational resources from a learning dynamics perspective.1

1. Introduction: The Scaling Law Puzzle

1.1 Empirical Observations

神经缩放定律(Neural Scaling Laws)是深度学习领域最重要的经验规律之一:神经网络的性能(通常用测试损失或困惑度衡量)与其规模(参数量 )、训练数据量 )、计算量(FLOPs )之间存在可预测的幂律关系。1

经典缩放定律公式(Kaplan等人):

其中:

  • :测试集上的交叉熵损失
  • :模型参数量
  • :训练数据量
  • :不可约损失(irreducible loss)
  • , :缩放指数

1.2 The Theoretical Gap

尽管缩放定律已被广泛验证,但其理论机制尚不清楚。现有研究的问题:

研究类型局限性
Toy Models过于简化,无法解释真实训练动态
Empiricism缺乏理论保证,外推可靠性存疑
NTK Theory需要无限宽度假设,与实际不符
Phenomenology仅描述现象,不解释成因

1.3 Research Questions

本研究旨在回答以下核心问题:

核心问题:在训练基于Transformer的语言模型时,如何从理论角度确保计算资源分配的收敛性保证?

具体而言:

  1. 如何为多层Transformer建立精确的学习动态方程?
  2. 优化过程与泛化性能之间存在怎样的相位转换?
  3. 模型大小、训练时间、数据集大小如何独立影响泛化上界?

2. ODE System for Transformer Learning Dynamics

2.1 From Discrete Layers to Continuous Dynamics

借鉴Neural ODE的思想,Transformer的层叠可以视为连续动力学系统的离散化。

考虑一个 层的Transformer,其更新公式为:

其中 是第 层在时间 的隐藏状态。

2.2 The Transformer ODE System

本研究将多层Transformer的训练动态建模为ODE系统。对于decoder-only的生成式Transformer,利用其特殊性质简化复杂的矩阵运算为并行向量运算。

核心ODE系统

其中:

  • :时刻 的参数向量
  • :学习率
  • :损失函数

2.3 SGD Dynamics Formulation

在随机梯度下降(SGD)框架下,参数更新为:

其中 是经验风险。

ODE近似:当学习率足够小且批量足够大时,SGD可以近似为连续的随机微分方程:

其中 是噪声扩散系数, 是维纳过程。

2.4 Key Derivation for Learning Dynamics

论文的核心推导包括:

  1. Decoder-only属性利用:将复杂的矩阵运算简化为并行向量运算
  2. 多token生成建模:将序列到序列的预测建模为ODE系统
  3. 任意数据分布支持:放宽对数据分布的假设

3. Kernel Approximation and Convergence Analysis

3.1 From ODE to Kernel Behavior

训练后期,当模型接近收敛时,Transformer的动态可以近似为核方法(Kernel)行为。这一分析框架借鉴了神经正切核(Neural Tangent Kernel, NTK)理论,但突破了无限宽度假设的限制。

核近似假设:在适当条件下,多层Transformer的更新可以近似为:

3.2 Kernel Matrix Formulation

定义核矩阵 ,其中 是数据集大小:

核行为下的预测

其中 是测试样本与训练样本的核相似度向量。

3.3 Convergence Analysis

定理(收敛性保证):在核近似下,经验风险最小化器的期望风险收敛到最优预测器。

为在数据集 上训练的模型,则:

其中:

  • 是期望风险
  • 是最优预测器

3.4 Excess Risk Decomposition

泛化误差可以分解为:

核范数约束

4. Phase Transition from Exponential to Power-Law Decay

4.1 Two-Phase Structure

本研究的核心发现是泛化误差随计算资源增长呈现两阶段相位转换

┌─────────────────────────────────────────────────────────────────┐
│                  泛化误差的相位转换结构                           │
├─────────────────────────────────────────────────────────────────┤
│                                                                  │
│   Excess Risk                                                   │
│       │                                                         │
│       │    ╱                                                     │
│       │   ╱  阶段1: 计算受限                                     │
│       │  ╱   Exponential decay: Θ(exp(-C))                      │
│       │ ╱                                                        │
│       │╱──────────────────────────────────                       │
│       ╱                                   ╲                       │
│      ╱                                     ╲  阶段2: 数据受限    │
│     ╱                                       ╲  Power-law: Θ(C^(-1/6)) │
│    ╱                                         ╲                   │
│   ╱                                           ╲                  │
│  └─────────────────────────────────────────────→ Compute C      │
│                                                   ↑              │
│                                              临界点 C*           │
└─────────────────────────────────────────────────────────────────┘

4.2 Stage 1: Compute-Starved Phase

优化主导阶段:在训练初期,模型处于”计算受限”状态。

指数衰减特性

其中 是衰减系数, 是计算成本。

物理直觉

  • 此时模型尚未学习到数据的关键结构
  • 优化过程主导泛化动态
  • 增加计算资源可以快速改善性能

4.3 Critical Threshold

临界点由以下条件定义:

其中 是数据集大小, 是与模型架构相关的常数。

临界条件公式

4.4 Stage 2: Data-Limited Phase

统计主导阶段:越过临界点后,系统进入”数据受限”状态。

幂律衰减特性

物理直觉

  • 优化问题已基本解决
  • 泛化误差受限于数据分布的固有复杂性
  • 进一步增加训练时间或模型容量,收益递减

4.5 Mathematical Characterization

定理(两阶段泛化界):设 为总计算成本,则期望风险的泛化上界为:

其中 是噪声方差, 是数据集大小。

4.6 Empirical Validation

实验验证使用GPT-2系列模型,结果显示:

阶段预测行为观测到的缩放指数
计算受限 (log-linear)
数据受限

这与理论预测高度一致,证实了相位转换的存在。

5. Separate Scaling Laws for Model Size, Training Time, and Dataset Size

5.1 Individual Scaling Exponents

本研究的一个关键贡献是推导出模型大小、训练时间和数据集大小独立影响泛化上界的缩放定律。

5.2 Model Size Scaling Law

模型大小与性能的关系

其中 是参数量, 是模型缩放指数。

独立缩放界

直觉解释

  • 更大的模型有更大的假设空间
  • 但同时也更容易过拟合
  • 存在最优模型大小与数据集大小的平衡

5.3 Training Time Scaling Law

训练时间与性能的关系

其中 是训练步数或epoch数, 是时间缩放指数。

独立缩放界

5.4 Dataset Size Scaling Law

数据集大小与性能的关系

其中 是数据集大小, 是数据缩放指数。

独立缩放界

5.5 Joint Scaling Framework

整合三个变量的完整公式:

计算约束下的最优分配

5.6 Chinchilla Scaling Law Connection

本研究与Hoffmann等人提出的Chinchilla定律的关系:

定律公式关注点
Kaplan定律模型大小
Chinchilla计算最优分配
本研究两阶段 学习动态视角

6. Theoretical Upper Bound on Excess Risk

6.1 Excess Risk Definition

定义:超额风险(Excess Risk)定义为模型期望风险与最优贝叶斯风险之差:

其中 是给定数据分布下的最优预测器。

6.2 Main Upper Bound Theorem

定理(超额风险上界):对于训练 单位计算的Transformer模型,其超额风险的期望满足:

6.3 Detailed Bound Components

分解形式

误差项形式来源
优化动态
有限样本
模型近似

6.4 Concentration Inequalities

使用McDiarmid不等式和经验过程理论,获得高概率界:

其中 是与损失函数方差相关的常数。

6.5 Rate Optimality

最优性证明:文中证明 的速率是渐进最优的:

这意味着任何算法都不能达到比 更快的收敛速率(在数据受限阶段)。

7. Connections to Empirical Observations

7.1 GPT-2 Series Validation

论文使用GPT-2系列(124M, 355M, 774M, 1.5B参数)验证理论预测:

观测理论预测验证结果
相位转换点✓ 高度一致
幂律指数✓ 接近观测
临界噪声 相关✓ 数据质量敏感

7.2 Scaling Law Breakdown Conditions

本研究识别出缩放定律可能失效(breakdown)的三种情况:

7.2.1 Lambert W Function Correction

在高精度区域,对数修正变得显著:

其中 是Lambert W函数。

7.2.2 Dataset Noise Growth

当数据噪声随样本量增加时:

导致缩放收益递减。

7.2.3 Model Capacity Saturation

当模型容量相对于数据复杂度饱和时:

进一步增加模型大小不会提升性能。

7.3 Relationship with Emergent Abilities

本研究为涌现能力提供了动力学解释:

涌现能力的相位转换解释:某些”突现”的能力可能是模型从计算受限阶段向数据受限阶段转换的标志。

┌────────────────────────────────────────────────────────┐
│              涌现能力的相位转换视角                       │
├────────────────────────────────────────────────────────┤
│                                                         │
│   能力水平                                              │
│       │                                                │
│       │              ╱╲                                 │
│       │             ╱  ╲   ← 看似"突现"的能力          │
│       │            ╱    ╲                              │
│       │           ╱      ╲                             │
│       │    ──────╱────────╲─────────────→ 计算量       │
│       │         ↑         ↑                             │
│       │      临界点1    临界点2                         │
│       │      (学习开始)  (能力涌现)                     │
└────────────────────────────────────────────────────────┘

7.4 Implications for Practice

对训练实践的启示

阶段推荐策略
计算受限增加batch size、延长训练、使用更大模型
数据受限收集更多高质量数据、改进数据质量
临界区域平衡模型大小与数据量,避免资源浪费

8. Implementation Framework

8.1 Conceptual Algorithm

以下伪代码展示了两阶段训练框架的概念:

def compute_optimal_training_budget(N, D, model_complexity):
    """
    Compute optimal compute allocation based on theory.
    
    Args:
        N: Number of parameters
        D: Dataset size
        model_complexity: Architectural complexity factor
    
    Returns:
        C_star: Critical compute threshold
        budget: Optimal total compute budget
    """
    # Critical threshold from theory
    C_star = (D ** 0.6) * (model_complexity ** 0.4)
    
    # Phase-dependent allocation
    if current_compute < C_star:
        # Compute-starved phase: exponential decay regime
        return allocate_for_exponential_regime(C_star)
    else:
        # Data-limited phase: power-law regime
        return allocate_for_power_law_regime(D)

8.2 Loss Curve Prediction

def predict_loss_curve(C, N, D, C_star=None):
    """
    Predict loss based on two-phase theory.
    
    Args:
        C: Compute budget
        N: Model parameters
        D: Dataset size
        C_star: Critical threshold (optional)
    
    Returns:
        Expected loss
    """
    if C_star is None:
        C_star = (D ** 0.6) * (N ** 0.4)
    
    if C < C_star:
        # Compute-starved phase
        base_loss = initial_loss
        loss = base_loss * np.exp(-C / C_star)
    else:
        # Data-limited phase
        base_loss = irreducible_loss(C_star)
        exponent = -1/6
        loss = base_loss * (C / C_star) ** exponent
    
    # Add statistical correction
    noise_term = noise_std / np.sqrt(D)
    
    return loss + noise_term

8.3 Phase Detection

def detect_training_phase(loss_history, compute_history):
    """
    Detect current training phase from loss curve.
    
    Returns:
        phase: 'compute_starved' or 'data_limited'
        confidence: Detection confidence
    """
    # Compute empirical decay rate
    log_losses = np.log(loss_history)
    derivatives = np.gradient(log_losses, compute_history)
    
    # Check for phase transition
    if np.mean(derivatives[:len(derivatives)//2]) < -0.1:
        return 'compute_starved', 0.85
    else:
        return 'data_limited', 0.75

9.1 Relation to Neural Tangent Kernel

方面NTK理论本研究
宽度假设无限宽度有限宽度,多层
动态类型线性非线性(ODE系统)
收敛速率两阶段:指数 +
数据假设i.i.d.任意分布

9.2 Relation to Mean Field Theory

本研究与Mean Field Theory的联系:

  • 使用Mean Field方法分析大宽度极限
  • 建立了参数空间离散化与相应ODE系统的联系
  • 证明了梯度流在足够小的权重衰减下达到全局最小值

9.3 Relation to Phase Transition Scaling

与Phase-Transitional Scaling (PTS)框架的联系:

PTS框架本研究
sigmoid响应指数到幂律转换
能力获取阈值 临界点
数据复杂度控制 数据大小控制相变

10. Conclusions and Open Questions

10.1 Summary of Contributions

本研究的主要贡献:

  1. ODE系统建模:首次对基于SGD训练的多层Transformer进行完整的学习动态分析
  2. 两阶段泛化界:建立了具有明确相位转换的泛化误差上界
  3. 独立缩放定律:推导了模型大小、训练时间、数据集大小的独立缩放关系
  4. 幂律指数确定:理论预测 的幂律指数
  5. 缩放定律失效条件:识别了Lambert W修正、噪声增长、容量饱和等失效模式

10.2 Limitations

当前框架的局限性:

局限说明未来方向
序列长度主要关注固定长度设置变长序列泛化
注意力机制简化的注意力模型完整Softmax注意力
优化器SGD假设Adam等自适应方法
架构变体标准TransformerMoE, SSM等

10.3 Future Research Directions

  • 将框架扩展到混合专家(MoE)架构
  • 分析注意力头数量和深度的影响
  • 研究课程学习和 curriculum scaling
  • 探索与涌现能力的更深层联系

References


Related Topics

Footnotes

  1. Yang, C. (2025). Unifying Learning Dynamics and Generalization in Transformers Scaling Law. arXiv:2512.22088. https://arxiv.org/abs/2512.22088 2