Approximate Normalization Transformer (anTransformer)

近似归一化Transformer(Approximately Normalized Transformer,简称 anTransformer)是一种创新的Transformer架构改进方法,通过标量乘法实现近似归一化,从而在不显著增加计算开销的情况下加速训练收敛。该方法由 Franke 等人在 NeurIPS 2025 上提出(arXiv:2505.22014),核心思想是利用高维随机向量范数的**紧密集中(concentration)**现象,以常数标量替代昂贵的逐元素归一化操作。

1. Introduction: Why Normalization Matters

归一化技术的核心作用

Normalization(归一化)技术是现代深度学习,尤其是 Transformer 架构成功的关键因素之一。12 传统观点认为归一化主要解决**内部协变量偏移(Internal Covariate Shift)**问题,但 anTransformer 的研究者提出了更深层次的理解:

归一化的两大核心效益:

  1. 输入尺度一致性(Input Scale Consistency)

    • 归一化确保每层接收的输入具有一致的尺度
    • 这允许选择更合适的(全局)学习率
    • 可类比为对优化问题的对角线预处理(diagonal preconditioning)
  2. 防止残差流范数膨胀(Preventing Residual Norm Growth)

    • 在标准 Transformer 中,残差连接上的表示范数会随层数增长
    • 这导致深层需要显著放大输出幅度才能保持影响力
    • Sun 等人将这种现象命名为 “深度诅咒”(Curse of Depth)3

形式化分析:为什么归一化能改善优化?

考虑一个简单的优化问题:

使用学习率 进行梯度下降:

每个分量 的收敛速度由收缩因子 决定。最优学习率满足:

其中 是条件数。当 (即所有特征值相等)时,收敛最快。

归一化的作用:通过对每个坐标进行归一化(乘以 ),使所有 ,从而可以选择最优学习率。

2. The Curse of Depth Problem

问题的数学描述

在标准 Transformer 中,假设每个层 对残差连接的贡献是独立的随机向量 ,且满足:

则第 层的有效贡献需要满足:

因此:

这意味着 深层输出的期望范数随深度增长,导致:

  1. 梯度不稳定(梯度尺度随深度指数变化)
  2. 深层贡献被稀释(需要更大缩放才能覆盖累积残差)
  3. 训练需要仔细的学习率预热(warm-up)

深度诅咒的实证证据

Figure: Layer-wise input norm growth in vanilla GPT

上图展示了在 0.5B 模型训练 10B tokens 过程中,各层输入范数的对数变化。可以看到:

  • Vanilla GPT:深层获得越来越大的输入范数(指数增长)
  • nGPT:完全消除了这个问题
  • anGPT:有效缓解了这个问题

相关工作:解决深度诅咒

方法策略局限性
Pre-LN在残差分支前归一化输出方差随深度指数增长
Post-LN在残差分支后归一化需要学习率预热
DeepNorm残差缩放 + 初始化调整约束模型更新为常数
LayerNorm Scaling 缩放 LayerNorm 输出需要修改归一化层
nGPT所有表示在超球面上需要额外归一化层,增加推理开销

3. Tight Concentration of High-Dimensional Random Vectors

测度集中现象(Concentration of Measure)

这是 anTransformer 理论的核心洞察。在高维空间中,随机向量的范数高度集中在其期望值附近——这被 anTransformer 的作者称为”维度的祝福”(Blessing of Dimensionality)。4

定理 1(球面上 Lipschitz 函数的集中性)

为单位球面上均匀分布的随机向量, 为 Lipschitz 函数。则对于任意

其中 为常数, 的 Lipschitz 常数。

核心含义:在高维空间 中, 以指数速度集中在其期望值附近!

应用于神经网络

在 anTransformer 中,考虑函数 ,其中 是网络组件(如前馈层)的输出。

假设条件

  1. 输入 已归一化且近似均匀分布在单位球面上
  2. 是 Lipschitz 的(如线性映射)

则输出范数 会集中在其期望值附近:

这意味着我们可以使用常数 近似归一化因子

4. Approximate Normalization via Scalar Multiplication

核心思想

传统归一化需要计算:

anTransformer 提出用输入无关的归一化因子 替代:

这样归一化就变成了简单的标量乘法,计算开销从 降到了

各组件的归一化因子推导

线性映射

假设:

  • ,且
  • ,权重按输入维度归一化:

则:

因此归一化因子为:

残差更新(LERP)

nGPT 提出用**线性插值(LERP)**替代经典残差更新:

对于 的情况:

归一化因子:

激活函数

