概述
Grokking 是深度学习中一个令人困惑的现象:模型在训练损失已经接近零后很久,验证集性能仍然很差,但突然在某个时刻急剧提升。这种现象对传统的泛化理论提出了挑战。
2026 年的最新研究提供了第一性原理解释,揭示了这一现象背后的深层机制。1
1. 问题陈述
1.1 经典观察
考虑一个简单的任务:在模 算术上进行加法运算。
| 模型 | 参数量 | 训练样本 | 训练准确率 | 验证准确率 |
|---|---|---|---|---|
| Transformer | ~100K | 1000 | 99.9% | 10% |
| Transformer | ~100K | 1000 | 99.9% | 99.9% |
两个模型都能完美记忆训练数据,但泛化能力完全不同。
1.2 时间尺度问题
关键问题:为什么泛化需要这么长时间?
时间轴 ─────────────────────────────────────────────────────────▶
训练准确率 ████████████████████████████████████████████ 100%
(训练准确率) ██ 10%
↑ ↑
开始 泛化发生
训练 (Grokking点)
│ │
└────────────────────┘
~10^5 steps
传统理论预测:
- 如果网络能泛化,应该在训练过程中就泛化
- 不应该存在”延迟泛化”
2. 表示相变理论
2.1 核心假设
Grokking 可以理解为发生在表示空间中的相变。
假设 1:表示的双稳态
神经网络的表示空间可以处于两种状态:
| 状态 | 特征 | 对应能力 |
|---|---|---|
| 混乱态(Disordered) | 表征结构接近随机 | 仅能记忆训练样本 |
| 有序态(Ordered) | 表征捕获底层结构 | 能泛化到新样本 |
假设 2:相变的能垒
从混乱态到有序态存在能垒:
在训练初期,网络被困在混乱态的局部最小值中。
2.2 Kolmogorov 复杂度视角
定义:数据的 Kolmogorov 复杂度 是生成 的最短程序长度。
关键洞察:
-
记忆 vs 泛化
- 记忆:需要 比特
- 泛化:需要 比特
-
学习动态
- SGD 首先学习”节省”的状态(低复杂度表示)
- 但低复杂度表示被随机初始化的噪声”遮挡”
- 需要时间”去除噪声”才能暴露底层模式
2.3 形式化
令 为参数 对应的表示向量。
表示复杂度:
其中 是期望表示, 是表示的二阶导数。
学习动态方程:
- :复杂度梯度
- :有效学习率(随训练递减)
- :表示噪声
3. 两阶段学习模型
3.1 阶段 1:记忆阶段
持续时间:
特征:
- 训练损失 快速下降到接近零
- 验证损失 保持在高位
- 网络学习将每个训练样本映射到正确输出
表示状态:
解释:网络学习的是”查表”而非”规则”。
3.2 阶段 2:泛化阶段
持续时间:
特征:
- 表示从混乱态逐渐变得有序
- 验证损失在临界点 急剧下降
- 网络学习到底层规则
表示状态:
其中 是规则函数, 是残余噪声。
3.3 相变的临界条件
临界点 满足:
其中:
- :有效学习率
- :能垒高度
- :梯度噪声方差
预测:
- 越大, 越大
- 越大, 越小(更多探索)
- 越小, 越大(收敛越慢)
4. 证据与验证
4.1 序参量演化
通过监测表示空间的序参量可以观察到相变:
# 序参量:表示的内部分布结构
def compute_order_parameter(model, data_loader):
representations = []
for x, y in data_loader:
with torch.no_grad():
r = model.get_representation(x)
representations.append(r)
representations = torch.cat(representations, dim=0)
# 计算表示的相关矩阵
corr = torch.corrcoef(representations.T)
# 特征值分布的熵(序参量)
eigenvalues = torch.linalg.eigvalsh(corr)
eigenvalues = eigenvalues / eigenvalues.sum()
entropy = -torch.sum(eigenvalues * torch.log(eigenvalues + 1e-10))
return entropy, eigenvalues4.2 实验观察
实验 1:表示可视化
记忆阶段 泛化阶段
● ● ● ● ●
● ● ● ● ● ● ●
● ● ● ● ● ● ● ● ●
● ● ● ● ●
无明显结构 明显的聚类结构
实验 2:表示相似性
表示相似性
│
1.0 │ ┌──── grokking 点
│ /
0.8 │ /
│ /
0.6 │ /
│ /
0.4 │ /
│ /
0.2 │ /
│ /
0.0 ├──────────────────────────────▶
0 20K 40K 60K 80K steps
5. 影响 Grokking 的因素
5.1 数据集大小
数据集越小,Grokking 越明显:
| 训练样本数 | Grokking 时间 | 原因 |
|---|---|---|
| 100 | ~5K steps | 快速记忆,快速泛化 |
| 500 | ~30K steps | 中等 |
| 1000 | ~100K steps | 较长 |
| 5000 | 几乎不发生 | 充分数据 |
直觉:更多数据 = 更难记忆 = 更早开始泛化
5.2 模型大小
| 参数量 | Grokking 倾向 | 原因 |
|---|---|---|
| 小 | 不易发生 | 容量不足以记忆 |
| 中等 | 最易发生 | 容量恰好够记忆 |
| 大 | 易发生但慢 | 过度参数化 |
直觉:过度参数化是 grokking 的前提条件。
5.3 权重衰减
权重衰减系数 对 grokking 有显著影响:
泛化能力
│
高 │ ┌───────────────────★
│ /
│ / ★ 权重衰减过强
│ / │ 强制记忆,无泛化
│ /
│ /
低 │ /
│ /
│★
└────────────────────────────────▶ 权重衰减 λ
↑ 最佳点
λ_opt ≈ 10^{-4}
5.4 学习率
学习率通过影响探索-利用平衡来影响 grokking:
| 学习率 | Grokking 时间 | 解释 |
|---|---|---|
| 太小 | 长或无 | 探索不足 |
| 适中 | 适中 | 平衡 |
| 太大 | 长 | 随机性过高 |
6. 理论预测与实验验证
6.1 预测 1:Grokking 时间与能垒成正比
理论:
验证:在模 算术中, 越大,规则越复杂, 越大, 越大。
6.2 预测 2:噪声加速 Grokking
理论: 增大 → 减小
验证:添加标签噪声会加速泛化。
6.3 预测 3:表示退火
理论:从高学习率切换到低学习率(类似退火)可以加速 grokking。
验证:学习率 warmup + decay 策略有效。
7. 与其他现象的联系
7.1 Edge of Stability
Grokking 可以视为 EoS 在表示空间中的类比:
| 现象 | 空间 | 振荡/相变 |
|---|---|---|
| EoS | 参数空间 | 损失振荡 |
| Grokking | 表示空间 | 泛化相变 |
7.2 灾难性遗忘
Grokking 期间,网络学会”忘记”记忆样本的具体细节,同时”保留”规则知识。这与持续学习中的灾难性遗忘问题相关。
7.3 模式连接
Grokking 点附近的表示形成”模式连接”:不同初始化的网络在 grokking 后趋向于相似的表示。
详见 学习动态与Grokking。
8. 实践应用
8.1 促进 Grokking 的策略
# 策略 1:适当的权重衰减
weight_decay = 1e-4 # 不要太大或太小
# 策略 2:学习率 schedule
config = {
'lr_init': 1e-3,
'warmup_steps': 1000,
'decay_type': 'cosine',
'min_lr': 1e-5,
}
# 策略 3:早停要谨慎
# 不要在验证准确率还在上升时停止
early_stop_patience = 50 # 较长8.2 监测 Grokking
class GrokkingMonitor:
def __init__(self):
self.train_acc_history = []
self.val_acc_history = []
self.repr_entropy_history = []
def step(self, model, train_loader, val_loader):
# 计算准确率
train_acc = evaluate(model, train_loader)
val_acc = evaluate(model, val_loader)
# 计算表示序参量
repr_entropy = compute_order_parameter(model, train_loader)
self.train_acc_history.append(train_acc)
self.val_acc_history.append(val_acc)
self.repr_entropy_history.append(repr_entropy)
return {
'train_acc': train_acc,
'val_acc': val_acc,
'repr_entropy': repr_entropy,
}
def detect_grokking(self, threshold=0.1):
"""检测 grokking 发生"""
if len(self.val_acc_history) < 100:
return False
recent = self.val_acc_history[-100:]
if max(recent) - min(recent) > threshold:
# 检查是否在训练准确率已高后发生
if self.train_acc_history[-1] > 0.95:
return True
return False9. 开放问题
9.1 理论问题
- 能垒的来源:为什么从混乱态到有序态存在能垒?
- 临界指数:Grokking 相变的普适类是什么?
- 与 SGD 噪声的具体关系:噪声如何帮助克服能垒?
9.2 实践问题
- 预测 grokking 时间:能否在训练前预测 grokking 是否会发生?
- 加速 grokking:能否设计干预措施加速泛化?
- Grokking 的普遍性:Grokking 是否只发生在toy tasks?
参考
Footnotes
-
This document summarizes the latest theoretical advances in understanding grokking, including the representational phase transition theory (arXiv:2603.13331) and its implications for deep learning generalization. ↩