概述

Transformer 训练动力学长期以来被认为难以分析:其模型尺度、训练目标、优化器配置的差异都会导致完全不同的训练轨迹。然而,2025 年 NeurIPS Oral 论文 From Condensation to Rank Collapse: A Two-Stage Analysis of Transformer Training Dynamics(Chen & Luo, SJTU)1 给出了一个惊人的结论:Transformer 的训练过程在数学上必然经历两个不同的阶段,且两个阶段之间存在显著的相变

具体而言,作者在小初始化(Small Initialization)设定下,采用 Zhou 等人2 提出的梯度流(Gradient Flow)分析框架,系统地研究了线性化 Transformer 的训练动力学,并将其分解为两个截然不同的阶段:

  • 阶段 I(Condensation):外层参数 在随机初始化引入的非对称扰动驱动下发生有限时间 blow-up,迅速逃离小初始化区域并向任务相关方向聚集(condensation)。在这一阶段,注意力矩阵 几乎静止。
  • 阶段 II(Key-Query Rank Collapse):当外层参数进入准稳态(quasi-stationary)后,原本静默的 重新被激活,驱动归一化的键-查询矩阵发生渐近秩坍缩

这两个阶段并非相互独立的现象,而是由同一训练动力学在不同时间尺度上自然涌现的两面。论文的主要贡献可以归纳为四点1

  1. Blow-up 动力学(Theorem 1):在测度论意义上几乎所有高斯初始点都会发生有限时间 blow-up,从而消除对初始点二分性假设的依赖。
  2. Condensation 机制(Theorem 2):通过引入 Condensation Condition(Assumption 1),证明 condensation 在小初始化下必然出现。
  3. Key-Query 坍缩(Theorem 3):当外层参数进入准稳态(Assumption 2)后,键-查询矩阵发生渐近秩坍缩。
  4. 实验验证:在合成数据与 WikiText 真实任务上均观察到两阶段行为,且在小初始化下稳定出现。

本文件将系统性地展开这一理论框架,并与已有的 注意力秩崩溃谱理论Transformer Mean-Field 动力学 以及 规范表示假说与神经崩溃 等工作进行交叉对比,揭示 Transformer 训练动力学的统一图景。


1. 预备知识:Token Condensation 现象

1.1 Condensation 的形式化定义

Token Condensation(语义聚集)是指在训练早期,同语义类别的 token 表示在表示空间中聚类对齐的现象。Chen & Luo 在论文中采用更一般的方向性定义1

定义 1(Condensation)。设 是一个矩阵,其行(或列)为 。若

则称 condense 到方向

直观上,这意味着矩阵的所有行(或列)虽然尺度可能差异很大,但方向高度一致——这是一个比 rank-1 更弱的概念(仅要求方向一致,不要求尺度一致)。当方向唯一时,condensation 蕴含 rank-1 坍缩。

1.2 与 Neural Collapse 的关系

规范表示假说与神经崩溃(Papyan, Han & Donoho, 2020)描述了监督学习末期最后一层特征的类内方差趋零3。Condensation 与 Neural Collapse 都是类内方差减小的现象,但存在以下关键区别:

维度Token CondensationNeural Collapse
发生位置中间层 + 参数矩阵最后一层特征
发生时间训练早期训练末期
度量对象参数方向特征分布
类别先验可无(自监督也成立)必须监督

Chen & Luo 的工作将 Condensation 推广为参数矩阵方向性的早期收敛现象,从而将 Neural Collapse 纳入更广义的「方向性对齐」框架。

1.3 Condensation 与 Self-supervised Learning

在自监督学习中,类似的现象被称为 neural collapse in SSLdimensional collapse reduction:BYOL、SimSiam 等模型在训练早期就会让增强视图的表示高度对齐。Chen & Luo 论文中证明的 condensation 不需要监督信号——只要损失函数对参数矩阵具有乘法耦合结构(如 Transformer 的 ),condensation 就会发生。

1.4 实验观察

Condensation 现象在多种 Transformer 架构中普遍存在:

  • BERT 中间层的 attention head 在预训练早期发生 token 聚集;
  • GPT-2 不同层在训练过程中逐步形成方向一致的 query/key 投影;
  • ViT 的 patch embedding 在 ImageNet 上预训练时发生 patch-wise 聚类。

