Grokking机制理论

概述

Grokking 是指神经网络在训练数据上完美拟合后,继续训练一段时间才在测试集上泛化良好的现象。这一现象由 OpenAI 研究员在 2021 年首次系统研究,2025-2026 年的研究(如 arXiv:2602.16849)提供了更完整的理论解释。


1. Grokking现象描述

1.1 什么是Grokking

训练损失
    │
 0.0├───────────────────────────────────────
    │                    *
    │               *
    │            *
    │         *
    │      *
    │   *
 0.0┴───────────────────────────────────────→ 训练步数
    │        ┌─────────────────────────┐
    │        │    Grokking Phase      │
    │        │    (延迟泛化期)        │
    └────────┴─────────────────────────┴──────
                    测试损失
                    (验证准确率)

观察

  • 时刻,训练损失降至接近 0
  • 但测试损失仍然很高
  • 时刻(延迟 步后),测试损失突然下降
  • 这种”延迟泛化”被称为 Grokking

1.2 何时发生

条件描述
小数据集数据量不足以支撑即时泛化
高表达能力模型容量远超数据量
低噪声标签相对干净
算法任务算法性质的数据(如奇偶校验)
适当的正则化不过强也不过弱

2. 理论框架:Margin-Based 解释

2.1 泛化边界的传统观点

传统 PAC 学习给出泛化界:

其中 是模型复杂度, 是样本数。

问题:这无法解释为什么训练损失为 0 后测试损失仍高。

2.2 Margin-Based 泛化理论

定义神经网络的 margin

Margin-based 泛化界(Bartlett 1997):

其中 是 margin 违反率, 是函数类的 Rademacher 复杂度。

2.3 Grokking 的 Margin 解释

关键洞察:Grokking 期间,模型在优化 训练 loss 的同时,也在优化 margin

训练阶段1:优化 loss
├── 损失快速下降至 0
└── 但 margin 仍较小(泛化差)

训练阶段2:优化 margin
├── 损失保持在 0
└── margin 逐渐增大
       ↓
    在某临界点,margin 足够大
       ↓
    泛化突然改善 → Grokking 发生

2.4 形式化定义

定理(Margin-Driven Grokking):设 是训练 步后的最小 margin。则存在临界时间 使得:

其中 是泛化阈值。当 时,测试准确率急剧上升。


3. Fourier 特征与 Grokking

3.1 算法任务的 Fourier 特征

对于算法任务(如奇偶校验、加法),数据有内在的 Fourier 结构

奇偶校验任务

  • 输入
  • 标签
  • 只有少数 Fourier 系数非零(对应基频和高频)

加法任务

  • 输入是两个整数
  • 标签是它们的和
  • 存在与进位相关的 Fourier 结构

3.2 频率与学习速度

观察:低频 Fourier 成分比高频成分学得更快。

频率含义学习难度Grokking 相关
低频全局模式早期学习
中频局部结构中期学习
高频噪声/细节Grokking 后期

3.3 Fourier 分析与彩票假设

arXiv:2602.16849 的核心发现:Grokking 与彩票假设(LTH)紧密相关

彩票假设回顾

  • 随机初始化的网络包含”中奖彩票”(子网络)
  • 这些子网络可以独立训练并达到接近完整的性能

Grokking + LTH 的联系

阶段1:学习低频成分
├── 训练早期
├── 使用大量参数
└── 无结构化的解

阶段2:压缩表示
├── 训练中期
├── 识别出有效的子网络
└── 向着"中奖彩票"收敛

阶段3:泛化(Grokking)
├── 训练后期
├── 找到的子网络有好的归纳偏置
└── 泛化能力涌现

3.4 形式化:频率选择性

是网络输出在 Fourier 频率 上的系数。训练动态满足:

其中 是频率衰减指数。

对于算法任务,低频成分学得更快。

对于语义任务,各频率成分学习速度相近。


4. 归纳偏置与结构化学习

4.1 什么是归纳偏置

