残差网络与跳跃连接理论

残差网络(ResNet)由 He 等人在 2016 年提出,通过引入跳跃连接(Skip Connection)成功训练了超过 1000 层的深度神经网络,解决了深度学习中的退化问题1 本文深入分析残差学习的数学原理、梯度流特性,以及各种变体架构的设计思想。

问题背景:深度网络的退化

退化的定义

传统观点认为:网络越深,性能应该越好。但 He 等人发现,当网络深度增加到一定程度后,增加更多层反而导致训练误差和测试误差都增加

训练误差
    │
    │        ┌──────┐
    │       ╱│ 深层 │  ← 退化:深层网络的性能反而下降
    │      ╱ │ 网络  │
    │     ╱  └──────┘
    │    ╱
    │   ╱  ┌────┐
    │  ╱   │浅层│
    │ ╱    │网络 │
    │╱     └────┘
    └─────────────────────── 网络深度

退化 vs 过拟合

特性退化过拟合
训练误差↑ 增加↓ 降低
测试误差↑ 增加先降后升
根因优化困难记忆噪声
解决方案残差学习正则化

退化的根本原因

退化不是因为过拟合,而是因为优化困难。深层网络更难优化到最优解。

核心洞察:如果浅层网络已经达到某个最优解,那么增加更多层后,这些层的最优策略应该是学习恒等映射,即什么都不做。但学习”恒等”对神经网络来说并不容易。


残差学习框架

核心思想

假设我们希望学习一个映射 ,残差学习框架让网络学习:

则原始映射变为:

恒等映射的引入

传统方法:                    残差方法:
                            
输入 x ──→ [Conv] ──→ ... ──→ 输出     输入 x ──┬──→ [Conv] ──→ ... ──┬──→ 输出
                                              │                    │
                                              └─── 跳跃连接 ←───────┘

如果 是恒等映射,则 ,网络只需将权重学习到零——这比学习恒等映射容易得多。

数学形式化

残差块的前向传播:

其中 是残差函数, 是第 层的输入。

反向传播(链式法则):

关键在于:

这个加法项 保证了梯度的有效传播!


跳跃连接的梯度流分析

梯度消失问题的缓解

传统网络中,反向传播的梯度为:

当层数很深时,连乘导致梯度指数级衰减或爆炸。

残差连接的梯度保护

在残差网络中:

反向传播时:

关键性质

  1. 梯度不会消失$I$ 项保证了即使 很小,梯度也能传递
  2. 梯度不会爆炸:虽然有连乘,但每次乘法后加 ,稳定了梯度尺度
  3. 信息直接传播:信号可以直接从浅层传到深层

梯度流可视化

import torch
import torch.nn as nn
 
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
    
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = torch.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += residual  # 跳跃连接
        return torch.relu(out)
 
def analyze_gradient_flow(model, x):
    """分析残差网络中的梯度流"""
    model.eval()
    
    # Hook to capture gradients
    gradients = {}
    
    def hook_fn(module, grad_input, grad_output):
        if grad_output[0] is not None:
            gradients[module] = grad_output[0].norm().item()
    
    # 注册hook
    hooks = []
    for name, module in model.named_modules():
        if isinstance(module, ResidualBlock):
            hooks.append(module.register_backward_hook(hook_fn))
    
    # 前向+反向
    output = model(x)
    loss = output.sum()
    loss.backward()
    
    return gradients

恒等映射的重要性

预激活 vs 后激活

ResNet 最初使用后激活(Post-activation)

Original ResNet (post-activation):
x_l ──→ [BN → ReLU → Conv] → [BN → ReLU → Conv] → [+] → ReLU → x_{l+1}

后来提出的**预激活(Pre-activation)**更优雅:

Pre-activation ResNet:
x_l ──→ [BN → ReLU → Conv] → [BN → ReLU → Conv] → [+] → x_{l+1}
                          ↑                            ↑
                       跳跃连接直接加               无需额外激活

预激活的优势

  1. 是 BN-ReLU-Conv 的组合,更容易学习零映射
  2. 跳跃连接不受 BN 影响,信息传播更直接
  3. 反向传播更顺畅:梯度可以直接传到输入
class PreActResidualBlock(nn.Module):
    """
    预激活残差块
    
    更适合训练超深网络
    """
    def __init__(self, channels):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
    
    def forward(self, x):
        residual = x
        
        # 预激活
        out = self.bn1(x)
        out = torch.relu(out)
        out = self.conv1(out)
        
        out = self.bn2(out)
        out = torch.relu(out)
        out = self.conv2(out)
        
        # 恒等映射
        out += residual
        return out

