Approximate Normalization Transformer (anTransformer)
近似归一化Transformer(Approximately Normalized Transformer,简称 anTransformer)是一种创新的Transformer架构改进方法,通过标量乘法实现近似归一化,从而在不显著增加计算开销的情况下加速训练收敛。该方法由 Franke 等人在 NeurIPS 2025 上提出(arXiv:2505.22014),核心思想是利用高维随机向量范数的**紧密集中(concentration)**现象,以常数标量替代昂贵的逐元素归一化操作。
1. Introduction: Why Normalization Matters
归一化技术的核心作用
Normalization(归一化)技术是现代深度学习,尤其是 Transformer 架构成功的关键因素之一。12 传统观点认为归一化主要解决**内部协变量偏移(Internal Covariate Shift)**问题,但 anTransformer 的研究者提出了更深层次的理解:
归一化的两大核心效益:
-
输入尺度一致性(Input Scale Consistency)
- 归一化确保每层接收的输入具有一致的尺度
- 这允许选择更合适的(全局)学习率
- 可类比为对优化问题的对角线预处理(diagonal preconditioning)
-
防止残差流范数膨胀(Preventing Residual Norm Growth)
- 在标准 Transformer 中,残差连接上的表示范数会随层数增长
- 这导致深层需要显著放大输出幅度才能保持影响力
- Sun 等人将这种现象命名为 “深度诅咒”(Curse of Depth)3
形式化分析:为什么归一化能改善优化?
考虑一个简单的优化问题:
使用学习率 进行梯度下降:
每个分量 的收敛速度由收缩因子 决定。最优学习率满足:
其中 是条件数。当 (即所有特征值相等)时,收敛最快。
归一化的作用:通过对每个坐标进行归一化(乘以 ),使所有 ,从而可以选择最优学习率。
2. The Curse of Depth Problem
问题的数学描述
在标准 Transformer 中,假设每个层 对残差连接的贡献是独立的随机向量 ,且满足:
则第 层的有效贡献需要满足:
因此:
这意味着 深层输出的期望范数随深度增长,导致:
- 梯度不稳定(梯度尺度随深度指数变化)
- 深层贡献被稀释(需要更大缩放才能覆盖累积残差)
- 训练需要仔细的学习率预热(warm-up)
深度诅咒的实证证据

上图展示了在 0.5B 模型训练 10B tokens 过程中,各层输入范数的对数变化。可以看到:
- Vanilla GPT:深层获得越来越大的输入范数(指数增长)
- nGPT:完全消除了这个问题
- anGPT:有效缓解了这个问题
相关工作:解决深度诅咒
| 方法 | 策略 | 局限性 |
|---|---|---|
| Pre-LN | 在残差分支前归一化 | 输出方差随深度指数增长 |
| Post-LN | 在残差分支后归一化 | 需要学习率预热 |
| DeepNorm | 残差缩放 + 初始化调整 | 约束模型更新为常数 |
| LayerNorm Scaling | 按 缩放 LayerNorm 输出 | 需要修改归一化层 |
| nGPT | 所有表示在超球面上 | 需要额外归一化层,增加推理开销 |
3. Tight Concentration of High-Dimensional Random Vectors
测度集中现象(Concentration of Measure)
这是 anTransformer 理论的核心洞察。在高维空间中,随机向量的范数高度集中在其期望值附近——这被 anTransformer 的作者称为”维度的祝福”(Blessing of Dimensionality)。4
定理 1(球面上 Lipschitz 函数的集中性):
设 为单位球面上均匀分布的随机向量, 为 Lipschitz 函数。则对于任意 :
其中 为常数, 是 的 Lipschitz 常数。
核心含义:在高维空间 中, 以指数速度集中在其期望值附近!
应用于神经网络
在 anTransformer 中,考虑函数 ,其中 是网络组件(如前馈层)的输出。
假设条件:
- 输入 已归一化且近似均匀分布在单位球面上
- 是 Lipschitz 的(如线性映射)
则输出范数 会集中在其期望值附近:
这意味着我们可以使用常数 近似归一化因子 !
4. Approximate Normalization via Scalar Multiplication
核心思想
传统归一化需要计算:
anTransformer 提出用输入无关的归一化因子 替代:
这样归一化就变成了简单的标量乘法,计算开销从 降到了 !
各组件的归一化因子推导
线性映射
假设:
- ,且
- ,权重按输入维度归一化:
则:
因此归一化因子为:
残差更新(LERP)
nGPT 提出用**线性插值(LERP)**替代经典残差更新:
对于 的情况:
归一化因子:
激活函数
由于激活函数的非线性, 需要数值计算。对于 SiLU/Swish 激活:
(通过蒙特卡洛估计)
完整的归一化因子表
| 组件 | 归一化因子 | 说明 |
|---|---|---|
| QKV 映射 | : 模型维度, : 头维度 | |
| 输出映射 | - | |
| FFN 上投影 | SwiGLU 门控 | |
| FFN 下投影 | - | |
| 激活函数 | SiLU/Swish | |
| 残差 LERP | 为可学习插值参数 |
5. Removing Warm-up Requirements
为什么传统 Transformer 需要 Warm-up?
标准 Transformer(尤其是 Post-LN)需要学习率预热的原因:
- 初始化时梯度不稳定:未归一化的表示导致早期层梯度可能爆炸
- 参数尺度不一致:不同层参数的尺度差异导致统一学习率难以选择
anTransformer 如何消除 Warm-up?
anTransformer 通过以下设计消除了对 warm-up 的需求:
1. 参数归一化约束
对所有权重矩阵的输入维度进行归一化:
这确保了线性映射的输出范数有界。
2. 近似归一化表示
通过标量乘法近似归一化,确保表示范数围绕常数集中。
3. 优化器动态改善
实验表明,anTransformer 中 Adam 的一阶矩(momentum)方差更稳定:

