深度学习归一化技术全面解析

归一化技术是现代深度学习的基石之一,通过对层的输入或激活进行统计归一化,显著加速训练收敛并提升模型稳定性。12 本文系统梳理主流归一化方法的数学原理、实现细节与适用场景。

问题背景:Internal Covariate Shift

协变量偏移问题

在深度神经网络中,每一层的输入分布会随着前一层参数的更新而变化,这种现象称为内部协变量偏移(Internal Covariate Shift, ICS)

假设第 层的输入为:

更新后, 的分布发生变化,导致:

  1. 上层需要不断适应新的输入分布(缓慢收敛)
  2. 底层梯度可能消失或爆炸(训练不稳定)
  3. 饱和激活函数(如Sigmoid)更难优化

归一化的核心思想

通过对输入进行白化(Whitening)——即零均值、单位方差——可以:

  1. 限制激活值在梯度较大的区间
  2. 加速梯度传播
  3. 减少对初始化的敏感性

然而,完整白化(解相关)计算代价高昂,因此实践中使用简化的归一化方法。


Batch Normalization

核心算法

Batch Normalization1 在 mini-batch 上对输入进行归一化:

def batch_norm(x, gamma, beta, eps=1e-5, momentum=0.1, training=True):
    """
    Batch Normalization
    
    Args:
        x: 输入张量 (B, C, H, W) 或 (B, D)
        gamma: 缩放参数
        beta: 偏移参数
        eps: 数值稳定性
        momentum: 滑动平均动量
        training: 训练/推理模式
    
    Returns:
        归一化后的输出
    """
    if training:
        # 在 batch 维度上计算统计量
        mu = x.mean(dim=(0, 2, 3), keepdim=True)  # (1, C, 1, 1)
        var = x.var(dim=(0, 2, 3), keepdim=True)  # (1, C, 1, 1)
        
        # 存储用于推理
        running_mu = (1 - momentum) * running_mu + momentum * mu
        running_var = (1 - momentum) * running_var + momentum * var
    else:
        # 推理时使用滑动平均
        mu = running_mu
        var = running_var
    
    # 归一化
    x_norm = (x - mu) / torch.sqrt(var + eps)
    
    # 仿射变换(可学习的参数)
    y = gamma * x_norm + beta
    
    return y

数学形式化

对于特征维度

其中:

  • Batch 均值
  • Batch 方差
  • 可学习参数(缩放),(平移)

梯度推导

BatchNorm 的反向传播涉及多条路径:

class BatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        # 可学习参数
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        
        # 滑动平均(推理用)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.register_buffer('num_batches', torch.zeros(1))
    
    def forward(self, x):
        if self.training:
            # 计算 batch 统计量
            mu = x.mean(dim=(0, 2, 3))
            var = x.var(dim=(0, 2, 3), unbiased=False)
            
            # 更新滑动平均
            self.running_mean = self.running_mean * self.momentum + mu * (1 - self.momentum)
            self.running_var = self.running_var * self.momentum + var * (1 - self.momentum)
        else:
            mu = self.running_mean
            var = self.running_var
        
        # 归一化 + 仿射
        x_norm = (x - mu.view(1, -1, 1, 1)) / torch.sqrt(var.view(1, -1, 1, 1) + self.eps)
        return self.gamma.view(1, -1, 1, 1) * x_norm + self.beta.view(1, -1, 1, 1)

BatchNorm 的特性

特性说明
依赖 Batch需要足够大的 batch_size(通常 ≥ 32)
训练/推理差异训练用 batch 统计量,推理用滑动平均
通道独立每个通道独立归一化
序列长度无关适合 CNN,不直接适合变长序列

BatchNorm 的问题

  1. 小 Batch Size 下不稳定:统计量估计不准
  2. 不适合 RNN:序列长度变化时难以处理
  3. 训练/推理行为不一致:可能导致分布偏移
  4. 对 batch 内样本相关:分布式训练时可能有问题

Layer Normalization

核心算法

Layer Normalization3 对单个样本的所有特征进行归一化:

def layer_norm(x, gamma, beta, eps=1e-5):
    """
    Layer Normalization
    
    对每个样本的 feature 维度归一化
    
    Args:
        x: 输入张量 (B, D) 或 (B, H, W, D)
        gamma, beta: 可学习参数
        
    Returns:
        归一化后的输出
    """
    # 计算均值和方差(沿所有特征维度)
    mu = x.mean(dim=-1, keepdim=True)
    var = x.var(dim=-1, keepdim=True)
    
    # 归一化
    x_norm = (x - mu) / torch.sqrt(var + eps)
    
    return gamma * x_norm + beta