Chen & Luo 在合成 anchor function 数据集上量化了此过程:在不到 200 个训练步内,外层参数满足 condensation 条件的比例即从 0 升至接近 11


2. 预备知识:Rank Collapse 现象

2.1 Rank Collapse 的定义

Rank Collapse(秩坍缩)描述训练后期 token 表示或参数矩阵的有效秩显著下降。Chen & Luo 给出的定义为1

定义 2(渐近秩坍缩)。设 为参数矩阵。若极限

存在且 ,则称 发生 rank- 坍缩

注意此定义与 Condensation 的差异:Condensation 是方向性概念( 内积),而 Rank Collapse 是谱性概念(奇异值集中到少数几个)。

2.2 测量方法

实践中常用以下三个谱度量来检测 Rank Collapse:

稳定秩(Stable Rank)4

有效秩(Effective Rank):

核范数比(Nuclear Norm Ratio):

时,所有奇异值趋于相等(无坍缩);当 时,矩阵坍缩到 rank-

2.3 已有工作

Rank Collapse 在 Transformer 文献中有多个独立研究线索:

  1. 注意力矩阵的低秩结构 注意力秩崩溃谱理论:softmax 注意力矩阵在深层会发生特征值集中,趋向于秩-1 的均匀矩阵 5
  2. Context Aggregation 中的瓶颈:随着层数加深,token 表示会被压缩到越来越窄的子空间(与 低秩训练 相关)。
  3. Differential Transformer 的动机:通过差分注意力机制缓解注意力矩阵的秩坍缩问题5
  4. NTK 极限下的无限宽网络:在 NTK regime 下,训练不会改变参数的有效秩;但在 feature learning regime 下,rank collapse 是显著现象。

2.4 Rank Collapse 的负面后果

Rank Collapse 在训练末期会引发一系列问题:

  • 信息瓶颈:token 表示坍缩到低维子空间后,跨位置信息传递能力受限;
  • 梯度消失:当 attention 矩阵接近秩-1 时,反向传播的梯度也会相应集中;
  • 表示能力退化:分类边界变得模糊,模型难以区分细微语义差异。

Chen & Luo 的关键洞察是:这些「负面后果」其实是训练的一个内在阶段,而非单纯的退化现象——它代表着参数空间结构性的重组。


3. 两阶段理论框架

3.1 时间轴划分

Chen & Luo 的核心贡献是将训练时间轴 显式划分为两个阶段1

其中 外层参数进入准稳态的临界时刻。两个阶段的数学刻画如下:

阶段主导参数主要动力学表征度量
Stage I, , blow-up + directional convergencecosine similarity 矩阵出现块结构
Stage II, 线性 ODE 系统归一化键-查询矩阵有效秩下降

3.2 关键贡献:相变临界点

论文证明两个阶段之间存在清晰的相变临界点 。这个相变具有以下特征:

几何突变:损失景观在 附近的几何结构发生质变——从「外层参数主导的梯度场」切换到「注意力参数主导的梯度场」。

二阶导数不连续:损失函数 关于时间 的一阶导数 处连续,但二阶导数 出现明显的拐点1

尺度分离:在 时,外层参数的 Frobenius 范数 blow-up 到 ,而 始终保持 量级;在 时,外层参数方向冻结,而注意力参数开始演化。

3.3 与物理学相变的类比

Chen & Luo 的两阶段框架与物理学中的相变(一阶/二阶)有深刻的类比关系:

类比 1:一阶相变 vs 二阶相变

  • 一阶相变(如水的凝固):相变点处存在潜热(一阶导数不连续);
  • 二阶相变(如超导转变):一阶导数连续,但二阶导数(如比热)不连续6

Chen & Luo 的论文揭示的训练相变更接近二阶相变——Loss 曲线连续,但 在临界点处出现拐点。

类比 2:能量景观的两阶段弛豫
在统计物理中,复杂系统经常经历「快-慢两阶段弛豫」(fast-slow relaxation):

  • 快弛豫(quench):系统迅速下降到局部能量极小;
  • 慢弛豫(aging):系统在能量景观中缓慢漂移。

