概述

权重初始化是深度学习中最容易被忽视却最关键的环节之一。初始化不当会导致梯度消失/爆炸,使深层网络无法训练;恰当的初始化则让收敛速度提升一个数量级。1

本文系统推导 Xavier(Glorot)和 He(Kaiming)初始化的数学原理,介绍其他常用初始化方法,并给出 PyTorch 实践指南。


一、问题:为什么初始化如此重要?

1.1 现象

考虑一个简单的实验:训练一个 50 层的 MLP 做 MNIST 分类。

  • 用全零初始化:损失不下降,模型不学习
  • 用标准正态 :前几轮 loss 就变成 NaN
  • :训练慢且效果差
  • Kaiming 初始化:训练稳定,5 个 epoch 达到 98% 准确率

1.2 根本原因

深度网络的前向传播中,激活值的方差会随层数指数级变化:

反向传播同理。初始化方差的选择决定了信号是放大还是衰减


二、朴素初始化方法

2.1 全零/常数初始化

nn.init.zeros_(layer.weight)

问题:所有神经元学到的特征完全相同(对称性问题),无法学习。

例外:偏置 可以初始化为 0。

2.2 小随机数初始化

nn.init.uniform_(layer.weight, -0.01, 0.01)

问题:方差太小,深层网络信号衰减到 0。

2.3 大随机数初始化

nn.init.uniform_(layer.weight, -1.0, 1.0)

问题:方差太大,深层网络信号爆炸到 NaN。

结论:需要精心选择方差,使信号在层间保持稳定。


三、Xavier/Glorot 初始化(Glorot & Bengio 2010)

3.1 核心思想

Glorot & Bengio 在论文2中提出方差守恒原则:

每层激活的方差应与输入的方差相同(前向传播)
每层梯度的方差应与输出的梯度方差相同(反向传播)

3.2 推导

考虑一个简单的全连接层:

假设

  • 各元素独立同分布
  • 独立
  • 偏置初始化为 0
  • 激活函数 在零点附近近似恒等

前向传播的方差推导

为输入维度:

要让 ,需要:

即:

反向传播的方差推导

类似地,从输出回传:

要让梯度方差保持稳定:

即:

3.3 折中方案

前向和反向要求不同,取几何平均

这就是 Xavier 初始化。Uniform 形式:

Normal 形式:

3.4 适用激活函数

Xavier 假设激活函数在零点附近近似线性(导数约为 1),适合 sigmoid 和 tanh

对 tanh 特别有效;对 ReLU 效果一般(因为 ReLU 的负半轴破坏了恒等假设)。


四、He/Kaiming 初始化(He et al. 2015)

4.1 ReLU 的特殊性

ReLU 激活:

它有两个问题对 Xavier 假设构成挑战:

  1. 非对称性:负半轴归零,输出方差是输入的一半
  2. 零点的非线性:在零点处不连续

4.2 推导

考虑 ReLU 激活。设 关于 0 对称分布, 也是:

对 ReLU:

要让

这正是 Kaiming 初始化的核心公式。

4.3 PyTorch 接口

PyTorch 提供两种模式:

fan_in 模式(保持前向方差):

nn.init.kaiming_normal_(tensor, mode='fan_in', nonlinearity='relu')

方差:

其中 是 ReLU 的负半轴斜率(默认 0,对应普通 ReLU)。

fan_out 模式(保持反向梯度方差):

nn.init.kaiming_normal_(tensor, mode='fan_out', nonlinearity='relu')

方差:

4.4 PReLU 扩展

对 PReLU(带可学习斜率 ):

4.5 适用激活函数

  • ReLU
  • Leaky ReLU
  • PReLU
  • ELU(需要进一步调整)

不适用于 sigmoid/tanh(用 Xavier)。


五、LeCun 初始化(1998)

5.1 历史

LeCun 等人在 1998 年的 Efficient BackProp3中提出,是 Xavier 的前身:

仅前向方差守恒,不考虑反向。Uniform 形式:

5.2 现代视角

LeCun 初始化对 ReLU 略偏小(少了一个因子 2),但实践中仍可工作。


六、其他常用初始化

6.1 正交初始化(Orthogonal)

nn.init.orthogonal_(tensor, gain=1.0)

通过 SVD 生成正交矩阵:

