1. 引言

神经网络训练过程中会涌现出各种有趣的现象:神经崩溃(Neural Collapse)观察到最后一层特征的类内方差趋近于零;维度崩溃(Dimension Collapse)观察到表示维度被压缩。这些现象看似独立,但规范表示假说(Canonical Representation Hypothesis, CRH)提出了一个统一的解释框架。1

CRH认为,在训练过程中,神经网络的表示(Representations)、权重(Weights)和神经元梯度(Neuron Gradients)会通过一组对齐方程相互关联。这种对齐源于梯度噪声(扩张表示)和权重衰减(收缩权重)之间的平衡——类似于物理中的涨落-耗散定理

2. 问题的数学形式化

2.1 设置

考虑一个 层神经网络,第 层:

其中:

  • :第 层的激活
  • :权重矩阵
  • :输出

2.2 梯度定义

定义神经元梯度

其中:

  • :反向传播到激活 的梯度
  • :反向传播到输入 的梯度

2.3 协方差矩阵

定义三组协方差矩阵:

其中 是权重的Gram矩阵。

3. 六种对齐关系

CRH提出了六个对齐关系,分为前向(Forward)和后向(Backward)两组:

3.1 前向对齐关系

3.1.1 RGA:表示-梯度对齐

表示-梯度对齐(Representation-Gradient Alignment, RGA):

即第 层激活的协方差与该层梯度的协方差成正比。

物理直觉:如果某个神经元方向上有大的激活方差(表示),那么该方向上也应该有大的梯度方差(学习信号)。

3.1.2 RWA:表示-权重对齐

表示-权重对齐(Representation-Weight Alignment, RWA):

即激活协方差与权重Gram矩阵成正比。

物理直觉:如果某个方向上有大的表示方差,那么沿着该方向的权重也应该有大的范数。

3.1.3 GWA:梯度-权重对齐

梯度-权重对齐(Gradient-Weight Alignment, GWA):

即梯度协方差与权重Gram矩阵成正比。

3.2 后向对齐关系

类似的三个关系定义在反向传播方向上:

3.3 对齐关系的几何解释

对齐关系几何意义
激活主方向 = 梯度主方向
激活主方向 = 权重主方向
梯度主方向 = 权重主方向

当所有对齐关系成立时,,即三者的主方向完全一致

4. 噪声-正则化平衡理论

4.1 问题的物理类比

神经网络的训练动态可以类比为统计力学系统

  • 梯度噪声:类比于温度,倾向于扩大系统熵(增加表示多样性)
  • 权重衰减:类比于势能,倾向于收缩系统(使权重趋向零点)

在平衡态,系统自由能最小化:

其中 是”温度”(梯度噪声强度), 是熵(表示多样性)。

4.2 平稳性条件

定理 4.1(平稳性条件):在训练的平稳态,有:

其中 是常数。

物理解释:权重扩展(通过梯度噪声)与权重收缩(通过权重衰减)达到平衡,导致协方差矩阵的特定对齐模式。

4.3 关键发现

推论:在局部最小值点,有:

这正是CRH预测的对齐关系!

5. 多项式对齐假说

5.1 打破精确对齐

在实际训练中,精确对齐很少成立。CRH提出了多项式对齐假说(Polynomial Alignment Hypothesis, PAH):

PAH预测:当对齐被打破时,关系变成幂律形式:

5.2 相位分类

相位前向关系后向关系指数
5
6
8

5.3 实验验证

研究团队在多个任务上观察到幂律对齐:

  • CIFAR-10/100分类
  • Transformer语言模型
  • 扩散模型

所有观察到的指数都落在 范围内,精确匹配PAH的理论预测

6. 与神经崩溃的联系

6.1 神经崩溃现象

神经崩溃(Neural Collapse, NC)描述了训练末期的四个阶段:2

  • NC1:类内变异性崩溃 →
  • NC2:类均值形成等角紧框架(ETF)
  • NC3:分类器与特征之间自对偶对齐
  • NC4:最近类中心(NCC)分类器优于线性分类器