Transformer 训练的两阶段与这一物理图像惊人地一致:condensation 是快速 blow-up + 方向收敛,对应「快弛豫」;rank collapse 是缓慢的注意力重构,对应「慢弛豫」。


4. 理论分析

4.1 形式化设置

Chen & Luo 考虑一层 Transformer 模型1

其中注意力子层为:

参数矩阵 ,激活函数

小初始化设定:所有参数按 初始化,其中 控制初始化尺度7

4.2 有效动力学的导出

将损失函数 关于 进行 Taylor 展开,得到主阶形式1

通过引入凝缩方向(condensation direction):

重缩放时间 ,论文证明归一化参数 满足梯度上升形式(Proposition 1)1

并定义能量泛函

4.3 阶段 I:Condensation 速率定理

核心定理 1(Blow-up in finite time):在测度论意义上几乎所有高斯初始点都满足非退化初始化条件(Definition 4),且有效动力学会在有限时间内 blow-up1

证明基于 Riccati 型微分不等式。设 为能量函数,则

由此得到

因此当 时,

核心定理 2(Condensation):在 Condensation 条件(Assumption 1)下,有效动力学驱动 发生方向性收敛(condensation)1

Condensation 条件的形式为:

  • 对每个
  • 对每个 ,行积与列积方向对齐。

证明技巧:

  1. 几何一致性:在 Assumption 1 下,一旦在某 时刻对齐条件成立,它会在 上传播;
  2. 结构二分 的列分为凝缩类 与一致有界类
  3. 双边能量控制:通过双边估计 的上下界,凝缩指标在有限时间内占主导地位。

4.4 阶段 II:Rank Collapse 速率定理

核心定理 3(Asymptotic Rank Collapse):在阶段 II 的关键-查询动力学下,归一化的键矩阵和查询矩阵发生渐近秩坍缩。当 (见下式)具有唯一最大奇异值时,归一化矩阵渐近趋于秩-11

阶段 II 的关键-查询动力学是线性 ODE 系统

其中 由外层参数与数据决定:

由于这是一个线性 ODE 系统,其解可以显式写为矩阵指数形式:

时,矩阵指数被最大奇异值对应的特征方向主导,从而归一化矩阵趋向 rank-1——这是渐近秩坍缩。

4.5 对超参数的依赖

Chen & Luo 框架揭示了两阶段对几个关键超参数的依赖:

学习率:增大学习率会加快两阶段的过渡——blow-up 时间 为学习率),临界点 也随之提前。

初始化尺度 :更小的 会延长阶段 I 的 blow-up 时间,但会让 condensation 更彻底(方向性更强);同时让阶段 II 开始的更晚。

模型深度 :深层 Transformer 中,每层都经历一次微缩的两阶段循环——下层的 rank collapse 为上层的 condensation 提供更优的初始化。这种层级递归结构Transformer Mean-Field 动力学 中的层级相互作用高度一致8

数据规模 :在小数据下阶段 II 容易过拟合(rank collapse 过早发生);大数据下两阶段更平稳。

4.6 与 Loss Landscape 几何的关系

Chen & Luo 的两阶段框架与 损失景观的多重分形动力学9 高度相关:

  • 阶段 I 对应 Sharp Minima 区域:condensation 让外层参数快速收敛到一个狭窄的能量井;
  • 阶段 II 对应 Flat Minima 过渡:rank collapse 是损失景观几何的「软化」过程。

这一对比暗示:小初始化 两阶段相变 隐式正则化 是一个统一的因果链,与 权重衰减与损失景观 中的隐式偏好现象一致。


5. 实验验证

5.1 合成数据实验(Anchor Function)

Chen & Luo 采用 Zhang 等人10 提出的 anchor function 构造合成数据集。该数据集模拟简化的语言建模场景:每个输入是若干 token 序列,目标是预测某个锚点位置的目标值。

实验设置

  • 模型:一层 Transformer,tanh 激活,cross-entropy loss
  • 优化器:AdamW,小初始化
  • 度量:参数 cosine similarity 矩阵、Frobenius 范数变化、有效秩

三阶段观察

   Loss
    ↑
    |  \
    |   \____     /‾‾‾‾‾\
    |        \   /        \____
    |         \ /
    +———————————————————————————→ Steps
      Stage I    Stage II   Further
    (condens.)  (rank col.)  training

