概述

Grokking 是深度学习中一个令人困惑的现象:模型在训练损失已经接近零后很久,验证集性能仍然很差,但突然在某个时刻急剧提升。这种现象对传统的泛化理论提出了挑战。

2026 年的最新研究提供了第一性原理解释,揭示了这一现象背后的深层机制。1


1. 问题陈述

1.1 经典观察

考虑一个简单的任务:在模 算术上进行加法运算。

模型参数量训练样本训练准确率验证准确率
Transformer~100K100099.9%10%
Transformer~100K100099.9%99.9%

两个模型都能完美记忆训练数据,但泛化能力完全不同。

1.2 时间尺度问题

关键问题:为什么泛化需要这么长时间?

时间轴 ─────────────────────────────────────────────────────────▶

训练准确率    ████████████████████████████████████████████  100%
(训练准确率)  ██                                             10%
                 ↑                    ↑
              开始                    泛化发生
              训练                    (Grokking点)
                 │                    │
                 └────────────────────┘
                        ~10^5 steps

传统理论预测:

  • 如果网络能泛化,应该在训练过程中就泛化
  • 不应该存在”延迟泛化”

2. 表示相变理论

2.1 核心假设

Grokking 可以理解为发生在表示空间中的相变。

假设 1:表示的双稳态

神经网络的表示空间可以处于两种状态:

状态特征对应能力
混乱态(Disordered)表征结构接近随机仅能记忆训练样本
有序态(Ordered)表征捕获底层结构能泛化到新样本

假设 2:相变的能垒

从混乱态到有序态存在能垒:

在训练初期,网络被困在混乱态的局部最小值中。

2.2 Kolmogorov 复杂度视角

定义:数据的 Kolmogorov 复杂度 是生成 的最短程序长度。

关键洞察

  1. 记忆 vs 泛化

    • 记忆:需要 比特
    • 泛化:需要 比特
  2. 学习动态

    • SGD 首先学习”节省”的状态(低复杂度表示)
    • 但低复杂度表示被随机初始化的噪声”遮挡”
    • 需要时间”去除噪声”才能暴露底层模式

2.3 形式化

为参数 对应的表示向量。

表示复杂度

其中 是期望表示, 是表示的二阶导数。

学习动态方程

  • :复杂度梯度
  • :有效学习率(随训练递减)
  • :表示噪声

3. 两阶段学习模型

3.1 阶段 1:记忆阶段

持续时间

特征

  1. 训练损失 快速下降到接近零
  2. 验证损失 保持在高位
  3. 网络学习将每个训练样本映射到正确输出

表示状态

解释:网络学习的是”查表”而非”规则”。

3.2 阶段 2:泛化阶段

持续时间

特征

  1. 表示从混乱态逐渐变得有序
  2. 验证损失在临界点 急剧下降
  3. 网络学习到底层规则

表示状态

其中 是规则函数, 是残余噪声。

3.3 相变的临界条件

临界点 满足:

其中:

  • :有效学习率
  • :能垒高度
  • :梯度噪声方差

预测

  • 越大, 越大
  • 越大, 越小(更多探索)
  • 越小, 越大(收敛越慢)

4. 证据与验证

4.1 序参量演化

通过监测表示空间的序参量可以观察到相变:

# 序参量:表示的内部分布结构
def compute_order_parameter(model, data_loader):
    representations = []
    for x, y in data_loader:
        with torch.no_grad():
            r = model.get_representation(x)
        representations.append(r)
    
    representations = torch.cat(representations, dim=0)
    
    # 计算表示的相关矩阵
    corr = torch.corrcoef(representations.T)
    
    # 特征值分布的熵(序参量)
    eigenvalues = torch.linalg.eigvalsh(corr)
    eigenvalues = eigenvalues / eigenvalues.sum()
    entropy = -torch.sum(eigenvalues * torch.log(eigenvalues + 1e-10))
    
    return entropy, eigenvalues

4.2 实验观察

实验 1:表示可视化

记忆阶段                          泛化阶段

  ●  ●                              ●  ●  ●
  ●  ●  ●                           ●  ●  ●  ●
  ●  ●  ●  ●                        ●  ●  ●  ●  ●
  ●  ●  ●  ●  ●
    无明显结构                        明显的聚类结构

实验 2:表示相似性

表示相似性
     │
1.0  │                                    ┌──── grokking 点
     │                                   /
0.8  │                                  /
     │                                 /
0.6  │                                /
     │                               /
0.4  │                              /
     │                             /
0.2  │                            /
     │                           /
