ResNet:深度残差学习
深度残差网络(ResNet)是深度学习历史上最重要的架构之一,由何恺明等人在2015年提出(CVPR 2016最佳论文),彻底解决了深层网络训练的退化问题。1
1. 问题背景:深度网络的退化
1.1 退化现象
传统观点认为,更深的网络应该具有更强的表达能力,从而获得更好的性能。然而,研究人员发现:
- 当网络层数增加时,训练误差反而升高
- 这种退化不是由过拟合引起的(测试集表现同样下降)
- 简单的堆叠层数会导致性能下降
训练误差
│
│ 56层
│ ╱
│ ╱ 20层
│ ╱
│╱
└──────────────── 层数
1.2 退化原因的直觉理解
假设我们构建一个深层网络,其中大部分层是恒等映射(identity mapping)。那么:
- 深层网络至少应该能获得与浅层网络相同的性能
- 因为多出来的层可以学习恒等映射而什么都不做
这意味着深度网络优化困难不是由过拟合造成的,而是由优化器无法有效学习恒等映射引起的。
2. 残差学习框架
2.1 核心思想
ResNet的核心思想是:不直接让堆叠层拟合期望的底层映射 ,而是让它们拟合残差映射 。
其中:
- 是输入(也称为恒等 shortcut/shortcut connection)
- 是堆叠的非线性层学习的残差函数
- 是期望的底层映射
2.2 残差块的数学形式
一个标准的残差块可以表示为:
其中:
- 是输入和输出向量
- 是要学习的残差函数
- 是用于匹配维度的线性投影(可选)
对于两层残差块:
其中 是ReLU激活函数。
2.3 残差学习的优势
为什么残差学习比直接学习更容易?
考虑两种情况:
-
直接学习:让堆叠层学习
- 如果最优映射接近恒等映射,堆叠层需要将输出”调整”回接近输入
- 这对于SGD来说是个困难的任务
-
残差学习:让堆叠层学习
- 如果最优映射接近恒等映射,只需让
- 学习零映射比学习恒等映射容易得多
恒等映射的”零化”:将权重置零比学习恒等映射更简单。
3. 网络架构
3.1 ResNet整体结构
ResNet的架构设计遵循以下原则:
| 阶段 | 输出尺寸 | 通道数 | 残差块数量 |
|---|---|---|---|
| Conv1 | 112×112 | 64 | 1 |
| Conv2.x | 56×56 | 64 | 3 |
| Conv3.x | 28×28 | 128 | 4 |
| Conv4.x | 14×14 | 256 | 6 |
| Conv5.x | 7×7 | 512 | 3 |
3.2 残差块类型
ResNet使用两种类型的残差块:
BasicBlock(用于ResNet-18/34):
Input → Conv(3×3) → ReLU → Conv(3×3) → Add → ReLU → Output
↓ ↑
Shortcut (identity) ─────────────────┘
Bottleneck(用于ResNet-50/101/152):
Input → Conv(1×1,↓) → ReLU → Conv(3×3) → ReLU → Conv(1×1,↑) → Add → ReLU
↓ ↑
Shortcut (Conv 1×1 if needed) ────────────────────────────┘
Bottleneck设计通过 卷积先降维再升维,减少了 卷积的计算量。
3.3 Shortcut Connection设计
论文提出了两种shortcut设计:
| 类型 | 描述 | 公式 |
|---|---|---|
| A | 恒等shortcut,增维时用零填充 | |
| B | 投影shortcut,所有情况都使用 | |
| C | 所有shortcut都投影 |
实验表明 B 在稍高的计算成本下提供最好的结果,但三种变体差异不大。
4. 梯度流分析
4.1 梯度传播公式
考虑第 层的残差块:
通过递归展开:
对损失 求偏导:
链式法则:
4.2 恒等shortcut的优势
关键性质:梯度可以直接从深层传回浅层,不经过任何权重矩阵。
- 即使 的Jacobian矩阵很小,梯度仍能通过恒等项 有效传播
- 避免了传统深层网络中梯度消失的问题
4.3 梯度爆炸/消失的缓解
| 情况 | 传统网络 | ResNet |
|---|---|---|
| 梯度大小 | 乘积链中快速衰减 | 通过恒等shortcut稳定传播 |
| 最坏情况 | 小值 → 梯度消失 | (取决于 项) |
5. 实验结果
5.1 ImageNet分类
| 模型 | Top-1 错误率 | 参数量 |
|---|---|---|
| VGG-19 | 28.5% | 19.6M |
| GoogLeNet | 26.4% | 6.8M |
| Plain-34 | 28.3% | 21.8M |
| ResNet-34 | 24.5% | 21.8M |
| ResNet-50 | 22.8% | 25.6M |
| ResNet-101 | 21.6% | 44.5M |
| ResNet-152 | 21.3% | 60.2M |
5.2 CIFAR-10实验
| 深度 | Plain网络 | ResNet |
|---|---|---|
| 20层 | 8.75% | 8.75% |
| 32层 | 9.00% | 6.97% |
| 44层 | 9.26% | 6.14% |
| 56层 | 9.58% | 5.94% |
| 110层 | 10.2% | 5.61% |
结论:Plain网络在深度增加时性能退化,而ResNet持续改善。
6. PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class BasicBlock(nn.Module):
"""Basic residual block for ResNet-18/34"""
expansion = 1
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels, out_channels,
kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(
out_channels, out_channels,
kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(out_channels)
# Shortcut connection
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * self.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * self.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * self.expansion)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x) # 关键:恒等shortcut
out = F.relu(out)
return out
class Bottleneck(nn.Module):
"""Bottleneck residual block for ResNet-50/101/152"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels,
kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * self.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * self.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels * self.expansion)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
"""ResNet architecture"""
def __init__(self, block, layers, num_classes=1000):
super().__init__()
self.in_channels = 64
# Stem
self.conv1 = nn.Conv2d(3, 64, kernel_size=7,
stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Residual stages
self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
# Classifier
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, blocks, stride=1):
layers = [block(self.in_channels, out_channels, stride)]
self.in_channels = out_channels * block.expansion
for _ in range(1, blocks):
layers.append(block(self.in_channels, out_channels))
return nn.Sequential(*layers)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
# Factory functions
def resnet18(num_classes=1000):
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
def resnet34(num_classes=1000):
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
def resnet50(num_classes=1000):
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
def resnet101(num_classes=1000):
return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)7. ResNet的后续发展
7.1 ResNet变体
| 变体 | 主要改进 | 论文 |
|---|---|---|
| ResNeXt | 多分支聚合-transformer风格 | Aggregated Residual Transformations |
| Wide ResNet | 增加宽度,减少深度 | Wider Residual Networks |
| DenseNet | 密集连接 | Densely Connected Networks |
| SENet | 通道注意力 | Squeeze-and-Excitation Networks |
7.2 Identity Mappings(ECCV 2016)
后续论文深入分析了恒等映射的重要性,提出:
- 预激活设计:将BN-ReLU放在Conv之前
- 更清晰的梯度传播路径
- 进一步提升深度网络的可训练性
8. 核心要点总结
- 残差学习的本质:让网络学习恒等映射的”扰动”而非直接学习底层映射
- Shortcut的核心作用:提供梯度的”高速公路”,直接传回浅层
- 退化问题的解决:残差框架使深层网络的训练变得稳定
- 数学简洁性: 启发了后续大量研究
参考文献
Footnotes
-
He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning for Image Recognition. CVPR 2016. https://arxiv.org/abs/1512.03385 ↩