图中显示 anGPT 的 Adam 一阶矩方差在训练过程中更加稳定。
实验验证
在论文的实验中,anGPT 和 nGPT 均不使用学习率预热,而 GPT+ 基线使用 10% 训练步数的 warm-up。尽管如此,anGPT 仍能达到更快的收敛速度。
6. Removing Weight Decay Requirements
权重衰减的传统作用
传统训练中,权重衰减(Weight Decay)主要用于:
- 防止权重过大:避免数值不稳定
- 正则化:减少过拟合
- 维持表示尺度:间接控制激活值范围
anTransformer 的参数约束策略
anTransformer 对参数实施范数约束而非正则化:
关键区别:
| 方法 | 策略 | 效果 |
|---|---|---|
| Weight Decay | 正则化项 | 软约束,权重可无限增长 |
| 范数约束 | 硬约束,权重有界 |
为什么不需要权重衰减?
- 表示被约束在紧致空间:由于参数有界,激活值不会无限增长
- 归一化稳定梯度:反向传播时梯度尺度更一致
- 参数尺度均匀:所有参数组具有相似的量级,优化器行为更可预测
参数重参数化
为了确保均匀的优化动态,anTransformer 对可学习缩放参数 (如 )采用重参数化:
其中 是优化变量,初始化为 ,确保 。
有效学习率:
7. Scaling Law Implications
Chinchilla 缩放定律回顾
标准 Chinchilla 缩放定律5指出:
其中 是模型参数量, 是训练 token 数。
anGPT 的缩放特性
最优批量大小缩放
论文发现,anGPT 的最优批量大小随模型规模的变化趋势与标准 GPT 一致:
这与 Chinchilla 的发现高度吻合。
最优学习率缩放
关键优势:支持更大的批量大小
anGPT 可以在不损失性能的情况下使用更大的批量大小!
这意味着:
- 更少的训练步数
- 更高的 GPU 并行效率
- 潜在的更短训练时间
缩放趋势拟合