0.0  ├──────────────────────────────▶
     0     20K    40K    60K    80K   steps

5. 影响 Grokking 的因素

5.1 数据集大小

数据集越小,Grokking 越明显:

训练样本数Grokking 时间原因
100~5K steps快速记忆,快速泛化
500~30K steps中等
1000~100K steps较长
5000几乎不发生充分数据

直觉:更多数据 = 更难记忆 = 更早开始泛化

5.2 模型大小

参数量Grokking 倾向原因
不易发生容量不足以记忆
中等最易发生容量恰好够记忆
易发生但慢过度参数化

直觉:过度参数化是 grokking 的前提条件。

5.3 权重衰减

权重衰减系数 对 grokking 有显著影响:

泛化能力
     │
高   │        ┌───────────────────★
     │       /
     │      /               ★ 权重衰减过强
     │     /                │  强制记忆,无泛化
     │    /
     │   /
低   │  /
     │ /
     │★
     └────────────────────────────────▶ 权重衰减 λ
              ↑ 最佳点
              λ_opt ≈ 10^{-4}

5.4 学习率

学习率通过影响探索-利用平衡来影响 grokking:

学习率Grokking 时间解释
太小长或无探索不足
适中适中平衡
太大随机性过高

6. 理论预测与实验验证

6.1 预测 1:Grokking 时间与能垒成正比

理论

验证:在模 算术中, 越大,规则越复杂, 越大, 越大。

6.2 预测 2:噪声加速 Grokking

理论 增大 → 减小

验证:添加标签噪声会加速泛化。

6.3 预测 3:表示退火

理论:从高学习率切换到低学习率(类似退火)可以加速 grokking。

验证:学习率 warmup + decay 策略有效。


7. 与其他现象的联系

7.1 Edge of Stability

Grokking 可以视为 EoS 在表示空间中的类比:

现象空间振荡/相变
EoS参数空间损失振荡
Grokking表示空间泛化相变

7.2 灾难性遗忘

Grokking 期间,网络学会”忘记”记忆样本的具体细节,同时”保留”规则知识。这与持续学习中的灾难性遗忘问题相关。

7.3 模式连接

Grokking 点附近的表示形成”模式连接”:不同初始化的网络在 grokking 后趋向于相似的表示。

详见 学习动态与Grokking


8. 实践应用

8.1 促进 Grokking 的策略

# 策略 1:适当的权重衰减
weight_decay = 1e-4  # 不要太大或太小
 
# 策略 2:学习率 schedule
config = {
    'lr_init': 1e-3,
    'warmup_steps': 1000,
    'decay_type': 'cosine',
    'min_lr': 1e-5,
}
 
# 策略 3:早停要谨慎
# 不要在验证准确率还在上升时停止
early_stop_patience = 50  # 较长

8.2 监测 Grokking

class GrokkingMonitor:
    def __init__(self):
        self.train_acc_history = []
        self.val_acc_history = []
        self.repr_entropy_history = []
    
    def step(self, model, train_loader, val_loader):
        # 计算准确率
        train_acc = evaluate(model, train_loader)
        val_acc = evaluate(model, val_loader)
        
        # 计算表示序参量
        repr_entropy = compute_order_parameter(model, train_loader)
        
        self.train_acc_history.append(train_acc)
        self.val_acc_history.append(val_acc)
        self.repr_entropy_history.append(repr_entropy)
        
        return {
            'train_acc': train_acc,
            'val_acc': val_acc,
            'repr_entropy': repr_entropy,
        }
    
    def detect_grokking(self, threshold=0.1):
        """检测 grokking 发生"""
        if len(self.val_acc_history) < 100:
            return False
        
        recent = self.val_acc_history[-100:]
        if max(recent) - min(recent) > threshold:
            # 检查是否在训练准确率已高后发生
            if self.train_acc_history[-1] > 0.95:
                return True
        return False

9. 开放问题

9.1 理论问题

  1. 能垒的来源:为什么从混乱态到有序态存在能垒?
  2. 临界指数:Grokking 相变的普适类是什么?
  3. 与 SGD 噪声的具体关系:噪声如何帮助克服能垒?

9.2 实践问题

  1. 预测 grokking 时间:能否在训练前预测 grokking 是否会发生?
  2. 加速 grokking:能否设计干预措施加速泛化?
  3. Grokking 的普遍性:Grokking 是否只发生在toy tasks?

参考

Footnotes

  1. This document summarizes the latest theoretical advances in understanding grokking, including the representational phase transition theory (arXiv:2603.13331) and its implications for deep learning generalization.