优势

  • 完美保持信号范数
  • 避免梯度爆炸
  • 循环网络(RNN/LSTM)的标准选择

数学上,正交矩阵的特征值都是 1,信号经过一层不会被放大或缩小。

6.2 截断正态(Truncated Normal)

nn.init.trunc_normal_(tensor, mean=0, std=1, a=-2, b=2)

处截断,避免极端值。Transformer 常用此初始化。

6.3 LSUV(Layer-Sequential Unit-Variance)

Mishkin & Matas(2016)提出:

  1. 用正交初始化
  2. 逐层前向,计算输出方差
  3. 缩放权重使方差归一化

优势:自适应每个层的实际统计。

6.4 常量初始化(偏置专用)

nn.init.constant_(tensor, val)
  • 偏置 通常初始化为 0
  • 遗忘门偏置初始化为 1:LSTM 中保持初始记忆
  • 输出层偏置初始化为类别先验的对数

七、PyTorch 默认初始化表

7.1 各层默认行为

默认初始化
nn.Linear
nn.Conv2d
nn.BatchNorm2dweight=1, bias=0
nn.LSTM遗忘门 bias=1,其他 bias=0
nn.GRU
nn.Embedding
nn.ConvTranspose2d同 Conv2d

7.2 实际项目推荐

def init_weights(module):
    """通用初始化策略"""
    if isinstance(module, (nn.Linear, nn.Conv2d)):
        # 大部分层用 Kaiming
        nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, nn.BatchNorm2d):
        nn.init.constant_(module.weight, 1)
        nn.init.constant_(module.bias, 0)
    elif isinstance(module, nn.LSTM):
        # LSTM 推荐正交 + 遗忘门偏置
        for name, param in module.named_parameters():
            if 'weight_ih' in name:
                nn.init.xavier_uniform_(param)
            elif 'weight_hh' in name:
                nn.init.orthogonal_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)
                # 遗忘门偏置 = 1
                n = param.size(0)
                param.data[n//4:n//2].fill_(1.0)
 
# 应用
model.apply(init_weights)

八、初始化的实践陷阱

8.1 常见错误

错误后果解决方案
忘记缩放 Embedding极大输入幅度Embedding 用小方差(0.01-0.1)
共享层多次初始化多次随机破坏学习只初始化一次
加载预训练又随机初始化覆盖预训练加载权重后跳过初始化
自定义层忘记初始化训练失败__init__ 中显式初始化

8.2 调试工具

前向激活值监控

activation_stats = {}
def hook(module, input, output):
    activation_stats[module] = {
        'mean': output.mean().item(),
        'std': output.std().item(),
    }
 
for name, module in model.named_modules():
    if isinstance(module, (nn.Linear, nn.Conv2d)):
        module.register_forward_hook(hook)
 
# 训练一个 batch 后检查
for name, stats in activation_stats.items():
    print(f"{name}: mean={stats['mean']:.3f}, std={stats['std']:.3f}")

健康指标:

  • 各层 std 维持在 0.5-2.0
  • 没有 NaN 或极端值

8.3 失败模式

  1. 全死 ReLU:所有激活为 0
    • 原因:学习率太大或初始化偏置为负
  2. 激活饱和(sigmoid/tanh):所有输出接近 ±1
    • 原因:输入幅度太大
  3. 梯度爆炸:loss = NaN
    • 原因:权重初始化太大或学习率太高
  4. 梯度消失:loss 不下降
    • 原因:权重初始化太小或网络太深

九、初始化与归一化的关系

9.1 是否可以”用 BN 代替好的初始化”?

部分可以。BatchNorm 通过可学习的缩放/平移使激活分布稳定,部分缓解了初始化问题。但:

  • BN 不能完全替代:没有 BN 的层(如输出层)仍需好初始化
  • BN 本身有可学习参数需要初始化
  • BN 引入额外计算开销

最佳实践好初始化 + 适当归一化

9.2 实践组合

场景初始化归一化
MLPKaiming-
CNN + ReLUKaimingBatchNorm
RNN/LSTM正交 + 遗忘门 bias=1-
TransformerTruncated Normal (0.02)LayerNorm
GAN 生成器-

十、PyTorch 完整实现

10.1 手动实现 Xavier/He

import math
 
def xavier_uniform_(tensor, gain=1.0):
    """Xavier 均匀初始化"""
    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
    std = gain * math.sqrt(2.0 / (fan_in + fan_out))
    bound = math.sqrt(3.0) * std
    with torch.no_grad():
        return tensor.uniform_(-bound, bound)
 
 
def xavier_normal_(tensor, gain=1.0):
    """Xavier 正态初始化"""
    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
    std = gain * math.sqrt(2.0 / (fan_in + fan_out))
    with torch.no_grad():
        return tensor.normal_(0, std)
 
 
def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
    """He 均匀初始化"""
    fan = _calculate_correct_fan(tensor, mode)
    gain = _calculate_gain(nonlinearity, a)
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std
    with torch.no_grad():
        return tensor.uniform_(-bound, bound)
 
 
def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
    """He 正态初始化"""
    fan = _calculate_correct_fan(tensor, mode)
    gain = _calculate_gain(nonlinearity, a)
    std = gain / math.sqrt(fan)
    with torch.no_grad():
        return tensor.normal_(0, std)

10.2 自动推导 fan_in / fan_out

def _calculate_fan_in_and_fan_out(tensor):
    """计算张量的 fan_in 和 fan_out"""
    dimensions = tensor.dim()
    if dimensions < 2:
        raise ValueError("Tensor must have at least 2 dimensions")
    
    if dimensions == 2:  # Linear
        fan_in = tensor.size(1)
        fan_out = tensor.size(0)
    else:  # Conv
        receptive_field = 1
        for i in range(2, dimensions):
            receptive_field *= tensor.size(i)
        fan_in = tensor.size(1) * receptive_field
        fan_out = tensor.size(0) * receptive_field
    
    return fan_in, fan_out
 
 
def _calculate_correct_fan(tensor, mode):
    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
    return fan_in if mode == 'fan_in' else fan_out
 
 
def _calculate_gain(nonlinearity, a=0):
    """不同激活函数的 gain"""
    gains = {
        'linear': 1,
        'sigmoid': 1,
        'tanh': 5/3,
        'relu': math.sqrt(2.0),
        'leaky_relu': math.sqrt(2.0 / (1 + a**2)),
        'selu': 0.75,
    }
    return gains.get(nonlinearity, 1)

10.3 验证初始化效果

def test_initialization(init_fn, activation, n_layers=20, n_samples=1000):
    """验证初始化在多层网络中的信号保持能力"""
    x = torch.randn(n_samples, 100)
    
    for layer in range(n_layers):
        W = init_fn(torch.empty(100, 100))
        x = activation(x @ W.T)
        
    print(f"  Layer {n_layers}: mean={x.mean():.3f}, std={x.std():.3f}")
    
    # 反向传播
    x.sum().backward()
    # 检查梯度(需重新前向,记录梯度)
 
# 测试
test_initialization(lambda t: nn.init.xavier_normal_(t), torch.tanh)
# Layer 20: mean=-0.005, std=0.687  ✓ 方差稳定
 
test_initialization(lambda t: nn.init.kaiming_normal_(t, nonlinearity='relu'), torch.relu)
# Layer 20: mean=0.420, std=0.788  ✓ 方差稳定

十一、参考文献


附录:决策流程

激活函数是什么?
├─ tanh/sigmoid → Xavier 初始化
├─ ReLU/LeakyReLU → Kaiming 初始化
├─ SELU → LeCun Normal
└─ 线性 → 小常数

网络类型?
├─ CNN → Kaiming + BatchNorm
├─ RNN/LSTM → 正交 + 遗忘门 bias=1
├─ Transformer → Truncated Normal (0.02) + LayerNorm
└─ GAN → 谨慎初始化,避免过早模式坍缩

需要监控吗?
└─ 是:用 hook 监控激活和梯度方差

最后更新:2026-06-22

Footnotes

  1. Glorot, X., & Bengio, Y. (2010). Understanding the difficulty of training deep feedforward neural networks. AISTATS.

  2. He, K., Zhang, X., Ren, S., & Sun, J. (2015). Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification. ICCV.

  3. LeCun, Y., Bottou, L., Orr, G. B., & Müller, K. R. (1998). Efficient BackProp. Neural Networks: Tricks of the Trade.