阶段 I 特征

  • 外层参数 的 Frobenius 范数从 迅速增长到
  • cosine similarity 矩阵出现块结构;
  • 有效秩单调下降;
  • 注意力参数 保持 不变。

阶段 II 特征

  • 外层参数方向冻结(cosine similarity );
  • 注意力参数开始主导优化;
  • 归一化的 发生 rank collapse;
  • 损失曲线进入平台期。

5.2 真实任务(WikiText)

在 WikiText(Merity et al., 2017)上,作者训练两层 Transformer(GeLU 激活 + 残差连接,省略 LayerNorm)来验证两阶段动力学的普遍性1

关键发现:即使在真实文本数据上,两阶段动力学依然稳健出现:

  • 两层都依次经历 condensation;
  • 当外层参数冻结后,注意力参数开始演化;
  • 最终归一化的 都发生 rank collapse。

这一结果验证了两阶段不是合成数据的特殊现象,而是 Transformer 训练的一般规律

5.3 训练曲线可视化

下图为示意图(基于论文 Figure 1-3 重绘):

graph LR
    A[Step 0<br/>random init] -->|Stage I<br/>condensation| B[Step t_c<br/>outer quasi-static]
    B -->|Stage II<br/>rank collapse| C[Step T<br/>final model]
    A -.->|cosine sim block| D[Stage I<br/>块结构形成]
    B -.->|effective rank drop| E[Stage II<br/>rank collapse]
    A -.->|F norm blow-up| F[Stage I<br/>||W||_F: ε → 1]

5.4 临界点检测算法

基于论文的 Assumption 2*(方向变化缓慢 + 损失变化缓慢),可以设计如下临界点检测算法

def detect_phase_transition(loss_history, param_dirs, threshold=0.01):
    """
    Detect the condensation-to-rank-collapse phase transition.
 
    Args:
        loss_history: list of loss values per step
        param_dirs: list of normalized outer parameter directions
        threshold: direction change threshold for "frozen" detection
 
    Returns:
        t_c: critical step index
    """
    n_steps = len(loss_history)
    dir_stability = []
 
    for t in range(1, n_steps):
        # Cosine similarity between consecutive normalized directions
        cos = np.dot(param_dirs[t], param_dirs[t-1])
        dir_stability.append(cos)
 
    # Critical point: when direction stability first exceeds threshold
    # AND loss has plateaued
    for t in range(1, n_steps - 1):
        if dir_stability[t] > (1 - threshold):
            # Check loss plateau: |Δloss| small
            recent_loss_change = abs(
                loss_history[t] - np.mean(loss_history[max(0, t-50):t])
            )
            if recent_loss_change < threshold:
                return t
 
    return None

6. 代码实现:训练过程两阶段度量

6.1 Token Condensation 度量

下面的 PyTorch 代码实现了训练过程中类内方差(intra-class variance)的实时度量:

import torch
import torch.nn.functional as F
 
 
def intra_class_variance(features: torch.Tensor,
                         labels: torch.Tensor,
                         eps: float = 1e-8) -> torch.Tensor:
    """
    Compute the average intra-class variance of token representations.
    Lower values indicate stronger condensation.
 
    Args:
        features: (B, N, d) token representations
        labels:   (B,) class labels
        eps:      numerical stability
 
    Returns:
        scalar: mean intra-class variance
    """
    unique_labels = torch.unique(labels)
    variances = []
 
    for c in unique_labels:
        mask = (labels == c)
        if mask.sum() < 2:
            continue
        class_feats = features[mask]             # (n_c, d)
        centroid = class_feats.mean(dim=0, keepdim=True)
        var = ((class_feats - centroid) ** 2).sum(dim=-1).mean()
        variances.append(var)
 
    if not variances:
        return torch.tensor(0.0, device=features.device)
    return torch.stack(variances).mean()
 
 
def condensation_score(features: torch.Tensor, labels: torch.Tensor) -> float:
    """
    Compute normalized condensation score in [0, 1].
    1.0 = perfect condensation, 0.0 = random.
    """
    total_var = features.var(dim=0).sum()
    intra_var = intra_class_variance(features, labels)
    # 1 - (intra / total): condensation reduces intra-class variance
    return float(1.0 - (intra_var / (total_var + 1e-8)).clamp(0, 1))

