概述
权重初始化是深度学习中最容易被忽视却最关键的环节之一。初始化不当会导致梯度消失/爆炸,使深层网络无法训练;恰当的初始化则让收敛速度提升一个数量级。1
本文系统推导 Xavier(Glorot)和 He(Kaiming)初始化的数学原理,介绍其他常用初始化方法,并给出 PyTorch 实践指南。
一、问题:为什么初始化如此重要?
1.1 现象
考虑一个简单的实验:训练一个 50 层的 MLP 做 MNIST 分类。
- 用全零初始化:损失不下降,模型不学习
- 用标准正态 :前几轮 loss 就变成 NaN
- 用 :训练慢且效果差
- 用 Kaiming 初始化:训练稳定,5 个 epoch 达到 98% 准确率
1.2 根本原因
深度网络的前向传播中,激活值的方差会随层数指数级变化:
反向传播同理。初始化方差的选择决定了信号是放大还是衰减。
二、朴素初始化方法
2.1 全零/常数初始化
nn.init.zeros_(layer.weight)问题:所有神经元学到的特征完全相同(对称性问题),无法学习。
例外:偏置 可以初始化为 0。
2.2 小随机数初始化
nn.init.uniform_(layer.weight, -0.01, 0.01)问题:方差太小,深层网络信号衰减到 0。
2.3 大随机数初始化
nn.init.uniform_(layer.weight, -1.0, 1.0)问题:方差太大,深层网络信号爆炸到 NaN。
结论:需要精心选择方差,使信号在层间保持稳定。
三、Xavier/Glorot 初始化(Glorot & Bengio 2010)
3.1 核心思想
Glorot & Bengio 在论文2中提出方差守恒原则:
每层激活的方差应与输入的方差相同(前向传播)
每层梯度的方差应与输出的梯度方差相同(反向传播)
3.2 推导
考虑一个简单的全连接层:
假设:
- 各元素独立同分布
- 与 独立
- 偏置初始化为 0
- 激活函数 在零点附近近似恒等
前向传播的方差推导:
设 为输入维度:
要让 ,需要:
即:
反向传播的方差推导:
类似地,从输出回传:
要让梯度方差保持稳定:
即:
3.3 折中方案
前向和反向要求不同,取几何平均:
这就是 Xavier 初始化。Uniform 形式:
Normal 形式:
3.4 适用激活函数
Xavier 假设激活函数在零点附近近似线性(导数约为 1),适合 sigmoid 和 tanh。
对 tanh 特别有效;对 ReLU 效果一般(因为 ReLU 的负半轴破坏了恒等假设)。
四、He/Kaiming 初始化(He et al. 2015)
4.1 ReLU 的特殊性
ReLU 激活:
它有两个问题对 Xavier 假设构成挑战:
- 非对称性:负半轴归零,输出方差是输入的一半
- 零点的非线性:在零点处不连续
4.2 推导
考虑 ReLU 激活。设 关于 0 对称分布, 也是:
对 ReLU:
要让 :
这正是 Kaiming 初始化的核心公式。
4.3 PyTorch 接口
PyTorch 提供两种模式:
fan_in 模式(保持前向方差):
nn.init.kaiming_normal_(tensor, mode='fan_in', nonlinearity='relu')方差:
其中 是 ReLU 的负半轴斜率(默认 0,对应普通 ReLU)。
fan_out 模式(保持反向梯度方差):
nn.init.kaiming_normal_(tensor, mode='fan_out', nonlinearity='relu')方差:
4.4 PReLU 扩展
对 PReLU(带可学习斜率 ):
4.5 适用激活函数
- ReLU
- Leaky ReLU
- PReLU
- ELU(需要进一步调整)
不适用于 sigmoid/tanh(用 Xavier)。
五、LeCun 初始化(1998)
5.1 历史
LeCun 等人在 1998 年的 Efficient BackProp3中提出,是 Xavier 的前身:
仅前向方差守恒,不考虑反向。Uniform 形式:
5.2 现代视角
LeCun 初始化对 ReLU 略偏小(少了一个因子 2),但实践中仍可工作。
六、其他常用初始化
6.1 正交初始化(Orthogonal)
nn.init.orthogonal_(tensor, gain=1.0)通过 SVD 生成正交矩阵:。
优势:
- 完美保持信号范数
- 避免梯度爆炸
- 循环网络(RNN/LSTM)的标准选择
数学上,正交矩阵的特征值都是 1,信号经过一层不会被放大或缩小。
6.2 截断正态(Truncated Normal)
nn.init.trunc_normal_(tensor, mean=0, std=1, a=-2, b=2)在 处截断,避免极端值。Transformer 常用此初始化。
6.3 LSUV(Layer-Sequential Unit-Variance)
Mishkin & Matas(2016)提出:
- 用正交初始化
- 逐层前向,计算输出方差
- 缩放权重使方差归一化
优势:自适应每个层的实际统计。
6.4 常量初始化(偏置专用)
nn.init.constant_(tensor, val)- 偏置 通常初始化为 0
- 遗忘门偏置初始化为 1:LSTM 中保持初始记忆
- 输出层偏置初始化为类别先验的对数
七、PyTorch 默认初始化表
7.1 各层默认行为
| 层 | 默认初始化 |
|---|---|
nn.Linear | |
nn.Conv2d | , |
nn.BatchNorm2d | weight=1, bias=0 |
nn.LSTM | 遗忘门 bias=1,其他 bias=0 |
nn.GRU | |
nn.Embedding | |
nn.ConvTranspose2d | 同 Conv2d |
7.2 实际项目推荐
def init_weights(module):
"""通用初始化策略"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# 大部分层用 Kaiming
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.BatchNorm2d):
nn.init.constant_(module.weight, 1)
nn.init.constant_(module.bias, 0)
elif isinstance(module, nn.LSTM):
# LSTM 推荐正交 + 遗忘门偏置
for name, param in module.named_parameters():
if 'weight_ih' in name:
nn.init.xavier_uniform_(param)
elif 'weight_hh' in name:
nn.init.orthogonal_(param)
elif 'bias' in name:
nn.init.zeros_(param)
# 遗忘门偏置 = 1
n = param.size(0)
param.data[n//4:n//2].fill_(1.0)
# 应用
model.apply(init_weights)八、初始化的实践陷阱
8.1 常见错误
| 错误 | 后果 | 解决方案 |
|---|---|---|
| 忘记缩放 Embedding | 极大输入幅度 | Embedding 用小方差(0.01-0.1) |
| 共享层多次初始化 | 多次随机破坏学习 | 只初始化一次 |
| 加载预训练又随机初始化 | 覆盖预训练 | 加载权重后跳过初始化 |
| 自定义层忘记初始化 | 训练失败 | 在 __init__ 中显式初始化 |
8.2 调试工具
前向激活值监控:
activation_stats = {}
def hook(module, input, output):
activation_stats[module] = {
'mean': output.mean().item(),
'std': output.std().item(),
}
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.register_forward_hook(hook)
# 训练一个 batch 后检查
for name, stats in activation_stats.items():
print(f"{name}: mean={stats['mean']:.3f}, std={stats['std']:.3f}")健康指标:
- 各层 std 维持在 0.5-2.0
- 没有 NaN 或极端值
8.3 失败模式
- 全死 ReLU:所有激活为 0
- 原因:学习率太大或初始化偏置为负
- 激活饱和(sigmoid/tanh):所有输出接近 ±1
- 原因:输入幅度太大
- 梯度爆炸:loss = NaN
- 原因:权重初始化太大或学习率太高
- 梯度消失:loss 不下降
- 原因:权重初始化太小或网络太深
九、初始化与归一化的关系
9.1 是否可以”用 BN 代替好的初始化”?
部分可以。BatchNorm 通过可学习的缩放/平移使激活分布稳定,部分缓解了初始化问题。但:
- BN 不能完全替代:没有 BN 的层(如输出层)仍需好初始化
- BN 本身有可学习参数需要初始化
- BN 引入额外计算开销
最佳实践:好初始化 + 适当归一化。
9.2 实践组合
| 场景 | 初始化 | 归一化 |
|---|---|---|
| MLP | Kaiming | - |
| CNN + ReLU | Kaiming | BatchNorm |
| RNN/LSTM | 正交 + 遗忘门 bias=1 | - |
| Transformer | Truncated Normal (0.02) | LayerNorm |
| GAN 生成器 | - |
十、PyTorch 完整实现
10.1 手动实现 Xavier/He
import math
def xavier_uniform_(tensor, gain=1.0):
"""Xavier 均匀初始化"""
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / (fan_in + fan_out))
bound = math.sqrt(3.0) * std
with torch.no_grad():
return tensor.uniform_(-bound, bound)
def xavier_normal_(tensor, gain=1.0):
"""Xavier 正态初始化"""
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / (fan_in + fan_out))
with torch.no_grad():
return tensor.normal_(0, std)
def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
"""He 均匀初始化"""
fan = _calculate_correct_fan(tensor, mode)
gain = _calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std
with torch.no_grad():
return tensor.uniform_(-bound, bound)
def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
"""He 正态初始化"""
fan = _calculate_correct_fan(tensor, mode)
gain = _calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
with torch.no_grad():
return tensor.normal_(0, std)10.2 自动推导 fan_in / fan_out
def _calculate_fan_in_and_fan_out(tensor):
"""计算张量的 fan_in 和 fan_out"""
dimensions = tensor.dim()
if dimensions < 2:
raise ValueError("Tensor must have at least 2 dimensions")
if dimensions == 2: # Linear
fan_in = tensor.size(1)
fan_out = tensor.size(0)
else: # Conv
receptive_field = 1
for i in range(2, dimensions):
receptive_field *= tensor.size(i)
fan_in = tensor.size(1) * receptive_field
fan_out = tensor.size(0) * receptive_field
return fan_in, fan_out
def _calculate_correct_fan(tensor, mode):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
def _calculate_gain(nonlinearity, a=0):
"""不同激活函数的 gain"""
gains = {
'linear': 1,
'sigmoid': 1,
'tanh': 5/3,
'relu': math.sqrt(2.0),
'leaky_relu': math.sqrt(2.0 / (1 + a**2)),
'selu': 0.75,
}
return gains.get(nonlinearity, 1)10.3 验证初始化效果
def test_initialization(init_fn, activation, n_layers=20, n_samples=1000):
"""验证初始化在多层网络中的信号保持能力"""
x = torch.randn(n_samples, 100)
for layer in range(n_layers):
W = init_fn(torch.empty(100, 100))
x = activation(x @ W.T)
print(f" Layer {n_layers}: mean={x.mean():.3f}, std={x.std():.3f}")
# 反向传播
x.sum().backward()
# 检查梯度(需重新前向,记录梯度)
# 测试
test_initialization(lambda t: nn.init.xavier_normal_(t), torch.tanh)
# Layer 20: mean=-0.005, std=0.687 ✓ 方差稳定
test_initialization(lambda t: nn.init.kaiming_normal_(t, nonlinearity='relu'), torch.relu)
# Layer 20: mean=0.420, std=0.788 ✓ 方差稳定十一、参考文献
附录:决策流程
激活函数是什么?
├─ tanh/sigmoid → Xavier 初始化
├─ ReLU/LeakyReLU → Kaiming 初始化
├─ SELU → LeCun Normal
└─ 线性 → 小常数
网络类型?
├─ CNN → Kaiming + BatchNorm
├─ RNN/LSTM → 正交 + 遗忘门 bias=1
├─ Transformer → Truncated Normal (0.02) + LayerNorm
└─ GAN → 谨慎初始化,避免过早模式坍缩
需要监控吗?
└─ 是:用 hook 监控激活和梯度方差
最后更新:2026-06-22
Footnotes
-
Glorot, X., & Bengio, Y. (2010). Understanding the difficulty of training deep feedforward neural networks. AISTATS. ↩
-
He, K., Zhang, X., Ren, S., & Sun, J. (2015). Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification. ICCV. ↩
-
LeCun, Y., Bottou, L., Orr, G. B., & Müller, K. R. (1998). Efficient BackProp. Neural Networks: Tricks of the Trade. ↩