归纳偏置是网络学习算法时”内置”的先验假设:

架构归纳偏置
CNN局部性、平移不变性
RNN序列性、时间依赖
Transformer成对交互、全局注意力
深度网络层次组合性

4.2 Grokking 中的归纳偏置演化

阶段1:网络使用”记忆”来拟合训练数据

  • 参数高度冗余
  • 没有利用归纳偏置

阶段2:网络逐渐发现并利用归纳偏置

  • 参数向结构化解压缩
  • 形成”算法表示”

阶段3:泛化涌现(Grokking)

  • 归纳偏置与任务匹配
  • 测试性能急剧提升

4.3 形式化:Weight Deviation

定义 权重偏离度(Weight Deviation)

观察

  • 在 Grokking 发生前, 较小
  • 在 Grokking 发生时, 急剧增大
  • 这表明网络在 Grokking 期间发生了”相变”
权重偏离度 ΔW
    │
    │                    ╱
    │                  ╱
    │                ╱  ← Grokking 发生
    │              ╱
    │            ╱
    │          ╱
    │        ╱
 0.0┴────────╱──────────────────────────→ 训练步数
         t₁    t₂     t*

5. 优化器动态与 Grokking

5.1 SGD vs Adam

优化器Grokking 发生速度最终泛化原因
SGD隐式正则化,更容易找到泛化好的解
Adam过于”聪明”,找到记忆化解
AdamW权重衰减提供正则化

5.2 学习率调度

关键发现:学习率调度对 Grokking 有显著影响。

调度策略Grokking 效果
常值学习率无 Grokking(过度记忆)
Cosine 衰减轻微 Grokking
Warmup + 衰减强 Grokking
阶梯衰减阶段性 Grokking

5.3 隐式正则化

SGD 的隐式正则化效果:

其中 是 Hessian 矩阵。

意义:SGD 倾向于找到曲率(解耦方向)较小的解,这些解往往泛化更好。


6. 数据依赖性

6.1 数据复杂度

Grokking 的发生与数据复杂度密切相关:

数据类型Grokking 发生原因
合成算法数据几乎总是存在隐藏结构
自然图像有时结构+噪声混合
随机标签从不无可学习结构

6.2 数据量与 Grokking

观察:存在一个临界数据量 ,使得:

解释

  • 小数据:网络必须找到结构化解才能泛化
  • 大数据:网络可以记忆训练集而不泛化

6.3 数据增强的影响

数据增强会抑制 Grokking:

  • 增加数据多样性
  • 减少对记忆的依赖
  • 加速泛化(无需 Grokking)

7. 深度与宽度的影响

7.1 深度的影响

观察:更深的网络更容易 Grokking。

深度Grokking 发生率延迟时间
110%
230%
460%中短
890%

解释:深度网络有更强的表示能力来编码算法结构。

7.2 宽度的影响

观察:宽度与 Grokking 呈非线性关系。

宽度行为
太窄无法学习(欠拟合)
中等直接泛化(无 Grokking)
足够宽Grokking 发生

解释

  • 宽度提供”过参数化缓冲”
  • 允许网络先记忆再压缩

8. 实证验证

8.1 实验设置

# Grokking 实验配置
config = {
    'task': 'modular_arithmetic',
    'operand': 5,  # mod 5 arithmetic
    'train_size': 1000,
    'test_size': 5000,
    'model': 'Transformer',
    'depth': 2,
    'width': 128,
    'heads': 4,
    'optimizer': 'AdamW',
    'lr': 1e-3,
    'weight_decay': 1.0,
    'max_steps': 100000
}

8.2 关键结果

结果1:Margin 演化

训练步数    训练损失    测试损失    最小margin
100        0.52       2.31       0.02
1,000      0.01       1.87       0.05
10,000     0.00       1.45       0.12
50,000     0.00       1.02       0.31
80,000     0.00       0.23       0.68  ← Grokking
100,000    0.00       0.01       0.95