ResNet 架构详解

整体结构

class ResNet(nn.Module):
    """
    ResNet 主架构
    
    Stage: [Conv, BN, ReLU, MaxPool]
    4个残差阶段: [3, 4, 6, 3] 个残差块
    """
    def __init__(self, num_classes=1000):
        super().__init__()
        
        # 初始卷积层
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1)
        )
        
        # 4个残差阶段
        self.layer1 = self._make_layer(64, 256, 3)   # 输出 56×56
        self.layer2 = self._make_layer(256, 512, 4)  # 输出 28×28
        self.layer3 = self._make_layer(512, 1024, 6) # 输出 14×14
        self.layer4 = self._make_layer(1024, 2048, 3) # 输出 7×7
        
        # 分类头
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(2048, num_classes)
    
    def _make_layer(self, in_channels, out_channels, num_blocks):
        layers = []
        
        # 下采样块(第一个块需要维度匹配)
        layers.append(BottleneckBlock(in_channels, out_channels, stride=2))
        
        # 后续块
        for _ in range(1, num_blocks):
            layers.append(BottleneckBlock(out_channels, out_channels))
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)
 
 
class BottleneckBlock(nn.Module):
    """
    Bottleneck 残差块
    
    1×1 Conv (降维) → 3×3 Conv → 1×1 Conv (升维)
    
    参数量: 1×1×C×(C/4) + 3×3×(C/4)×(C/4) + 1×1×(C/4)×C
         ≈ 0.19C² (相比 BasicBlock 的 2C² 更高效)
    """
    expansion = 4
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        
        mid_channels = out_channels // self.expansion
        
        # 1×1 降维
        self.conv1 = nn.Conv2d(in_channels, mid_channels, 1)
        self.bn1 = nn.BatchNorm2d(mid_channels)
        
        # 3×3 卷积
        self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 
                               stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(mid_channels)
        
        # 1×1 升维
        self.conv3 = nn.Conv2d(mid_channels, out_channels, 1)
        self.bn3 = nn.BatchNorm2d(out_channels)
        
        # 跳跃连接
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = torch.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)  # 跳跃连接
        return torch.relu(out)

残差块类型对比

类型结构参数量计算量适用场景
BasicBlock3×3 → 3×3小模型
Bottleneck1×1→3×3→1×1大模型
Pre-activationBN-ReLU-Conv×2同上同上超深网络

ResNet 变体架构

WideResNet

核心思想:增加通道数比增加深度更有效2

class WideResidualBlock(nn.Module):
    """
    WideResNet 的残差块
    
    通过增加宽度(通道数)提升性能
    """
    def __init__(self, in_channels, out_channels, stride=1, dropout=0.0):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, 
                               stride=stride, padding=1)
        self.dropout = nn.Dropout(dropout)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 
                               stride=1, padding=1)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=stride)
    
    def forward(self, x):
        out = torch.relu(self.bn1(x))
        out = self.conv1(out)
        out = torch.relu(self.bn2(out))
        out = self.dropout(out)
        out = self.conv2(out)
        return out + self.shortcut(x)

ResNeXt

核心思想:引入分组卷积,增加 Cardinality(基数)3

class ResNeXtBlock(nn.Module):
    """
    ResNeXt 残差块
    
    分组卷积:G 组独立卷积后 concat
    """
    def __init__(self, in_channels, out_channels, stride=1, cardinality=32, width=4):
        super().__init__()
        
        mid_channels = cardinality * width
        
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, mid_channels, 1)
        self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 
                               stride=stride, padding=1, 
                               groups=cardinality)  # 分组卷积
        self.bn2 = nn.BatchNorm2d(mid_channels)
        self.conv3 = nn.Conv2d(mid_channels, out_channels, 1)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=stride)
    
    def forward(self, x):
        out = self.conv1(torch.relu(self.bn1(x)))
        out = self.conv2(torch.relu(self.bn2(out)))
        out = self.conv3(out)
        return out + self.shortcut(x)

DenseNet

核心思想:每层接收所有前面层的特征作为输入4

