ResNet:深度残差学习

深度残差网络(ResNet)是深度学习历史上最重要的架构之一,由何恺明等人在2015年提出(CVPR 2016最佳论文),彻底解决了深层网络训练的退化问题。1

1. 问题背景:深度网络的退化

1.1 退化现象

传统观点认为,更深的网络应该具有更强的表达能力,从而获得更好的性能。然而,研究人员发现:

  • 当网络层数增加时,训练误差反而升高
  • 这种退化不是由过拟合引起的(测试集表现同样下降)
  • 简单的堆叠层数会导致性能下降
训练误差
    │
    │     56层
    │   ╱
    │  ╱  20层
    │ ╱
    │╱
    └──────────────── 层数

1.2 退化原因的直觉理解

假设我们构建一个深层网络,其中大部分层是恒等映射(identity mapping)。那么:

  • 深层网络至少应该能获得与浅层网络相同的性能
  • 因为多出来的层可以学习恒等映射而什么都不做

这意味着深度网络优化困难不是由过拟合造成的,而是由优化器无法有效学习恒等映射引起的。

2. 残差学习框架

2.1 核心思想

ResNet的核心思想是:不直接让堆叠层拟合期望的底层映射 ,而是让它们拟合残差映射

其中:

  • 是输入(也称为恒等 shortcut/shortcut connection)
  • 是堆叠的非线性层学习的残差函数
  • 是期望的底层映射

2.2 残差块的数学形式

一个标准的残差块可以表示为:

其中:

  • 是输入和输出向量
  • 是要学习的残差函数
  • 是用于匹配维度的线性投影(可选)

对于两层残差块:

其中 是ReLU激活函数。

2.3 残差学习的优势

为什么残差学习比直接学习更容易?

考虑两种情况:

  1. 直接学习:让堆叠层学习

    • 如果最优映射接近恒等映射,堆叠层需要将输出”调整”回接近输入
    • 这对于SGD来说是个困难的任务
  2. 残差学习:让堆叠层学习

    • 如果最优映射接近恒等映射,只需让
    • 学习零映射比学习恒等映射容易得多

恒等映射的”零化”:将权重置零比学习恒等映射更简单。

3. 网络架构

3.1 ResNet整体结构

ResNet的架构设计遵循以下原则:

阶段输出尺寸通道数残差块数量
Conv1112×112641
Conv2.x56×56643
Conv3.x28×281284
Conv4.x14×142566
Conv5.x7×75123

3.2 残差块类型

ResNet使用两种类型的残差块:

BasicBlock(用于ResNet-18/34)

Input → Conv(3×3) → ReLU → Conv(3×3) → Add → ReLU → Output
            ↓                              ↑
      Shortcut (identity) ─────────────────┘

Bottleneck(用于ResNet-50/101/152)

Input → Conv(1×1,↓) → ReLU → Conv(3×3) → ReLU → Conv(1×1,↑) → Add → ReLU
            ↓                                                    ↑
      Shortcut (Conv 1×1 if needed) ────────────────────────────┘

Bottleneck设计通过 卷积先降维再升维,减少了 卷积的计算量。

3.3 Shortcut Connection设计

论文提出了两种shortcut设计:

类型描述公式
A恒等shortcut,增维时用零填充
B投影shortcut,所有情况都使用
C所有shortcut都投影

实验表明 B 在稍高的计算成本下提供最好的结果,但三种变体差异不大。

4. 梯度流分析

4.1 梯度传播公式

考虑第 层的残差块:

通过递归展开:

对损失 求偏导:

链式法则:

4.2 恒等shortcut的优势

关键性质:梯度可以直接从深层传回浅层,不经过任何权重矩阵。

  • 即使 的Jacobian矩阵很小,梯度仍能通过恒等项 有效传播
  • 避免了传统深层网络中梯度消失的问题

4.3 梯度爆炸/消失的缓解

情况传统网络ResNet
梯度大小乘积链中快速衰减通过恒等shortcut稳定传播
最坏情况 小值 → 梯度消失(取决于 项)

5. 实验结果

5.1 ImageNet分类

模型Top-1 错误率参数量
VGG-1928.5%19.6M
GoogLeNet26.4%6.8M
Plain-3428.3%21.8M
ResNet-3424.5%21.8M
ResNet-5022.8%25.6M
ResNet-10121.6%44.5M
ResNet-15221.3%60.2M

5.2 CIFAR-10实验

深度Plain网络ResNet
20层8.75%8.75%
32层9.00%6.97%
44层9.26%6.14%
56层9.58%5.94%
110层10.2%5.61%

结论:Plain网络在深度增加时性能退化,而ResNet持续改善。

6. PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class BasicBlock(nn.Module):
    """Basic residual block for ResNet-18/34"""
    expansion = 1
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, 
            kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels,
            kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * self.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * self.expansion)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)  # 关键:恒等shortcut
        out = F.relu(out)
        return out
 
 
class Bottleneck(nn.Module):
    """Bottleneck residual block for ResNet-50/101/152"""
    expansion = 4
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 
                                kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels,
                                kernel_size=3, stride=stride, 
                                padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion,
                                kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * self.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * self.expansion)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
 
 
class ResNet(nn.Module):
    """ResNet architecture"""
    
    def __init__(self, block, layers, num_classes=1000):
        super().__init__()
        self.in_channels = 64
        
        # Stem
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, 
                                stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # Residual stages
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        
        # Classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
    
    def _make_layer(self, block, out_channels, blocks, stride=1):
        layers = [block(self.in_channels, out_channels, stride)]
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x
 
 
# Factory functions
def resnet18(num_classes=1000):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)
 
def resnet34(num_classes=1000):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)
 
def resnet50(num_classes=1000):
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes)
 
def resnet101(num_classes=1000):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes)

7. ResNet的后续发展

7.1 ResNet变体

变体主要改进论文
ResNeXt多分支聚合-transformer风格Aggregated Residual Transformations
Wide ResNet增加宽度,减少深度Wider Residual Networks
DenseNet密集连接Densely Connected Networks
SENet通道注意力Squeeze-and-Excitation Networks

7.2 Identity Mappings(ECCV 2016)

后续论文深入分析了恒等映射的重要性,提出:

  • 预激活设计:将BN-ReLU放在Conv之前
  • 更清晰的梯度传播路径
  • 进一步提升深度网络的可训练性

8. 核心要点总结

  1. 残差学习的本质:让网络学习恒等映射的”扰动”而非直接学习底层映射
  2. Shortcut的核心作用:提供梯度的”高速公路”,直接传回浅层
  3. 退化问题的解决:残差框架使深层网络的训练变得稳定
  4. 数学简洁性 启发了后续大量研究

参考文献

Footnotes

  1. He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep Residual Learning for Image Recognition. CVPR 2016. https://arxiv.org/abs/1512.03385