由于激活函数的非线性, 需要数值计算。对于 SiLU/Swish 激活:

(通过蒙特卡洛估计)

完整的归一化因子表

组件归一化因子 说明
QKV 映射: 模型维度, : 头维度
输出映射-
FFN 上投影SwiGLU 门控
FFN 下投影-
激活函数SiLU/Swish
残差 LERP 为可学习插值参数

5. Removing Warm-up Requirements

为什么传统 Transformer 需要 Warm-up?

标准 Transformer(尤其是 Post-LN)需要学习率预热的原因:

  1. 初始化时梯度不稳定:未归一化的表示导致早期层梯度可能爆炸
  2. 参数尺度不一致:不同层参数的尺度差异导致统一学习率难以选择

anTransformer 如何消除 Warm-up?

anTransformer 通过以下设计消除了对 warm-up 的需求:

1. 参数归一化约束

对所有权重矩阵的输入维度进行归一化:

这确保了线性映射的输出范数有界。

2. 近似归一化表示

通过标量乘法近似归一化,确保表示范数围绕常数集中。

3. 优化器动态改善

实验表明,anTransformer 中 Adam 的一阶矩(momentum)方差更稳定:

Figure: Variance of Adam first moment

图中显示 anGPT 的 Adam 一阶矩方差在训练过程中更加稳定。

实验验证

在论文的实验中,anGPT 和 nGPT 均不使用学习率预热,而 GPT+ 基线使用 10% 训练步数的 warm-up。尽管如此,anGPT 仍能达到更快的收敛速度。

6. Removing Weight Decay Requirements

权重衰减的传统作用

传统训练中,权重衰减(Weight Decay)主要用于:

  1. 防止权重过大:避免数值不稳定
  2. 正则化:减少过拟合
  3. 维持表示尺度:间接控制激活值范围

anTransformer 的参数约束策略

anTransformer 对参数实施范数约束而非正则化:

关键区别

方法策略效果
Weight Decay 正则化项软约束,权重可无限增长
范数约束硬约束,权重有界

为什么不需要权重衰减?

  1. 表示被约束在紧致空间:由于参数有界,激活值不会无限增长
  2. 归一化稳定梯度:反向传播时梯度尺度更一致
  3. 参数尺度均匀:所有参数组具有相似的量级,优化器行为更可预测

参数重参数化

为了确保均匀的优化动态,anTransformer 对可学习缩放参数 (如 )采用重参数化:

其中 是优化变量,初始化为 ,确保

有效学习率:

7. Scaling Law Implications

Chinchilla 缩放定律回顾

标准 Chinchilla 缩放定律5指出:

其中 是模型参数量, 是训练 token 数。

anGPT 的缩放特性

最优批量大小缩放

论文发现,anGPT 的最优批量大小随模型规模的变化趋势与标准 GPT 一致:

这与 Chinchilla 的发现高度吻合。

最优学习率缩放

关键优势:支持更大的批量大小

anGPT 可以在不损失性能的情况下使用更大的批量大小!

这意味着:

  • 更少的训练步数
  • 更高的 GPU 并行效率
  • 潜在的更短训练时间

缩放趋势拟合

Figure: Scaling trends for optimal batch size and learning rate

上图展示了跨多个模型规模(从 46M 到 1B 参数)的最优超参数拟合。anGPT 保持了与标准 GPT 相似的缩放特性,同时支持更大的批量大小。

8. Experimental Results

实验设置

配置详情
数据集SlimPajama (~627B tokens)
TokenizerGPT-NeoX (50k 词汇量)
上下文窗口2048 tokens
训练轮次<1 epoch
基线模型GPT+ (GPT + SwiGLU + RoPE + RMSNorm + QK-Norm)

核心结果

收敛速度提升

anGPT 相比 GPT+ 实现高达 40% 的收敛加速!

具体而言:

  • 在相同的训练步数下,anGPT 达到更低的验证损失
  • 达到相同损失水平,anGPT 需要的训练步数减少 40%

计算开销

模型相对训练时间
GPT+1.00× (基线)
nGPT~1.10×
anGPT1.03×

anGPT 的额外计算开销仅为 3%,远低于 nGPT。

推理效率

anGPT 预期推理时间与 GPT+ 相同!

原因:推理时可以将常数归一化因子吸收到模型参数中,无需额外归一化层。

与 nGPT 的对比

特性nGPTanGPT
归一化方式沿残差维度严格归一化近似归一化(标量乘法)
额外归一化层需要不需要
推理开销增加无增加
收敛速度较快相当或更快
参数约束严格在超球面紧致空间(有界)

不同模型规模的表现