数学形式化

对于隐藏维度

LayerNorm vs BatchNorm

维度BatchNormLayerNorm
归一化轴Batch 轴 Feature 轴
依赖 Batch
序列处理需特殊处理原生支持
RNN 友好
Transformer通常不用常用
# PyTorch 实现
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps
    
    def forward(self, x):
        # x: (B, *, D) → 计算最后一个维度的统计量
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

Group Normalization

核心思想

Group Normalization4 在 BatchNorm 和 LayerNorm 之间取得平衡——将通道分成 组,每组内部独立归一化:

def group_norm(x, gamma, beta, G, eps=1e-5):
    """
    Group Normalization
    
    Args:
        x: 输入张量 (B, C, H, W)
        G: 分组数量(必须能整除 C)
        gamma, beta: 可学习参数
    """
    B, C, H, W = x.shape
    x = x.view(B, G, C // G, H, W)
    
    # 在 H, W 维度上计算统计量
    mu = x.mean(dim=(2, 3, 4), keepdim=True)
    var = x.var(dim=(2, 3, 4), keepdim=True)
    
    x_norm = (x - mu) / torch.sqrt(var + eps)
    x_norm = x_norm.view(B, C, H, W)
    
    return gamma * x_norm + beta

与其他方法的关系

                    归一化维度
                    
                    Batch      Feature    Instance    Group
                      ↓           ↓          ↓          ↓
                    ┌─────────────────────────────────────┐
                    │                                     │
              None  │          ✓            ✓            │
                    │                                     │
                 Channel ──────────────────────────────────►
                    │                                     │
                    │             ✓           ✓            │
                    │                        │            │
                    │                        G groups    │
                    └─────────────────────────────────────┘
                    
□ = 归一化维度, ✓ = 归一化维度

Group Norm 的特性

特性BatchNormGroupNormLayerNorm
Batch 依赖
通道依赖全通道组内全通道
小 Batch
ImageNet 性能最好(大批量)次好较差
目标检测不稳定首选可用

RMSNorm

核心思想

RMSNorm5 简化 LayerNorm——只使用均方根(RMS)而不计算均值:

def rms_norm(x, gamma, eps=1e-5):
    """
    RMS Normalization
    
    仅使用均方根,不减去均值
    
    Args:
        x: 输入张量
        gamma: 缩放参数
    """
    rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps)
    return gamma * x / rms

数学形式化

RMSNorm vs LayerNorm

特性LayerNormRMSNorm
计算量均值 + 方差仅 RMS
可学习参数
移位不变性
性能相当略优或相当
TransformerGPT/BERT 用Llama/Mistral 用

RMSNorm 的优势

  1. 计算效率高:少一次均值计算
  2. 参数更少:不存储
  3. 效果相当:在 LLM 中广泛使用

权重归一化与谱归一化

权重归一化(Weight Normalization)

将权重向量分解为模长和方向:

def weight_norm(W, g, v):
    """
    Weight Normalization
    
    W = g * (v / ||v||)
    
    将权重分解为模长 g 和方向 v/||v||
    """
    norm = torch.norm(v, dim=-1, keepdim=True)
    return g * v / norm

谱归一化(Spectral Normalization)

限制权重矩阵的谱范数(最大奇异值)为1:

def spectral_norm(W, n_power_iterations=1):
    """
    谱归一化
    
    使 weight 的 spectral norm 等于 1
    
    用于:
    - WGAN 的 Lipschitz 约束
    - GAN 的训练稳定性
    """
    # 初始化 u, v
    u = torch.randn(W.shape[0], 1).normalize(dim=0)
    v = torch.randn(W.shape[1], 1).normalize(dim=0)
    
    for _ in range(n_power_iterations):
        # 幂迭代
        v = W.T @ u / torch.norm(W.T @ u)
        u = W @ v / torch.norm(W @ v)
    
    # spectral norm
    sigma = torch.norm(W @ v) / torch.norm(v)
    
    return W / sigma

归一化技术的对比

特性对比表

方法归一化维度Batch 依赖序列处理典型应用
BatchNorm需修改CNN 分类
LayerNorm原生Transformer
GroupNorm需修改目标检测
InstanceNorm需修改风格迁移
RMSNorm原生LLM