6.2 Rank Collapse 检测(基于 SVD 的有效秩)

import torch
 
 
def effective_rank(matrix: torch.Tensor, eps: float = 1e-8) -> float:
    """
    Compute the effective rank of a matrix via SVD entropy.
    Roy & Vetterli (2007), defined as exp(H(p)) where
        p_i = sigma_i / sum_j sigma_j,  H = -sum p_i log p_i
    """
    # matrix: (m, n)
    s = torch.linalg.svdvals(matrix.float())
    s = s[s > eps]
    if len(s) == 0:
        return 0.0
    p = s / s.sum()
    entropy = -(p * torch.log(p)).sum()
    return float(torch.exp(entropy))
 
 
def stable_rank(matrix: torch.Tensor, eps: float = 1e-8) -> float:
    """
    Stable rank: ||W||_F^2 / ||W||_2^2.
    More robust to noise than naive rank.
    """
    fro_sq = torch.linalg.matrix_norm(matrix, ord='fro') ** 2
    spec_sq = torch.linalg.matrix_norm(matrix, ord=2) ** 2
    return float(fro_sq / (spec_sq + eps))
 
 
class PhaseTransitionMonitor:
    """
    Track condensation + rank collapse metrics during training.
 
    Usage:
        monitor = PhaseTransitionMonitor()
        for step in range(N):
            ...
            monitor.update(step, features, labels, W_q, W_k, W_v)
        monitor.plot()
    """
 
    def __init__(self):
        self.history = {
            'step': [], 'condensation': [], 'rank_WQ': [],
            'rank_WK': [], 'rank_WV': [], 'stab_WQ': [], 'stab_WK': [],
            'stab_WV': [], 'norm_WQ': [], 'norm_WK': [], 'norm_WV': [],
        }
 
    def update(self, step, features, labels, W_q, W_k, W_v):
        with torch.no_grad():
            self.history['step'].append(step)
            self.history['condensation'].append(
                condensation_score(features, labels)
            )
            self.history['rank_WQ'].append(effective_rank(W_q))
            self.history['rank_WK'].append(effective_rank(W_k))
            self.history['rank_WV'].append(effective_rank(W_v))
            self.history['stab_WQ'].append(stable_rank(W_q))
            self.history['stab_WK'].append(stable_rank(W_k))
            self.history['stab_WV'].append(stable_rank(W_v))
            self.history['norm_WQ'].append(float(W_q.norm()))
            self.history['norm_WK'].append(float(W_k.norm()))
            self.history['norm_WV'].append(float(W_v.norm()))
 
    def detect_transition(self, threshold=0.95) -> int:
        """Heuristic: t_c is first step where outer-parameter stable rank
        stops changing AND attention-parameter stable rank starts changing."""
        sw = torch.tensor(self.history['stab_WV'])
        sa = torch.tensor(self.history['stab_WQ'])
        for t in range(1, len(sw) - 1):
            if abs(sw[t] - sw[t-1]) < 0.01 and abs(sa[t] - sa[t-1]) > 0.01:
                return t
        return -1

6.3 训练循环中的监控集成

import torch
import torch.nn as nn
 
 
def train_with_phase_monitor(model, dataloader, epochs, device='cuda'):
    monitor = PhaseTransitionMonitor()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
 
    step = 0
    for epoch in range(epochs):
        for batch in dataloader:
            inputs, labels = batch
            inputs = inputs.to(device)
            labels = labels.to(device)
 
            # Forward
            features, logits = model(inputs)
            loss = F.cross_entropy(logits, labels)
 
            # Update monitor (BEFORE optimizer step)
            with torch.no_grad():
                monitor.update(
                    step, features, labels,
                    model.attn.W_q.weight,    # (d, d)
                    model.attn.W_k.weight,    # (d, d)
                    model.attn.W_v.weight,    # (d, d)
                )
 
            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
 
            step += 1
 
        # End of epoch: check phase transition
        t_c = monitor.detect_transition()
        if t_c > 0 and step > t_c:
            print(f"[Phase II detected @ step {t_c}] "
                  f"rank_WQ={monitor.history['rank_WQ'][-1]:.2f}, "
                  f"rank_WK={monitor.history['rank_WK'][-1]:.2f}")
 
    return monitor

