引言:Grokking现象

Grokking(顿悟)是深度学习领域中一种引人入胜的现象,最初由Power等人于2022年在论文《Grokking: Generalization on the Test Set》中正式提出1。这一现象描述的是神经网络训练过程中的一种反直觉行为:模型在训练集上的准确率达到100%之后,仍需要继续训练很长时间才能在测试集上获得良好的泛化性能

import torch
import matplotlib.pyplot as plt
 
def plot_grokking_dynamics(train_losses, test_losses, train_accs, test_accs):
    """
    典型的Grokking学习曲线示意:
    - 训练准确率迅速达到100%
    - 测试准确率长时间停留在随机猜测水平
    - 突然在某个时间点"顿悟",测试准确率跃升至高水平
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    steps = range(len(train_losses))
    
    ax1.plot(steps, train_losses, label='训练损失')
    ax1.plot(steps, test_losses, label='测试损失')
    ax1.set_xlabel('训练步数')
    ax1.set_ylabel('损失值')
    ax1.legend()
    ax1.set_title('损失曲线')
    
    ax2.plot(steps, train_accs, label='训练准确率')
    ax2.plot(steps, test_accs, label='测试准确率')
    ax2.axhline(y=0.5, color='gray', linestyle='--', label='随机猜测')
    ax2.set_xlabel('训练步数')
    ax2.set_ylabel('准确率')
    ax2.legend()
    ax2.set_title('准确率曲线(典型Grokking)')
    
    plt.tight_layout()
    plt.savefig('grokking_dynamics.png')
    plt.show()

从数学角度来看,Grokking揭示了优化目标与泛化目标之间存在时间上的不对称性

模型首先最小化训练损失(记忆阶段),然后才逐渐学会泛化(理解阶段)。这种”先记忆后理解”的过程,与人类学习某些技能时的顿悟体验有着奇妙的相似之处。

记忆与泛化的分离

何时开始泛化

理解Grokking现象的关键在于认识到记忆(memorization)和泛化(generalization)是两个可以分离的学习过程。在传统机器学习中,我们通常假设模型在学习过程中会同时追求这两个目标。但Grokking表明,当模型容量远超过任务所需的复杂度时,这两个过程可能在时间上完全分离。

泛化开始的条件

根据Anthropic的研究2,泛化通常在以下条件下开始:

  1. 达到完美训练记忆:训练准确率达到100%
  2. 度过记忆稳定期:权重空间进入相对稳定的配置
  3. 发现更好的权重配置:Loss landscape中存在更平坦的泛化解
// 简化的Groking判定逻辑
bool isGrokkingOccurred(vector<double>& test_accs, int window = 100) {
    // 检测测试准确率的突变
    double recent_avg = average(test_accs.end() - window, test_accs.end());
    double previous_avg = average(test_accs.end() - 2*window, test_accs.end() - window);
    
    // 如果近期平均比之前显著提高,说明发生了Grokking
    return (recent_avg - previous_avg) > 0.3;  // 30%的跃升阈值
}

权重衰减的作用

权重衰减(Weight Decay)是触发和调节Grokking的关键超参数之一。Power等人发现,当引入适当的权重衰减时,泛化开始的时间点会发生显著变化。

权重衰减的数学形式为:

其中 是正则化系数。研究表明:

权重衰减强度训练后期行为泛化开始时间
持续过拟合,测试性能不改善永不泛化
适中出现明显Grokking训练后期开始泛化
过强训练困难可能无法记忆

权重衰减的作用机制在于:它惩罚权重的复杂性,促使网络寻找更简洁的解。当训练数据被完全记忆后,网络仍有”动力”去探索更高效的表示方式——这正是泛化的本质。

训练动态阶段分析

深入理解Grokking需要将训练过程分解为多个阶段。Neel Nanda的”玩具模型”(Toy Models)研究3提供了极为清晰的阶段划分:

初始化阶段

训练开始时,权重从随机初始化状态开始。此时:

  • 训练和测试损失都较高
  • 模型处于”混沌”状态,预测接近随机
  • 权重更新主要受梯度主导

初始化阶段的关键特征可用以下数学描述:

其中 通常取 以保持前向传播的方差稳定。

记忆阶段

记忆阶段是Grokking现象的核心,此时模型:

  • 快速记忆训练数据:训练准确率迅速接近100%
  • 测试性能停滞:测试准确率可能仅略高于随机猜测
  • 权重空间演变:权重范数逐渐增大
阶段示意:
步数 0-500:    训练准确率 10% → 99%
步数 500-2000: 测试准确率 50% → 55% (几乎停滞)
步数 2000+:    测试准确率开始上升...

记忆阶段的本质是神经网络利用其过参数化特性,记忆训练样本。这个过程通常很快,因为:

  • 神经网络具有巨大的参数容量
  • 随机梯度下降能找到记忆单个样本的解
  • 这些”记忆解”通常不是最优的泛化解

泛化阶段

泛化阶段是Grokking的精髓所在。此时:

  • 模型从”记住答案”转向”理解原理”
  • 测试准确率开始显著提升
  • 权重配置趋向更简洁的结构

泛化阶段的发生暗示:存在比纯记忆更高效的任务解决方法。这通常与数据的内在结构有关——例如算术运算中的进位规律、模式识别中的对称性等。

电路形成过程

Grokking研究与电路复杂度(circuit complexity)理论密切相关。Anthropic的研究表明,泛化过程本质上是神经网络逐渐形成结构化电路的过程。

从分散表示到结构化电路

在记忆阶段,模型对每个训练样本可能都有”专用”的表示方式——这是一种分散的、去中心化的知识存储方式。

记忆阶段(分散表示):
输入A → 神经元1 → 神经元5 → 输出A'
输入B → 神经元2 → 神经元6 → 输出B'
输入C → 神经元3 → 神经元7 → 输出C'
...(每个样本独立处理)

泛化阶段(结构化电路):
输入X → 算法模块 → 算法模块 → 输出Y
                    ↑
输入Z → ────────────┘
(共享的计算结构)

结构化电路的关键特征是参数共享层次化组织。当网络学会识别任务的底层结构后,相同的电路可以处理所有符合该结构的输入。

渐进式电路组装

电路的形成是一个渐进过程,而非一蹴而就:

class CircuitFormationMonitor:
    def track_circuit_evolution(self, model, train_loader):
        """
        监控电路形成的阶段
        
        阶段1: 弱电路形成(Weak Circuit Formation)
            - 某些注意头/神经元开始对相关模式有微弱响应
            - 响应具有一定的方向性但不精确
        
        阶段2: 电路加强(Circuit Strengthening)
            - 相关权重逐渐增大
            - 不相关权重被压制
        
        阶段3: 电路精炼(Circuit Refinement)
            - 电路结构趋于稳定
            - 权重值精细调整
        """
        activation_patterns = []
        weight_magnitudes = []
        
        for epoch in range(num_epochs):
            # 记录激活模式和权重
            activations = self.record_activations(model, train_loader)
            weights = self.record_weights(model)
            
            activation_patterns.append(activations)
            weight_magnitudes.append(weights)
            
            # 分析电路复杂度
            complexity = self.compute_circuit_complexity(activations, weights)
            print(f"Epoch {epoch}: 电路复杂度 = {complexity:.4f}")

理论解释

任务结构复杂性

Grokking现象与任务的固有复杂性密切相关。Neel Nanda的研究将任务分为三类:

任务类型描述Grokking行为
简单查表直接记忆即可无Grokking,直接泛化
结构化任务存在隐藏模式明显Grokking
随机噪声无可学习的模式永不泛化

以模运算任务为例:

这个任务具有清晰的结构:

  • 加法和乘法的代数性质
  • 模运算的周期性
  • 参数(a, b, n)定义了具体实例

神经网络需要学习的是算法本身(如何做模运算),而非简单的输入-输出映射。这正是Grokking发生的本质原因。

损失Landscape

Grokking现象可以从损失景观(Loss Landscape)的角度理解。

Loss Landscape示意图:

        记忆解(局部极小)
              ↓
    ╭─────────────────────╮
   ╱                       ╲
  ╱   全局最小             ╲
 │  (泛化解)              │
  ╲                         ╱
   ╲                       ╱
    ╰─────────────────────╯
    
    训练初期 → 记忆阶段 → 泛化阶段

关键洞察:

  1. 记忆解的陷阱:存在大量局部极小值对应于”完美记忆训练数据但无法泛化”的解
  2. 泛化解的吸引域:泛化解通常在更平坦的区域,具有更大的吸引域
  3. 权重衰减的作用:权重衰减帮助模型逃离记忆解的陷阱,进入泛化区域

从优化动力学的角度,Grokking可以描述为:

权重衰减项 扮演了”推动”模型离开局部极小的角色,使其有机会发现更优的泛化解。

实践意义:理解神经网络学习

Grokking研究对深度学习实践有多方面的指导意义:

训练策略

  1. 更长的训练时间:对于复杂任务,不要过早停止训练
  2. 适当的正则化:权重衰减对于泛化至关重要
  3. 监控测试性能:训练准确率100%不代表模型已经”学会”

模型选择

  • 过参数化的双刃剑:足够的参数是记忆的必要条件,但可能导致泛化困难
  • 电路复杂度的权衡:更简单的架构可能更早泛化,但容量受限

调试与诊断

当模型出现以下情况时,可能正在经历Grokking:

def diagnose_grokking(trainer):
    """
    Grokking诊断检查清单:
    
    1. 训练准确率达到100%但测试准确率很低
    2. 继续训练后测试性能突然改善
    3. 权重范数在泛化前达到峰值然后下降
    4. 不同随机种子导致泛化时间差异大
    """
    symptoms = {
        'perfect_train_acc': trainer.best_train_acc >= 0.9999,
        'poor_test_acc': trainer.best_test_acc < 0.6,
        'extended_training': trainer.current_step > trainer.step_at_100_train_acc * 2,
        'late_improvement': trainer.test_acc_improved_after_step(trainer.step_at_100_train_acc)
    }
    
    if sum(symptoms.values()) >= 3:
        print("⚠️ 检测到典型的Grokking行为")
        return "grokking_detected"
    return "normal"

参考文献


相关主题:自适应优化器理论深度学习基础链式推理

Footnotes

  1. Power, A., et al. (2022). Grokking: Generalization on the Test Set. ICML 2022. https://arxiv.org/abs/2201.02177

  2. Anthropic Research Team. Toy Models of Superposition. Anthropic Technical Report. https://transformer-circuits.pub/2022/toy_model/index.html

  3. Nanda, N. (2023). Growing NNs: A Toy Model of Grokking. Neel Nanda’s Blog. https://nickc.substack.com/p/grokking-and-failure-of-mode-connectivity