场景选择指南

输入形状       推荐方法
─────────────────────────────────
(B, C, H, W)  CNN 图像:
  - 分类/分割(大batch)→ BatchNorm
  - 检测/分割(小batch)→ GroupNorm
  - 风格迁移 → InstanceNorm

(B, T, D)     RNN/Transformer:
  - LSTM/GRU → LayerNorm
  - Transformer → LayerNorm/RMSNorm

(B, D)        全连接:
  - BatchNorm(少数情况)
  - LayerNorm(更常见)

(B, H, W, C)  Transformer ViT:
  - LayerNorm

归一化与初始化

初始化对归一化的影响

Xavier 初始化(适合 Sigmoid/Tanh):

He 初始化(适合 ReLU):

归一化网络的初始化

归一化层可以减少对精初始化需求,但仍需谨慎:

# BatchNorm 后:初始化影响减小
# LayerNorm 后:可以更激进地初始化
 
def initialize_with_ln(module):
    for name, param in module.named_parameters():
        if 'weight' in name:
            nn.init.kaiming_normal_(param)
        elif 'bias' in name:
            nn.init.zeros_(param)
        elif 'gamma' in name:  # LayerNorm
            nn.init.ones_(param)
        elif 'beta' in name:
            nn.init.zeros_(param)

归一化与梯度流

BatchNorm 的梯度分析

BatchNorm 反向传播涉及:

# 设 y = gamma * x_norm + beta
# x_norm = (x - mu) / sqrt(var + eps)
 
# 反向传播:
# dy/d(x_norm) = gamma
# dx_norm/d(x) 和 dx_norm/d(mu), dx_norm/d(var) 较复杂

梯度缩放效应

归一化将激活值缩放到单位方差,有助于:

  1. 减少梯度爆炸:避免激活值过大
  2. 保持梯度稳定:使不同层的梯度尺度相近
  3. 允许更高学习率:更稳定的训练

实践指南

PyTorch 归一化层速查

import torch.nn as nn
 
# Batch Normalization
nn.BatchNorm1d(num_features)   # (B, C)
nn.BatchNorm2d(num_features)   # (B, C, H, W)
nn.BatchNorm3d(num_features)   # (B, C, D, H, W)
 
# Layer Normalization
nn.LayerNorm(normalized_shape, eps=1e-5)
 
# Group Normalization
nn.GroupNorm(num_groups, num_channels)
 
# Instance Normalization
nn.InstanceNorm1d(num_features)
nn.InstanceNorm2d(num_features)
 
# RMSNorm (需安装 rmsnorm_torch 或自行实现)
class RMSNorm(nn.Module):
    def __init__(self, d, eps=1e-6):
        super().__init__()
        self.d = d
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d))
    
    def forward(self, x):
        norm = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
        return x * norm * self.weight

使用建议

# CNN 图像分类(大 batch)
model = nn.Sequential(
    nn.Conv2d(3, 64, 3, padding=1),
    nn.BatchNorm2d(64),  # 大 batch 效果好
    nn.ReLU(),
    ...
)
 
# Transformer / LLM
class TransformerBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)  # Pre-norm
        self.attn = MultiHeadAttention(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model)
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x))  # Pre-norm
        x = x + self.ffn(self.norm2(x))
        return x
 
# 目标检测(小 batch)
class DetectionModel(nn.Module):
    def __init__(self):
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.GroupNorm(32, 64),  # 不依赖 batch
            nn.ReLU(),
            ...
        )

核心公式速查

方法归一化公式
BatchNorm
LayerNorm
GroupNorm
RMSNorm

参考

相关文章

Footnotes

  1. Ioffe, S., & Szegedy, C. (2015). “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift”. ICML 2015. https://arxiv.org/abs/1502.03167 2

  2. Santurkar, S., et al. (2018). “How Does Batch Normalization Help Optimization?“. NeurIPS 2018. https://arxiv.org/abs/1805.11604

  3. Ba, J.L., Kiros, J.R., & Hinton, G.E. (2016). “Layer Normalization”. arXiv:1607.06450. https://arxiv.org/abs/1607.06450

  4. Wu, Y., & He, K. (2018). “Group Normalization”. ECCV 2018. https://arxiv.org/abs/1803.08494

  5. Zhang, B., & Sennrich, R. (2019). “Root Mean Square Layer Normalization”. NeurIPS 2019. https://arxiv.org/abs/1910.07467