上图展示了跨多个模型规模(从 46M 到 1B 参数)的最优超参数拟合。anGPT 保持了与标准 GPT 相似的缩放特性,同时支持更大的批量大小。
8. Experimental Results
实验设置
| 配置 | 详情 |
|---|---|
| 数据集 | SlimPajama (~627B tokens) |
| Tokenizer | GPT-NeoX (50k 词汇量) |
| 上下文窗口 | 2048 tokens |
| 训练轮次 | <1 epoch |
| 基线模型 | GPT+ (GPT + SwiGLU + RoPE + RMSNorm + QK-Norm) |
核心结果
收敛速度提升
anGPT 相比 GPT+ 实现高达 40% 的收敛加速!
具体而言:
- 在相同的训练步数下,anGPT 达到更低的验证损失
- 达到相同损失水平,anGPT 需要的训练步数减少 40%
计算开销
| 模型 | 相对训练时间 |
|---|---|
| GPT+ | 1.00× (基线) |
| nGPT | ~1.10× |
| anGPT | 1.03× |
anGPT 的额外计算开销仅为 3%,远低于 nGPT。
推理效率
anGPT 预期推理时间与 GPT+ 相同!
原因:推理时可以将常数归一化因子吸收到模型参数中,无需额外归一化层。
与 nGPT 的对比
| 特性 | nGPT | anGPT |
|---|---|---|
| 归一化方式 | 沿残差维度严格归一化 | 近似归一化(标量乘法) |
| 额外归一化层 | 需要 | 不需要 |
| 推理开销 | 增加 | 无增加 |
| 收敛速度 | 较快 | 相当或更快 |
| 参数约束 | 严格在超球面 | 紧致空间(有界) |
不同模型规模的表现
在 46M 到 1B 参数的范围内,anGPT 均一致性地优于 GPT+,且与 nGPT 表现相当或更好。
9. Implementation
架构对比
以下是 GPT+、nGPT 和 anGPT 的架构对比表:
| 组件 | GPT+ | nGPT | anGPT |
|---|---|---|---|
| 注意力层 | |||
| Pre-Norm | RMSNorm | norm + LERP | norm + LERP |
| QKV 映射 | |||
| QK-Norm | 额外 RMSNorm | - | - |
| 注意力权重 | |||
| 输出映射 | |||
| Post-Norm | RMSNorm | LERP + norm | LERP + norm |
| FFN 层 | |||
| 门控 | + RMSNorm | ||
| 激活 | SiLU | SiLU | SiLU |
| 上投影 | |||
| 下投影 | + RMSNorm | ||
| Post-Norm | RMSNorm | LERP + norm | LERP + norm |
代码实现示例
以下是一个简化的 anGPT 层实现:
import torch
import torch.nn as nn
import math
class ApproxNormalizedAttention(nn.Module):
"""
Approximate Normalized Attention Layer
Key modifications:
1. Remove all normalization layers (RMSNorm, QK-Norm)
2. Add constant normalization factors via scalar multiplication
3. Use LERP residual update instead of standard addition
"""
def __init__(self, d_model, n_heads, has_bias=True):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
# Normalization factors (derived from theory)
self.nu_qkv = math.sqrt(d_model / self.d_head) # QKV mapping
self.nu_proj = math.sqrt(self.d_head / d_model) # Output projection
# Learnable parameters (constrained to bounded space)
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=has_bias)
self.out_proj = nn.Linear(d_model, d_model, bias=has_bias)
# LERP interpolation parameter (learnable, bounded)
self.alpha_A = nn.Parameter(torch.tensor(0.1))
# Normalize QKV and output projection weights
self._normalize_weights()
def _normalize_weights(self):
"""Normalize weights along input dimension"""
with torch.no_grad():
# Normalize QKV projection
w = self.qkv_proj.weight
norms = torch.norm(w, dim=1, keepdim=True) # Along output dim
self.qkv_proj.weight.div_(norms.clamp(min=1.0))
# Normalize output projection
w = self.out_proj.weight
norms = torch.norm(w, dim=1, keepdim=True)
self.out_proj.weight.div_(norms.clamp(min=1.0))
def forward(self, x, mask=None):
"""
Args:
x: Input tensor [batch, seq_len, d_model]
mask: Optional attention mask
"""
B, T, C = x.size()
# Normalize input (L2 norm)
x_norm = x / (x.norm(dim=-1, keepdim=True) + 1e-8)
# QKV projection with normalization factor
qkv = self.qkv_proj(x_norm) * self.nu_qkv
q, k, v = qkv.split(self.d_model, dim=-1)
# Reshape for multi-head attention
q = q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
# Scaled dot-product attention
scale = math.sqrt(self.d_head)
attn = (q @ k.transpose(-2, -1)) / scale
if mask is not None:
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = torch.softmax(attn, dim=-1)
# Apply attention to values
out = attn @ v
out = out.transpose(1, 2).contiguous().view(B, T, C)
# Output projection with normalization factor
out = self.out_proj(out) * self.nu_proj
# LERP residual update
alpha = torch.sigmoid(self.alpha_A) # Bounded in (0, 1)
out = x_norm + alpha * (out - x_norm)
# Final L2 normalization
out = out / (out.norm(dim=-1, keepdim=True) + 1e-8)
return out
class ApproxNormalizedFFN(nn.Module):
"""
Approximate Normalized Feed-Forward Network with SwiGLU
"""
# Normalization factors
NU_UZ = 0.5 # sqrt(d_model / (4 * d_model))
NU_D = 2.0 # sqrt((4 * d_model) / d_model)
NU_ACTIVATION = 3.74 # Approximate via Monte Carlo
def __init__(self, d_model, d_ff=None, has_bias=True):
super().__init__()
d_ff = d_ff or 4 * d_model
self.d_model = d_model
self.d_ff = d_ff
# Gate and up projection (SwiGLU)
self.gate_proj = nn.Linear(d_model, d_ff, bias=has_bias)
self.up_proj = nn.Linear(d_model, d_ff, bias=has_bias)
# Down projection
self.down_proj = nn.Linear(d_ff, d_model, bias=has_bias)
# LERP interpolation parameter
self.alpha_F = nn.Parameter(torch.tensor(0.1))
# Normalize weights
self._normalize_weights()
def _normalize_weights(self):
"""Normalize weights along input dimension"""
with torch.no_grad():
for proj in [self.gate_proj, self.up_proj, self.down_proj]:
w = proj.weight
norms = torch.norm(w, dim=1, keepdim=True)
proj.weight.div_(norms.clamp(min=1.0))
def forward(self, x):
"""
Args:
x: Input tensor [batch, seq_len, d_model]
"""
# Normalize input
x_norm = x / (x.norm(dim=-1, keepdim=True) + 1e-8)
# SwiGLU: gate * SiLU(up)
gate = self.gate_proj(x_norm) * self.NU_UZ
up = self.up_proj(x_norm) * self.NU_UZ
x_ff = torch.nn.functional.silu(gate) * up * self.NU_ACTIVATION
# Down projection
out = self.down_proj(x_ff) * self.NU_D
# LERP residual update
alpha = torch.sigmoid(self.alpha_F)
out = x_norm + alpha * (out - x_norm)
# Final L2 normalization
out = out / (out.norm(dim=-1, keepdim=True) + 1e-8)
return out
class ApproxNormalizedTransformerBlock(nn.Module):
"""
Complete anTransformer Block
"""
def __init__(self, d_model, n_heads, d_ff=None):
super().__init__()
self.attn = ApproxNormalizedAttention(d_model, n_heads)
self.ffn = ApproxNormalizedFFN(d_model, d_ff)
def forward(self, x, mask=None):
# Attention block
x = self.attn(x, mask)
# FFN block
x = self.ffn(x)
return x训练配置对比
# GPT+ Configuration (requires warm-up and weight decay)
gpt_plus_config = {
"learning_rate": 1e-4,
"warmup_ratio": 0.1, # 10% of total steps
"weight_decay": 0.1,
"use_cosine_lr": True,
}
# anGPT Configuration (no warm-up, no weight decay)
# Key difference: fewer hyperparameters to tune!
angpt_config = {
"learning_rate": 1e-4, # Same range works well
"warmup_ratio": 0.0, # NOT needed!
"weight_decay": 0.0, # NOT needed!
"use_cosine_lr": True,
# Additional anGPT-specific parameters
"alpha_init": 0.01, # For scaling parameters
"param_norm_bound": 1.0, # Max norm for parameters
}10. Conclusion and Future Directions
核心贡献总结
- 理论贡献:利用高维随机向量范数紧密集中现象,为近似归一化提供了理论依据
- 架构贡献:提出 anTransformer,通过标量乘法替代昂贵的归一化操作
- 实践贡献:大幅减少训练超参数(无需 warm-up、无需 weight decay)
- 效率贡献:仅增加 3% 训练时间,推理时间无额外开销
与相关工作的关系
| 方法 | 核心思想 | anGPT 借鉴 |
|---|---|---|
| RMSNorm | 统计归一化 | ✓ 完全移除 |
| Normalized GPT (nGPT) | 超球面表示 | ✓ LERP 更新 |
| Chinchilla Scaling | 最优资源分配 | ✓ 保持缩放特性 |
开放问题
- 更深的网络:anGPT 能否有效训练超过 1000 层的 Transformer?
- 多模态扩展:在视觉 Transformer 等非文本任务上的表现如何?
- 与其他技术结合:与 MoE、Flash Attention 等的兼容性
参考资源
- 原始论文: arXiv:2505.22014 - Learning in Compact Spaces with Approximately Normalized Transformer (NeurIPS 2025)
- 作者: Jörg K.H. Franke, Urs Spiegelhalter, Marianna Nezhurina, Jenia Jitsev, Frank Hutter, Michael Hefenbrock
References
Footnotes
-
Ba, J. L., Kiros, J. R., & Hinton, G. E. (2016). Layer Normalization. arXiv:1607.06450. ↩
-
Xiong, R., Yang, Y., He, D., et al. (2020). On Layer Normalization in the Pre-Training Transformer. ICML 2020. ↩
-
Sun, Y., et al. (2025). The Curse of Depth in Large Language Models. NeurIPS 2025. ↩
-
Vershynin, R. (2018). High-Dimensional Probability: An Introduction with Applications in Data Science. Cambridge University Press. ↩
-
Hoffmann, J., Borgeaud, S., Mensch, A., et al. (2022). Training Compute-Optimal Large Language Models. NeurIPS 2022. ↩