7. 与其他现象的关系

7.1 与 Grokking 的关系

Grokking 描述了训练损失已收敛后,测试准确率突然跃升的现象。Grokking 与本论文两阶段的关系如下:

维度Condensation → Rank CollapseGrokking
时间尺度训练早期 → 中期训练末期(持续过拟合之后)
参数变化外层 → 注意力权重从小幅度 → 大幅度
度量有效秩下降测试准确率跃升
触发条件小初始化算法任务 + 权重衰减

深层联系:Grokking 的「延迟泛化」可以被理解为第三阶段——在 rank collapse 之后,模型经历一个缓慢的「表示重组」阶段,最终找到泛化解。这一推测尚未被严格证明,但与 深度学习中的相变现象6 中描述的多种相变高度一致。

7.2 与 Neural Collapse 的关系

如 §1.2 所述,规范表示假说3 描述的 Neural Collapse 是监督学习末期的最后一层特征对齐。Chen & Luo 的 Rank Collapse 是中间阶段——它们构成一个完整的「坍缩谱」:

早期        中期          后期
condensation → rank collapse → neural collapse
(外层参数)    (注意力矩阵)    (最终特征)

这一谱系与 Transformer 中的神经崩溃理论 中的描述高度吻合。

7.3 与 NTK / Feature Learning 的关系

NTK regime(无限宽网络)下,训练不会改变参数的有效秩——参数近似静止;而 feature learning regime 下,参数显著演化,发生 rank collapse。

Chen & Luo 的小初始化设定恰好横跨两个 regime

  • 阶段 I(condensation):显著的 feature learning;
  • 阶段 II(rank collapse):特征重组主导。

这与 交替梯度流与特征学习11 中描述的「特征学习与 lazy 训练交替」现象相一致。

7.4 与 Critical Learning Periods 的关系

Critical Learning Periods(关键学习期)描述训练早期对最终性能的决定性影响——若在早期引入随机标签或大幅扰动,最终性能会显著下降。

Chen & Luo 的阶段 I 恰好对应关键学习期:condensation 确定了参数的方向,这些方向在阶段 II 中几乎不再改变。因此阶段 I 是性能的决定性阶段,这为关键学习期提供了机制性解释。

7.5 与 Lottery Ticket Hypothesis 的关系

Lottery Ticket Hypothesis(Frankle & Carlin)认为密集网络中存在稀疏子网络(「中奖彩票」),单独训练即可达到原网络性能。Chen & Luo 的两阶段框架为 lottery ticket 提供了一个可能解释:

  • 阶段 I 的 condensation 实际上是一种隐式剪枝——只有与目标方向对齐的参数存活;
  • 阶段 II 的 rank collapse 是对剩余参数的精细调整

若能找到与 condensation 方向对应的稀疏子网络,则可能加速训练——这与 低秩训练与 LoRA 的思想一致12


8. 实践意义

8.1 学习率调度

根据两阶段框架,可以设计两阶段学习率调度

阶段 I(condensation):使用较大学习率。原因:

  • 阶段 I 需要快速 blow-up 逃离小初始化;
  • 大学习率加速方向性收敛;
  • 方向对齐是粗粒度的,对学习率精确性不敏感。

阶段 II(rank collapse):使用较小学习率。原因:

  • 阶段 II 是精细调整阶段;
  • 小学习率避免 rank collapse 过快(导致信息瓶颈);
  • 配合 warm restart 可能进一步提升性能。
# Example: two-stage LR schedule based on phase detection
def two_stage_lr_schedule(base_lr, monitor, patience=100):
    t_c = monitor.detect_transition()
    if t_c < 0:
        return base_lr                  # Stage I: large LR
    return base_lr * 0.1                # Stage II: small LR

8.2 正则化设计

避免过早 Rank Collapse 的正则化策略:

  1. 谱归一化(Spectral Normalization):约束 ,延缓 rank collapse;
  2. DropAttention:在注意力矩阵上引入随机扰动,阻止过快谱集中;
  3. Differential Attention:通过差分注意力机制(谱分析视角下的秩崩溃)从结构上缓解 rank collapse5
  4. 权重衰减与两阶段:在阶段 I 使用较大权重衰减(防止方向漂移),在阶段 II 使用较小权重衰减(保护精细结构)。