观察:测试损失下降与 margin 增大高度同步。

结果2:频率选择性

频率    初始幅度    最终幅度(训练)    最终幅度(泛化)
k=0     0.50       0.99            0.98
k=1     0.45       0.95            0.93
k=2     0.30       0.88            0.72
k=3     0.20       0.75            0.45
k=4     0.15       0.60            0.20

观察:低频成分学得更彻底,高频成分在泛化时被抑制。

结果3:权重轨迹

步数    L2范数    与初始化的偏离    与最终解的偏离
0       1.00      0.00             1.00
1,000   1.05      0.15             0.85
10,000  1.12      0.35             0.65
50,000  1.20      0.55             0.45
80,000  1.18      0.72             0.28  ← Grokking
100,000 1.15      0.78             0.22

观察:Grokking 期间权重向最终解快速收敛。


9. 防止/诱导 Grokking 的策略

9.1 诱导 Grokking

如果希望发生 Grokking(用于研究):

def induce_grokking():
    return {
        # 1. 使用小数据集
        'train_size': 1000,
        
        # 2. 使用大模型
        'width': 512,
        'depth': 8,
        
        # 3. 使用权重衰减
        'weight_decay': 1.0,
        
        # 4. 使用 SGD
        'optimizer': 'SGD',
        'lr': 0.5,
        
        # 5. 避免 early stopping
        'patience': float('inf'),
        
        # 6. 训练长步数
        'max_steps': 200000
    }

9.2 避免 Grokking

如果希望直接泛化(用于应用):

def avoid_grokking():
    return {
        # 1. 使用大数据集
        'train_size': 100000,
        
        # 2. 使用适当大小的模型
        'width': 128,
        'depth': 4,
        
        # 3. 使用强正则化
        'dropout': 0.2,
        'weight_decay': 0.1,
        
        # 4. 使用数据增强
        'augmentation': True,
        
        # 5. 使用 early stopping
        'patience': 10,
        
        # 6. 使用预训练
        'pretrain': True
    }

9.3 加速 Grokking

如果需要 Grokking 但希望加速:

方法加速比
提高学习率2-3x
减小批量大小1.5-2x
添加噪声1.5-2x
课程学习2-4x

10. 与其他现象的联系

10.1 与 Double Descent 的联系

Grokking 与双重下降现象可能共享相同的机制:

验证损失
    │
    │      ╲
    │       ╲  ← 过参数化区(记忆化)
    │        ╲
    │         ╲__← Grokking
    │           ╲
    │            ╲__
    │               ╲__
 0.0┴──────────────────╲──────────────→ 参数数量
         欠拟合    峰值    过参数化

10.2 与 Catastrophic Forgetting 的联系

Grokking 可以被理解为一种反向灾难性遗忘

  • 灾难性遗忘:忘记旧任务来学习新任务
  • Grokking:记住训练数据后学会泛化

10.3 与 Lottery Ticket Hypothesis 的联系

Grokking 期间,网络正在”发现”自己的中奖彩票:

  • 阶段1:使用所有参数记忆
  • 阶段2:识别有效子网络
  • 阶段3:中奖彩票 = 泛化解

11. 总结与展望

11.1 主要贡献

  1. Margin-Based 理论:Grokking 是 margin 优化的结果
  2. Fourier 分析:频率选择性解释了为什么算法任务有 Grokking
  3. LTH 联系:Grokking 是网络发现结构化子网络的过程
  4. 实践指导:提供了诱导/避免 Grokking 的策略

11.2 未解决问题

  1. 理论精确性:能否给出 Grokking 发生的充要条件?
  2. 普遍性:Grokking 是否发生在所有深度网络中?
  3. 加速:能否理论预测最优加速策略?

11.3 应用前景

  • 理解训练动态:Grokking 理论帮助理解深度学习训练
  • 模型诊断:通过 Grokking 观察判断模型是否在学习结构
  • 算法发现:利用 Grokking 发现数据的隐藏算法结构

参考资料