残差网络与跳跃连接理论
残差网络(ResNet)由 He 等人在 2016 年提出,通过引入跳跃连接(Skip Connection)成功训练了超过 1000 层的深度神经网络,解决了深度学习中的退化问题。1 本文深入分析残差学习的数学原理、梯度流特性,以及各种变体架构的设计思想。
问题背景:深度网络的退化
退化的定义
传统观点认为:网络越深,性能应该越好。但 He 等人发现,当网络深度增加到一定程度后,增加更多层反而导致训练误差和测试误差都增加:
训练误差
│
│ ┌──────┐
│ ╱│ 深层 │ ← 退化:深层网络的性能反而下降
│ ╱ │ 网络 │
│ ╱ └──────┘
│ ╱
│ ╱ ┌────┐
│ ╱ │浅层│
│ ╱ │网络 │
│╱ └────┘
└─────────────────────── 网络深度
退化 vs 过拟合
| 特性 | 退化 | 过拟合 |
|---|---|---|
| 训练误差 | ↑ 增加 | ↓ 降低 |
| 测试误差 | ↑ 增加 | 先降后升 |
| 根因 | 优化困难 | 记忆噪声 |
| 解决方案 | 残差学习 | 正则化 |
退化的根本原因
退化不是因为过拟合,而是因为优化困难。深层网络更难优化到最优解。
核心洞察:如果浅层网络已经达到某个最优解,那么增加更多层后,这些层的最优策略应该是学习恒等映射,即什么都不做。但学习”恒等”对神经网络来说并不容易。
残差学习框架
核心思想
假设我们希望学习一个映射 ,残差学习框架让网络学习:
则原始映射变为:
恒等映射的引入
传统方法: 残差方法:
输入 x ──→ [Conv] ──→ ... ──→ 输出 输入 x ──┬──→ [Conv] ──→ ... ──┬──→ 输出
│ │
└─── 跳跃连接 ←───────┘
如果 是恒等映射,则 ,网络只需将权重学习到零——这比学习恒等映射容易得多。
数学形式化
残差块的前向传播:
其中 是残差函数, 是第 层的输入。
反向传播(链式法则):
关键在于:
这个加法项 保证了梯度的有效传播!
跳跃连接的梯度流分析
梯度消失问题的缓解
传统网络中,反向传播的梯度为:
当层数很深时,连乘导致梯度指数级衰减或爆炸。
残差连接的梯度保护
在残差网络中:
反向传播时:
关键性质:
- 梯度不会消失:
$I$项保证了即使 很小,梯度也能传递 - 梯度不会爆炸:虽然有连乘,但每次乘法后加 ,稳定了梯度尺度
- 信息直接传播:信号可以直接从浅层传到深层
梯度流可视化
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = torch.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += residual # 跳跃连接
return torch.relu(out)
def analyze_gradient_flow(model, x):
"""分析残差网络中的梯度流"""
model.eval()
# Hook to capture gradients
gradients = {}
def hook_fn(module, grad_input, grad_output):
if grad_output[0] is not None:
gradients[module] = grad_output[0].norm().item()
# 注册hook
hooks = []
for name, module in model.named_modules():
if isinstance(module, ResidualBlock):
hooks.append(module.register_backward_hook(hook_fn))
# 前向+反向
output = model(x)
loss = output.sum()
loss.backward()
return gradients恒等映射的重要性
预激活 vs 后激活
ResNet 最初使用后激活(Post-activation):
Original ResNet (post-activation):
x_l ──→ [BN → ReLU → Conv] → [BN → ReLU → Conv] → [+] → ReLU → x_{l+1}
后来提出的**预激活(Pre-activation)**更优雅:
Pre-activation ResNet:
x_l ──→ [BN → ReLU → Conv] → [BN → ReLU → Conv] → [+] → x_{l+1}
↑ ↑
跳跃连接直接加 无需额外激活
预激活的优势
- 是 BN-ReLU-Conv 的组合,更容易学习零映射
- 跳跃连接不受 BN 影响,信息传播更直接
- 反向传播更顺畅:梯度可以直接传到输入
class PreActResidualBlock(nn.Module):
"""
预激活残差块
更适合训练超深网络
"""
def __init__(self, channels):
super().__init__()
self.bn1 = nn.BatchNorm2d(channels)
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
def forward(self, x):
residual = x
# 预激活
out = self.bn1(x)
out = torch.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = torch.relu(out)
out = self.conv2(out)
# 恒等映射
out += residual
return outResNet 架构详解
整体结构
class ResNet(nn.Module):
"""
ResNet 主架构
Stage: [Conv, BN, ReLU, MaxPool]
4个残差阶段: [3, 4, 6, 3] 个残差块
"""
def __init__(self, num_classes=1000):
super().__init__()
# 初始卷积层
self.stem = nn.Sequential(
nn.Conv2d(3, 64, 7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(3, stride=2, padding=1)
)
# 4个残差阶段
self.layer1 = self._make_layer(64, 256, 3) # 输出 56×56
self.layer2 = self._make_layer(256, 512, 4) # 输出 28×28
self.layer3 = self._make_layer(512, 1024, 6) # 输出 14×14
self.layer4 = self._make_layer(1024, 2048, 3) # 输出 7×7
# 分类头
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(2048, num_classes)
def _make_layer(self, in_channels, out_channels, num_blocks):
layers = []
# 下采样块(第一个块需要维度匹配)
layers.append(BottleneckBlock(in_channels, out_channels, stride=2))
# 后续块
for _ in range(1, num_blocks):
layers.append(BottleneckBlock(out_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = self.stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
return self.fc(x)
class BottleneckBlock(nn.Module):
"""
Bottleneck 残差块
1×1 Conv (降维) → 3×3 Conv → 1×1 Conv (升维)
参数量: 1×1×C×(C/4) + 3×3×(C/4)×(C/4) + 1×1×(C/4)×C
≈ 0.19C² (相比 BasicBlock 的 2C² 更高效)
"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
mid_channels = out_channels // self.expansion
# 1×1 降维
self.conv1 = nn.Conv2d(in_channels, mid_channels, 1)
self.bn1 = nn.BatchNorm2d(mid_channels)
# 3×3 卷积
self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3,
stride=stride, padding=1)
self.bn2 = nn.BatchNorm2d(mid_channels)
# 1×1 升维
self.conv3 = nn.Conv2d(mid_channels, out_channels, 1)
self.bn3 = nn.BatchNorm2d(out_channels)
# 跳跃连接
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, stride=stride),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = torch.relu(self.bn1(self.conv1(x)))
out = torch.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x) # 跳跃连接
return torch.relu(out)残差块类型对比
| 类型 | 结构 | 参数量 | 计算量 | 适用场景 |
|---|---|---|---|---|
| BasicBlock | 3×3 → 3×3 | 小模型 | ||
| Bottleneck | 1×1→3×3→1×1 | 大模型 | ||
| Pre-activation | BN-ReLU-Conv×2 | 同上 | 同上 | 超深网络 |
ResNet 变体架构
WideResNet
核心思想:增加通道数比增加深度更有效2
class WideResidualBlock(nn.Module):
"""
WideResNet 的残差块
通过增加宽度(通道数)提升性能
"""
def __init__(self, in_channels, out_channels, stride=1, dropout=0.0):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = nn.Conv2d(in_channels, out_channels, 3,
stride=stride, padding=1)
self.dropout = nn.Dropout(dropout)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3,
stride=1, padding=1)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=stride)
def forward(self, x):
out = torch.relu(self.bn1(x))
out = self.conv1(out)
out = torch.relu(self.bn2(out))
out = self.dropout(out)
out = self.conv2(out)
return out + self.shortcut(x)ResNeXt
核心思想:引入分组卷积,增加 Cardinality(基数)3
class ResNeXtBlock(nn.Module):
"""
ResNeXt 残差块
分组卷积:G 组独立卷积后 concat
"""
def __init__(self, in_channels, out_channels, stride=1, cardinality=32, width=4):
super().__init__()
mid_channels = cardinality * width
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = nn.Conv2d(in_channels, mid_channels, 1)
self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3,
stride=stride, padding=1,
groups=cardinality) # 分组卷积
self.bn2 = nn.BatchNorm2d(mid_channels)
self.conv3 = nn.Conv2d(mid_channels, out_channels, 1)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=stride)
def forward(self, x):
out = self.conv1(torch.relu(self.bn1(x)))
out = self.conv2(torch.relu(self.bn2(out)))
out = self.conv3(out)
return out + self.shortcut(x)DenseNet
核心思想:每层接收所有前面层的特征作为输入4
class DenseBlock(nn.Module):
"""
DenseNet 的密集块
每层与所有前面层相连
"""
def __init__(self, num_layers, in_channels, growth_rate):
super().__init__()
self.layers = nn.ModuleList()
for i in range(num_layers):
# 每层产生 growth_rate 个特征图
self.layers.append(self._make_layer(in_channels + i * growth_rate,
growth_rate))
def _make_layer(self, in_channels, out_channels):
return nn.Sequential(
nn.BatchNorm2d(in_channels),
nn.ReLU(),
nn.Conv2d(in_channels, out_channels, 3, padding=1)
)
def forward(self, x):
features = [x]
for layer in self.layers:
new_feature = layer(torch.cat(features, dim=1))
features.append(new_feature)
return torch.cat(features, dim=1)架构对比
| 架构 | 核心思想 | 跳跃连接 | 特点 |
|---|---|---|---|
| ResNet | 残差学习 | 逐元素加法 | 简单有效 |
| WideResNet | 增加宽度 | 同 ResNet | 更宽更快 |
| ResNeXt | 分组卷积 | 同 ResNet | 高效且强 |
| DenseNet | 密集连接 | Concat | 特征复用 |
| SENet | 通道注意力 | 同 ResNet | 自适应校准 |
残差连接与 Neural ODE
联系与类比
残差网络可以看作离散化的常微分方程(ODE):
当层数趋于无穷,步长趋于零时:
这正是 Neural ODE5 的起点!
ODE 视角的优势
class ODEBlock(nn.Module):
"""
Neural ODE Block
用 ODE 求解器替代固定层数
"""
def __init__(self, odefunc):
super().__init__()
self.odefunc = odefunc
def forward(self, x, t=torch.tensor([0, 1])):
from torchdiffeq import odeint
return odeint(self.odefunc, x, t)优势
- 自适应计算:根据输入复杂度决定计算量
- 内存效率:可逆计算
- 连续深度:没有离散的”第几层”概念
跳跃连接的类型
1. 恒等跳跃连接(Identity Skip)
# 直接相加
output = F(x) + x2. 投影跳跃连接(Projection Skip)
# 当维度不匹配时
output = F(x) + projection(x)3. 门控跳跃连接(Gated Skip)
# Highway Network 风格
gate = torch.sigmoid(W_g(x))
output = gate * F(x) + (1 - gate) * x4. 稀疏跳跃连接
# 可学习的稀疏连接
gate = torch.sigmoid(W_g(x))
mask = (gate > threshold).float()
output = mask * F(x) + (1 - mask) * x实践指南
残差网络的实现要点
class MyResNet(nn.Module):
def __init__(self, in_channels, hidden_channels, num_blocks):
super().__init__()
# 初始化层(不含残差)
self.stem = nn.Sequential(
nn.Conv2d(in_channels, hidden_channels, 3, padding=1),
nn.BatchNorm2d(hidden_channels),
nn.ReLU()
)
# 残差层
self.layers = nn.ModuleList([
ResidualBlock(hidden_channels)
for _ in range(num_blocks)
])
def forward(self, x):
x = self.stem(x)
for layer in self.layers:
x = layer(x) # 跳跃连接在内部
return x常见问题与解决
# 问题1: 维度不匹配
class ResidualBlock(nn.Module):
def __init__(self, in_ch, out_ch, stride=1):
# 投影处理维度不匹配
self.shortcut = nn.Sequential()
if stride != 1 or in_ch != out_ch:
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=stride)
def forward(self, x):
# 恒等连接保证信息流
return self.F(x) + self.shortcut(x)
# 问题2: 训练不稳定
# → 使用预激活 + 适当的初始化
def init_weights(module):
for m in module.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)适用场景
# ✅ 适合使用残差连接的场景:
# 1. 网络深度 ≥ 20 层
# 2. 需要训练超深网络
# 3. 图像分类、检测、分割
# 4. 需要特征复用
# ❌ 可能不需要残差连接的场景:
# 1. 网络深度 < 10 层
# 2. 极其浅的网络
# 3. 某些生成任务(可能干扰风格)核心公式速查
| 概念 | 公式 |
|---|---|
| 残差学习 | |
| 前向传播 | |
| 梯度流 | |
| 恒等映射 | |
| 预激活 |
参考
相关文章
- ResNet 基础 — ResNet 简介
- CNN与图像分类 — CNN 架构回顾
- 反向传播与梯度流 — 梯度流分析
- 归一化技术 — BN/LN/GroupNorm
- Mamba-2 SSD — 状态空间模型与残差连接的联系
Footnotes
-
He, K., Zhang, X., Ren, S., & Sun, J. (2016). “Deep Residual Learning for Image Recognition”. CVPR 2016. https://arxiv.org/abs/1512.03385 ↩
-
Zagoruyko, S., & Komodakis, N. (2016). “Wide Residual Networks”. BMVC 2016. https://arxiv.org/abs/1605.07146 ↩
-
Xie, S., et al. (2017). “Aggregated Residual Transformations for Deep Neural Networks”. CVPR 2017. https://arxiv.org/abs/1611.05431 ↩
-
Huang, G., et al. (2017). “Densely Connected Convolutional Networks”. CVPR 2017. https://arxiv.org/abs/1608.06993 ↩
-
Chen, R.T.Q., et al. (2018). “Neural Ordinary Differential Equations”. NeurIPS 2018. https://arxiv.org/abs/1806.07366 ↩