在 46M 到 1B 参数的范围内,anGPT 均一致性地优于 GPT+,且与 nGPT 表现相当或更好。

9. Implementation

架构对比

以下是 GPT+、nGPT 和 anGPT 的架构对比表:

组件GPT+nGPTanGPT
注意力层
Pre-NormRMSNormnorm + LERPnorm + LERP
QKV 映射
QK-Norm额外 RMSNorm--
注意力权重
输出映射
Post-NormRMSNormLERP + normLERP + norm
FFN 层
门控 + RMSNorm
激活SiLUSiLUSiLU
上投影
下投影 + RMSNorm
Post-NormRMSNormLERP + normLERP + norm

代码实现示例

以下是一个简化的 anGPT 层实现:

import torch
import torch.nn as nn
import math
 
class ApproxNormalizedAttention(nn.Module):
    """
    Approximate Normalized Attention Layer
    
    Key modifications:
    1. Remove all normalization layers (RMSNorm, QK-Norm)
    2. Add constant normalization factors via scalar multiplication
    3. Use LERP residual update instead of standard addition
    """
    
    def __init__(self, d_model, n_heads, has_bias=True):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        
        # Normalization factors (derived from theory)
        self.nu_qkv = math.sqrt(d_model / self.d_head)  # QKV mapping
        self.nu_proj = math.sqrt(self.d_head / d_model)  # Output projection
        
        # Learnable parameters (constrained to bounded space)
        self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=has_bias)
        self.out_proj = nn.Linear(d_model, d_model, bias=has_bias)
        
        # LERP interpolation parameter (learnable, bounded)
        self.alpha_A = nn.Parameter(torch.tensor(0.1))
        
        # Normalize QKV and output projection weights
        self._normalize_weights()
    
    def _normalize_weights(self):
        """Normalize weights along input dimension"""
        with torch.no_grad():
            # Normalize QKV projection
            w = self.qkv_proj.weight
            norms = torch.norm(w, dim=1, keepdim=True)  # Along output dim
            self.qkv_proj.weight.div_(norms.clamp(min=1.0))
            
            # Normalize output projection
            w = self.out_proj.weight
            norms = torch.norm(w, dim=1, keepdim=True)
            self.out_proj.weight.div_(norms.clamp(min=1.0))
    
    def forward(self, x, mask=None):
        """
        Args:
            x: Input tensor [batch, seq_len, d_model]
            mask: Optional attention mask
        """
        B, T, C = x.size()
        
        # Normalize input (L2 norm)
        x_norm = x / (x.norm(dim=-1, keepdim=True) + 1e-8)
        
        # QKV projection with normalization factor
        qkv = self.qkv_proj(x_norm) * self.nu_qkv
        q, k, v = qkv.split(self.d_model, dim=-1)
        
        # Reshape for multi-head attention
        q = q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        
        # Scaled dot-product attention
        scale = math.sqrt(self.d_head)
        attn = (q @ k.transpose(-2, -1)) / scale
        
        if mask is not None:
            attn = attn.masked_fill(mask == 0, float('-inf'))
        
        attn = torch.softmax(attn, dim=-1)
        
        # Apply attention to values
        out = attn @ v
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        
        # Output projection with normalization factor
        out = self.out_proj(out) * self.nu_proj
        
        # LERP residual update
        alpha = torch.sigmoid(self.alpha_A)  # Bounded in (0, 1)
        out = x_norm + alpha * (out - x_norm)
        
        # Final L2 normalization
        out = out / (out.norm(dim=-1, keepdim=True) + 1e-8)
        
        return out
 
 