8.3 早停策略

临界点附近的早停

  • 在检测到 后,可以适度继续训练让 rank collapse 完成;
  • 但若阶段 II 持续过久(rank 已接近 1),则应提前停止——继续训练只会强化信息瓶颈;
  • 一种经验策略:检测阶段 II 中 rank 下降速率,当 接近 0 时停止。

8.4 架构选择

深度对两阶段临界点的影响

  • 浅层模型():两阶段清晰, 易于检测;
  • 中等深度():每层独立经历两阶段,但跨层耦合复杂;
  • 深层模型():阶段 I 与阶段 II 可能重叠,需仔细设计初始化尺度。

宽度对两阶段的影响

  • 宽模型( 大):condensation 方向更分散,rank collapse 不完全;
  • 窄模型( 小):condensation 方向集中,rank collapse 接近 rank-1。

8.5 知识蒸馏

两阶段框架为知识蒸馏提供了新的策略:

  1. 阶段 I 蒸馏:让学生模型在早期阶段就学会教师模型的 condensation 方向;
  2. 阶段 II 蒸馏:让学生模型经历相同的 rank collapse 路径。

这种轨迹蒸馏(trajectory distillation)相比传统的输出蒸馏更能传递结构化的知识。


9. 相关工作与交叉引用

Chen & Luo 的两阶段相变框架与以下工作密切相关:

9.1 谱分析与秩崩溃

9.2 Mean-Field 与动力学

9.3 表示理论

9.4 损失景观

9.5 其他相关


10. 未来方向与开放问题

10.1 多阶段扩展

Chen & Luo 的工作主要揭示两阶段。但在实际训练中,可能存在三阶段甚至多阶段

Stage 0 (初始化) → Stage I (condensation) → Stage II (rank collapse) → Stage III (?) → ...

开放问题:是否存在第三阶段?候选现象包括:

  • 训练末期的 grokking 跃升;
  • 模型合并(model merging)中的对称破缺;
  • 持续学习中的灾难性遗忘。

10.2 不同架构的相变图谱

将两阶段框架推广到其他架构:

架构阶段 I阶段 II特殊现象
Transformercondensationrank collapse本文主题
Mamba / SSMstate matrix condensationA/B matrix rank collapse与注意力不同的相变结构
Mixture-of-Expertsrouter specializationexpert rank collapse多阶段交错
Diffusion ModelUNet condensationcross-attn rank collapse时间步相关的相变

10.3 与联邦学习 / 持续学习的关系

联邦学习:客户端的异构数据可能导致不同的 。如何聚合不同阶段的模型是一个开放问题。

持续学习:阶段切换可能与任务切换耦合——每个新任务都可能触发一次新的 condensation。

10.4 跨数据集迁移的临界点稳定性

开放问题:在数据集 A 上训练得到的 与在数据集 B 上训练得到的 是否存在稳定的比例关系?

初步证据表明 数据复杂度(如有效秩、数据维度)有强相关性,但精确关系尚未建立。

10.5 与 RLHF / 对齐的关系

RLHF 训练中通常采用三阶段:SFT → Reward Model → PPO。两阶段框架能否推广到 RLHF 的多阶段训练?

开放问题:RLHF 中每个阶段的 condensation/rank collapse 行为是什么?是否存在对齐阶段特有的相变


11. 局限性与讨论

Chen & Luo 的论文明确指出以下局限性1

  1. 二分类简化:理论分析局限于二分类任务,无法直接处理多分类或序列到序列学习。
  2. 小初始化依赖:两阶段动力学在小初始化下显著,大初始化下可能不成立。
  3. 一层模型:理论仅严格建立在一层 Transformer 上;多层情况仅通过实验验证。
  4. 渐近秩坍缩:Theorem 3 给出渐近行为,但有限时间内的精确收敛率尚未刻画。

尽管如此,这项工作仍具有重要的理论与实践价值——它为 Transformer 训练动力学提供了第一个不依赖特定任务的统一分析框架


12. 总结