6.2 CRH → NC的归约

定理 6.1(神经崩溃归约):在分类器的特殊情况下,CRH等价于神经崩溃(NC1-NC4)。

证明概要

  1. 对于分类器 ,输出激活 的协方差分解为:
  1. CRH的对齐条件 迫使:
  1. 自对偶性自动满足 (NC3)

6.3 神经特征ansatz

CRH与神经特征ansatz(NFA)有关:

GWA是对NFA的等变修正——它考虑了基变换的不变性。

7. 统一理论框架

7.1 CRH主定理

定理 7.1(规范表示定理)

  1. 方向冗余:如果任意两个前向对齐关系成立,则所有前向对齐关系成立
  2. 后向冗余:类似地
  3. 相互蕴含:如果一组前向对齐和一组后向对齐成立,则所有对齐成立

7.2 统一相图

CRH预测的训练动态相图:

                    GWA
                      ↑
     Phase 8 --------+--------> Phase 6
      |              |              |
      |              |              |
RWA   |        Phase 5         RGA  |
      |              |              |
      ↓              |              ↓
     Phase 4 <-------+--------> Phase 2
                      ↓
                   Phase 1
                   (初始态)

7.3 与其他现象的联系

现象CRH解释
神经崩溃末端层的完美对齐
维度崩溃低秩表示( 的有效秩降低)
特征复用跨层对齐(
尖锐极小值对齐度高 → 曲率大

8. 理论深度:涨落-耗散定理

8.1 形式类比

神经网络的训练动态与热力学系统有深刻联系:

热力学神经网络
自由能 损失
温度(涨落强度)梯度噪声方差
熵(无序程度)协方差矩阵的秩
平衡态局部最小值

8.2 涨落-耗散定理

在平衡态附近:

其中 是响应函数, 是涨落力。

类比到神经网络:

其中 是权重衰减系数。

8.3 深层含义

涨落-耗散框架揭示了:

  • 为什么神经网络存在普遍的对齐现象
  • 为什么不同架构(CNN、Transformer、MLP)都表现出类似的对齐
  • 如何通过调整超参数(、学习率)控制对齐程度

9. 实践意义

9.1 权重衰减的作用

CRH揭示了权重衰减的深层作用:

  • 强权重衰减(大 ):强制 小 → 对齐度低 → 表示更分散
  • 弱权重衰减(小 ):允许 大 → 对齐度高 → 表示更集中

9.2 学习率的影响

学习率与权重衰减的相互作用:

  • 高学习率 + 低权重衰减:噪声主导 → 保持对齐
  • 低学习率 + 高权重衰减:正则化主导 → 保持对齐
  • 最优平衡:对于特定任务,存在最优的 组合

9.3 早停策略

CRH为早停提供了理论基础:

  • 过度训练可能导致过度对齐(表示崩溃)
  • 最佳停止点对应于对齐度达到任务特定的阈值

9.4 代码实现

import torch
import torch.nn as nn
from torch.nn.utils import parameters_to_vector, vector_to_parameters
import numpy as np
 
def compute_covariance_matrix(tensor, batch_first=True):
    """
    计算张量的协方差矩阵
    
    Args:
        tensor: (n, d) 或 (d, n) 的张量
        batch_first: 是否为 (batch, dim) 格式
    
    Returns:
        cov: (d, d) 协方差矩阵
    """
    if batch_first:
        tensor = tensor.T  # -> (dim, batch)
    
    mean = tensor.mean(dim=1, keepdim=True)
    centered = tensor - mean
    cov = (centered @ centered.T) / (tensor.shape[1] - 1)
    return cov
 
def check_alignment(model, dataloader, device='cuda'):
    """
    检查模型的对齐关系
    
    Returns:
        metrics: dict,包含各种对齐度量
    """
    model.eval()
    
    # 收集激活和梯度
    activations = {}
    gradients = {}
    
    def hook_forward(name):
        def hook(module, input, output):
            activations[name] = output.detach()
        return hook
    
    def hook_backward(name):
        def hook(module, grad_input, grad_output):
            gradients[name] = grad_output[0].detach()
        return hook
    
    # 注册hook
    handles = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            handles.append(module.register_forward_hook(hook_forward(name)))
            handles.append(module.register_full_backward_hook(hook_backward(name)))
    
    # 前向+反向传播
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        output = model(x)
        loss = nn.functional.cross_entropy(output, y)
        loss.backward()
        break
    
    # 清理hook
    for h in handles:
        h.remove()
    
    # 计算对齐度量
    metrics = {}
    for name in activations:
        if name in gradients:
            H = compute_covariance_matrix(activations[name])
            G = compute_covariance_matrix(gradients[name])
            
            # Frobenius范数相似度
            H_normalized = H / torch.norm(H)
            G_normalized = G / torch.norm(G)
            
            alignment = torch.norm(H_normalized * G_normalized).item()
            metrics[f'{name}_alignment'] = alignment
    
    return metrics
 
def visualize_alignment_evolution(model, train_loader, epochs=100):
    """
    可视化对齐关系的演化
    """
    history = {'alignment': []}
    
    for epoch in range(epochs):
        metrics = check_alignment(model, train_loader)
        history['alignment'].append(metrics)
        
        # 训练一步
        model.train()
        for x, y in train_loader:
            x, y = x.to('cuda'), y.to('cuda')
            output = model(x)
            loss = nn.functional.cross_entropy(output, y)
            loss.backward()
            break
    
    # 绘图
    import matplotlib.pyplot as plt
    
    fig, ax = plt.subplots(figsize=(10, 6))
    for layer_name in history['alignment'][0].keys():
        values = [h[layer_name] for h in history['alignment']]
        ax.plot(values, label=layer_name)
    
    ax.set_xlabel('Training Step')
    ax.set_ylabel('Alignment Score')
    ax.set_title('Alignment Evolution During Training')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.savefig('alignment_evolution.png', dpi=150)
    plt.show()
 
class AlignmentRegularizer(nn.Module):
    """
    对齐正则化器:鼓励表示-权重对齐
    """
    def __init__(self, alpha=0.1):
        super().__init__()
        self.alpha = alpha
    
    def forward(self, h, W):
        """
        Args:
            h: (batch, d) 激活
            W: (d, d) 权重矩阵
        Returns:
            loss: 对齐损失
        """
        # 计算协方差
        H = compute_covariance_matrix(h)
        Z = W @ W.T
        
        # 对齐损失:H 和 Z 的主方向应该对齐
        H_normalized = H / (torch.norm(H) + 1e-8)
        Z_normalized = Z / (torch.norm(Z) + 1e-8)
        
        # 使用负余弦相似度
        alignment_loss = -torch.sum(H_normalized * Z_normalized)
        
        return self.alpha * alignment_loss

10. 总结

规范表示假说(CRH)提供了一个统一框架来理解神经网络训练中的表示演化:

  1. 六个对齐关系:RGA、RWA、GWA及其后向版本
  2. 物理根源:梯度噪声与权重衰减的平衡(涨落-耗散定理)
  3. 多项式对齐:打破精确对齐时呈现幂律关系
  4. 神经崩溃归约:CRH蕴含神经崩溃的所有现象
  5. 实践指导:权重衰减、学习率、早停的理论基础

CRH的深层意义在于揭示了神经网络作为物理系统的一面——统计力学的概念框架可以有效地描述深度学习的训练动态。


Footnotes

  1. Liu, Y., et al. “Canonical Representation Hypothesis: Unified Theory of Neural Collapse and Alignment.” arXiv 2025. https://arxiv.org/abs/2410.03006

  2. 本理论与隐式正则化ResNet动态系统理论密切相关。