class ApproxNormalizedFFN(nn.Module):
    """
    Approximate Normalized Feed-Forward Network with SwiGLU
    """
    
    # Normalization factors
    NU_UZ = 0.5  # sqrt(d_model / (4 * d_model))
    NU_D = 2.0   # sqrt((4 * d_model) / d_model)
    NU_ACTIVATION = 3.74  # Approximate via Monte Carlo
    
    def __init__(self, d_model, d_ff=None, has_bias=True):
        super().__init__()
        d_ff = d_ff or 4 * d_model
        self.d_model = d_model
        self.d_ff = d_ff
        
        # Gate and up projection (SwiGLU)
        self.gate_proj = nn.Linear(d_model, d_ff, bias=has_bias)
        self.up_proj = nn.Linear(d_model, d_ff, bias=has_bias)
        
        # Down projection
        self.down_proj = nn.Linear(d_ff, d_model, bias=has_bias)
        
        # LERP interpolation parameter
        self.alpha_F = nn.Parameter(torch.tensor(0.1))
        
        # Normalize weights
        self._normalize_weights()
    
    def _normalize_weights(self):
        """Normalize weights along input dimension"""
        with torch.no_grad():
            for proj in [self.gate_proj, self.up_proj, self.down_proj]:
                w = proj.weight
                norms = torch.norm(w, dim=1, keepdim=True)
                proj.weight.div_(norms.clamp(min=1.0))
    
    def forward(self, x):
        """
        Args:
            x: Input tensor [batch, seq_len, d_model]
        """
        # Normalize input
        x_norm = x / (x.norm(dim=-1, keepdim=True) + 1e-8)
        
        # SwiGLU: gate * SiLU(up)
        gate = self.gate_proj(x_norm) * self.NU_UZ
        up = self.up_proj(x_norm) * self.NU_UZ
        x_ff = torch.nn.functional.silu(gate) * up * self.NU_ACTIVATION
        
        # Down projection
        out = self.down_proj(x_ff) * self.NU_D
        
        # LERP residual update
        alpha = torch.sigmoid(self.alpha_F)
        out = x_norm + alpha * (out - x_norm)
        
        # Final L2 normalization
        out = out / (out.norm(dim=-1, keepdim=True) + 1e-8)
        
        return out
 
 
class ApproxNormalizedTransformerBlock(nn.Module):
    """
    Complete anTransformer Block
    """
    
    def __init__(self, d_model, n_heads, d_ff=None):
        super().__init__()
        self.attn = ApproxNormalizedAttention(d_model, n_heads)
        self.ffn = ApproxNormalizedFFN(d_model, d_ff)
    
    def forward(self, x, mask=None):
        # Attention block
        x = self.attn(x, mask)
        # FFN block
        x = self.ffn(x)
        return x

训练配置对比

# GPT+ Configuration (requires warm-up and weight decay)
gpt_plus_config = {
    "learning_rate": 1e-4,
    "warmup_ratio": 0.1,  # 10% of total steps
    "weight_decay": 0.1,
    "use_cosine_lr": True,
}
 
# anGPT Configuration (no warm-up, no weight decay)
# Key difference: fewer hyperparameters to tune!
angpt_config = {
    "learning_rate": 1e-4,  # Same range works well
    "warmup_ratio": 0.0,    # NOT needed!
    "weight_decay": 0.0,     # NOT needed!
    "use_cosine_lr": True,
    
    # Additional anGPT-specific parameters
    "alpha_init": 0.01,     # For scaling parameters
    "param_norm_bound": 1.0, # Max norm for parameters
}

10. Conclusion and Future Directions

核心贡献总结

  1. 理论贡献:利用高维随机向量范数紧密集中现象,为近似归一化提供了理论依据
  2. 架构贡献:提出 anTransformer,通过标量乘法替代昂贵的归一化操作
  3. 实践贡献:大幅减少训练超参数(无需 warm-up、无需 weight decay)
  4. 效率贡献:仅增加 3% 训练时间,推理时间无额外开销

与相关工作的关系

方法核心思想anGPT 借鉴
RMSNorm统计归一化✓ 完全移除
Normalized GPT (nGPT)超球面表示✓ LERP 更新
Chinchilla Scaling最优资源分配✓ 保持缩放特性

开放问题

  1. 更深的网络:anGPT 能否有效训练超过 1000 层的 Transformer?
  2. 多模态扩展:在视觉 Transformer 等非文本任务上的表现如何?
  3. 与其他技术结合:与 MoE、Flash Attention 等的兼容性

参考资源

  • 原始论文: arXiv:2505.22014 - Learning in Compact Spaces with Approximately Normalized Transformer (NeurIPS 2025)
  • 作者: Jörg K.H. Franke, Urs Spiegelhalter, Marianna Nezhurina, Jenia Jitsev, Frank Hutter, Michael Hefenbrock

References

Footnotes

  1. Ba, J. L., Kiros, J. R., & Hinton, G. E. (2016). Layer Normalization. arXiv:1607.06450.

  2. Xiong, R., Yang, Y., He, D., et al. (2020). On Layer Normalization in the Pre-Training Transformer. ICML 2020.

  3. Sun, Y., et al. (2025). The Curse of Depth in Large Language Models. NeurIPS 2025.

  4. Vershynin, R. (2018). High-Dimensional Probability: An Introduction with Applications in Data Science. Cambridge University Press.

  5. Hoffmann, J., Borgeaud, S., Mensch, A., et al. (2022). Training Compute-Optimal Large Language Models. NeurIPS 2022.