Chen & Luo (NeurIPS 2025 Oral) 的 From Condensation to Rank Collapse 提出了 Transformer 训练动力学的两阶段相变框架1

阶段主导参数核心动力学关键定理关键现象
Iblow-up + directional convergenceTheorem 2 (Condensation)外层参数方向对齐
II线性 ODE → rank collapseTheorem 3 (Asymptotic rank collapse)归一化矩阵趋向 rank-1

这一框架具有以下深远意义:

  1. 理论层面:首次将 token condensation 与 rank collapse 统一为训练动力学的两个内在阶段,揭示了 Transformer 训练的几何结构。
  2. 实验层面:在合成数据与 WikiText 真实任务上都得到验证,证明两阶段动力学的普遍性。
  3. 实践层面:为学习率调度、正则化设计、早停策略、架构选择提供了机制性指导。
  4. 跨领域联系:与 神经崩溃深度学习相变Mean-Field 动力学 等建立了密切联系。

未来研究可以向多阶段扩展、跨架构推广、跨数据集迁移等方向深入,最终形成 Transformer 训练动力学的统一相变图谱


参考文献

Footnotes

  1. Chen, Z.-A. & Luo, T. (2025). From Condensation to Rank Collapse: A Two-Stage Analysis of Transformer Training Dynamics. NeurIPS 2025 Oral. arXiv:2510.06954. https://openreview.net/forum?id=gm5mkiTGOy 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

  2. Zhou, H., Zhou, Q., Luo, T., Zhang, Y. & Xu, Z.-Q. J. (2022). Towards Understanding the Condensation of Neural Networks at Initial Training. NeurIPS 2022. 该工作建立了小初始化下梯度流分析的数学框架,本文是其向 Transformer 的推广。

  3. Papyan, V., Han, X. Y. & Donoho, D. L. (2020). Prevalence of Neural Collapse during the terminal phase of deep learning training. Proceedings of the National Academy of Sciences, 117(40):24652-24663. 与本文 condensation 现象共同构成神经网络的”坍缩谱”,可参考本知识库 规范表示假说与神经崩溃 的详细分析。 2 3

  4. Vershynin, R. (2018). High-Dimensional Probability: An Introduction with Applications in Data Science. Cambridge University Press. 稳定秩与有效秩的数学基础。

  5. Mind the Gap Team (2024). Mind the Gap: Spectral Analysis of Rank Collapse and Signal Propagation in Attention Layers. arXiv:2410.07799. 与本文 rank collapse 阶段互补,从静态谱角度分析注意力矩阵。可参考 注意力秩崩溃谱理论 2 3 4

  6. Cohen, J. M., Kaur, S., Li, Y., Kolter, J. Z. & Talwalkar, A. Characterizing Possible Failure Modes in Physics-Informed ML 及本知识库 深度学习中的相变现象 中关于 Edge of Stability、Progressive Sharpening 的讨论。 2 3

  7. 小初始化(Small Initialization)作为深度学习训练理论的核心设定,已被广泛研究。本文的关键发现之一是该设定下两阶段动力学的稳健性。综述参见 Xu et al. (2025). An Overview of Condensation Phenomenon in Deep Learning. arXiv:2504.09484.

  8. 与本知识库 Transformer Mean-Field 动力学 共享数学语言——注意力作为 Wasserstein 梯度流,token 聚集对应平均场极限下的同步现象。 2

  9. 与本知识库 损失景观的多重分形动力学 相关——阶段 I 对应 sharp minima,阶段 II 对应 flat minima 过渡。 2

  10. Zhang, Z., Lin, P., Wang, Z., Zhang, Y. & Xu, Z.-Q. J. (2024). Anchor Function: A Type of Benchmark Functions for Studying Language Models. arXiv:2401.08309. Chen & Luo 在合成实验中采用 anchor function 作为可控的基准任务。

  11. 与本知识库 交替梯度流与特征学习 中描述的”lazy 训练与特征学习交替”现象高度一致——本论文的两阶段是其在 Transformer 中的具体实例。 2

  12. 本知识库 低秩训练与 LoRA 之外 讨论了 rank collapse 与参数高效微调的桥梁。本文的渐近秩坍缩定理为 LoRA 类方法的有效性提供了理论支撑。