class DenseBlock(nn.Module):
    """
    DenseNet 的密集块
    
    每层与所有前面层相连
    """
    def __init__(self, num_layers, in_channels, growth_rate):
        super().__init__()
        self.layers = nn.ModuleList()
        
        for i in range(num_layers):
            # 每层产生 growth_rate 个特征图
            self.layers.append(self._make_layer(in_channels + i * growth_rate, 
                                                growth_rate))
    
    def _make_layer(self, in_channels, out_channels):
        return nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels, 3, padding=1)
        )
    
    def forward(self, x):
        features = [x]
        for layer in self.layers:
            new_feature = layer(torch.cat(features, dim=1))
            features.append(new_feature)
        return torch.cat(features, dim=1)

架构对比

架构核心思想跳跃连接特点
ResNet残差学习逐元素加法简单有效
WideResNet增加宽度同 ResNet更宽更快
ResNeXt分组卷积同 ResNet高效且强
DenseNet密集连接Concat特征复用
SENet通道注意力同 ResNet自适应校准

残差连接与 Neural ODE

联系与类比

残差网络可以看作离散化的常微分方程(ODE):

当层数趋于无穷,步长趋于零时:

这正是 Neural ODE5 的起点!

ODE 视角的优势

class ODEBlock(nn.Module):
    """
    Neural ODE Block
    
    用 ODE 求解器替代固定层数
    """
    def __init__(self, odefunc):
        super().__init__()
        self.odefunc = odefunc
    
    def forward(self, x, t=torch.tensor([0, 1])):
        from torchdiffeq import odeint
        return odeint(self.odefunc, x, t)

优势

  1. 自适应计算:根据输入复杂度决定计算量
  2. 内存效率:可逆计算
  3. 连续深度:没有离散的”第几层”概念

跳跃连接的类型

1. 恒等跳跃连接(Identity Skip)

# 直接相加
output = F(x) + x

2. 投影跳跃连接(Projection Skip)

# 当维度不匹配时
output = F(x) + projection(x)

3. 门控跳跃连接(Gated Skip)

# Highway Network 风格
gate = torch.sigmoid(W_g(x))
output = gate * F(x) + (1 - gate) * x

4. 稀疏跳跃连接

# 可学习的稀疏连接
gate = torch.sigmoid(W_g(x))
mask = (gate > threshold).float()
output = mask * F(x) + (1 - mask) * x

实践指南

残差网络的实现要点

class MyResNet(nn.Module):
    def __init__(self, in_channels, hidden_channels, num_blocks):
        super().__init__()
        
        # 初始化层(不含残差)
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, 3, padding=1),
            nn.BatchNorm2d(hidden_channels),
            nn.ReLU()
        )
        
        # 残差层
        self.layers = nn.ModuleList([
            ResidualBlock(hidden_channels) 
            for _ in range(num_blocks)
        ])
    
    def forward(self, x):
        x = self.stem(x)
        for layer in self.layers:
            x = layer(x)  # 跳跃连接在内部
        return x

常见问题与解决

# 问题1: 维度不匹配
class ResidualBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        # 投影处理维度不匹配
        self.shortcut = nn.Sequential()
        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=stride)
    
    def forward(self, x):
        # 恒等连接保证信息流
        return self.F(x) + self.shortcut(x)
 
 
# 问题2: 训练不稳定
# → 使用预激活 + 适当的初始化
def init_weights(module):
    for m in module.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out')
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

适用场景

# ✅ 适合使用残差连接的场景:
# 1. 网络深度 ≥ 20 层
# 2. 需要训练超深网络
# 3. 图像分类、检测、分割
# 4. 需要特征复用
 
# ❌ 可能不需要残差连接的场景:
# 1. 网络深度 < 10 层
# 2. 极其浅的网络
# 3. 某些生成任务(可能干扰风格)

核心公式速查

概念公式
残差学习
前向传播
梯度流
恒等映射
预激活

参考

相关文章

Footnotes

  1. He, K., Zhang, X., Ren, S., & Sun, J. (2016). “Deep Residual Learning for Image Recognition”. CVPR 2016. https://arxiv.org/abs/1512.03385

  2. Zagoruyko, S., & Komodakis, N. (2016). “Wide Residual Networks”. BMVC 2016. https://arxiv.org/abs/1605.07146

  3. Xie, S., et al. (2017). “Aggregated Residual Transformations for Deep Neural Networks”. CVPR 2017. https://arxiv.org/abs/1611.05431

  4. Huang, G., et al. (2017). “Densely Connected Convolutional Networks”. CVPR 2017. https://arxiv.org/abs/1608.06993

  5. Chen, R.T.Q., et al. (2018). “Neural Ordinary Differential Equations”. NeurIPS 2018. https://arxiv.org/abs/1806.07366