引言

残差连接(Residual Connection)是现代深度神经网络的核心组件,从ResNet开始成为几乎所有深度架构的标准配置。He等人提出残差连接来解决深度网络的梯度消失问题,使得训练数百层的网络成为可能。

近年来,研究者从理论和实践两个角度对残差连接进行了深入改进。本文分析这些最新进展,包括LAuReL(Learned Augmented Residual Layer)和Branch Scaling等技术。


残差连接基础回顾

标准残差连接

给定输入 ,残差块定义为:

其中 是可学习的映射(通常包含多个卷积层或全连接层)。

梯度流优势

残差连接提供了梯度的”高速公路”:

即使 很小,梯度仍能直接流过恒等映射。

信号传播分析

初始化时,期望残差分支输出接近零:

这确保了深层网络的信息传递与浅层网络相近。


标准残差连接的局限性

1. 表达能力受限

恒等映射 缺乏灵活性,无法根据输入动态调整信息流。

2. 最优缩放困难

残差分支的输出可能与输入尺度不匹配:

固定的 可能不是最优的。

3. 隐式正则化不明确

残差连接对训练动态的隐式正则化效应缺乏理论理解。

4. 浅层信息丢失

深层网络的浅层特征可能无法有效传递到深层。


Branch Scaling:隐式架构正则化

核心发现

研究发现,BatchNorm等归一化技术与残差分支缩放存在等价关系:

理论分析

定理:对于带BatchNorm的ResNet,存在等价的分支缩放因子 ,使得:

其中 由BatchNorm参数决定。

Branch Scaling Manifestation

训练过程中,缩放因子的演化表现出隐式正则化效应:

# 观察Branch Scaling的演化
for block in resnet.blocks:
    effective_scale = block.bn.weight / block.bn.running_var.sqrt()
    # 通常收敛到某个稳定值
    print(f"Scale: {effective_scale.mean():.4f}")

实践观察

现象解释
缩放因子逐渐减小隐式降低残差分支贡献
不同层缩放因子不同自适应调整信息流
与学习率相关大学习率导致更小的缩放

LAuReL:学习增强的残差层

核心思想

LAuReL(Learned Augmented Residual Layer)用可学习的门控机制替代固定的恒等映射:

其中 是输入相关的门控。

架构对比

标准残差:                          LAuReL:
┌─────────┐                      ┌─────────┐
│    x    │                      │    x    │
└────┬────┘                      └────┬────┘
     │                                │
     ▼                                ▼
┌─────────┐                      ┌─────────┐
│    F    │                      │   g(x)  │
└────┬────┘                      └───┬─────┘
     │                                │
     ▼                                ▼
┌─────────┐      +      ┌─────────┐   ▼
│  x + F  │  ────────  │ 1-g(x)  │── + ── + ──
└─────────┘            └─────────┘   │
                                       ▼
                                  ┌─────────┐
                                  │ g*F + (1-g)*x │
                                  └─────────┘

门控机制设计

class LAuReLBlock(nn.Module):
    def __init__(self, dim, expansion=4):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, dim * expansion),
            nn.GELU(),
            nn.Linear(dim * expansion, dim)
        )
        # 可学习的门控
        self.gate = nn.Sequential(
            nn.Linear(dim, dim),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        # 门控决定残差vs恒等的比例
        g = self.gate(self.norm1(x))
        return g * self.ff(self.norm2(x)) + (1 - g) * x

优势分析

特性标准残差LAuReL
恒等映射固定自适应
信息流控制隐式显式
参数增加0~2%
表达能力有限增强

实验结果

任务标准残差LAuReL提升
ImageNet分类78.3%79.1%+0.8%
COCO检测42.5 mAP43.2 mAP+0.7%
GLUE基准83.183.8+0.7%

Res-SE-Net:增强桥接连接

SE模块回顾

Squeeze-and-Excitation (SE) 模块通过通道注意力自适应调整特征:

Res-SE-Net改进

Res-SE-Net将SE模块应用于残差连接的”桥接”部分:

多尺度SE

class MultiScaleSE(nn.Module):
    def __init__(self, channels, scales=[1, 2, 4]):
        super().__init__()
        self.se_blocks = nn.ModuleList([
            nn.Sequential(
                nn.AdaptiveAvgPool2d(s),
                nn.Conv2d(channels, channels // r, 1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channels // r, channels, 1),
                nn.Sigmoid()
            ) for s in scales
        ])
    
    def forward(self, x):
        # 多尺度注意力融合
        attention = sum(se(x) for se in self.se_blocks) / len(self.scales)
        return attention * x

其他残差连接变体

1. ReZero

可学习的缩放因子 ,初始化为0。

2. SkipInit

引入额外的可学习标量:

其中 通过梯度下降学习。

3. SkipDense

密集连接与残差连接的结合:

4. Highway Network

门控机制控制信息流:


DeepCrossAttention:交叉注意力增强残差

核心思想

DeepCrossAttention (DCA) 使用交叉注意力动态组合残差路径:

其中 由交叉注意力机制动态生成。

实现

class DeepCrossAttention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(dim, num_heads)
        self.norm = nn.LayerNorm(dim)
        
        # 动态权重预测
        self.weight_predictor = nn.Sequential(
            nn.Linear(dim, dim // 4),
            nn.GELU(),
            nn.Linear(dim // 4, 2)
        )
    
    def forward(self, x, branch_out):
        # 预测动态权重
        w = self.weight_predictor(x)
        w = F.softmax(w, dim=-1)
        
        # 加权组合
        alpha_skip, alpha_branch = w[..., 0], w[..., 1]
        return alpha_skip * x + alpha_branch * branch_out

实践建议

何时使用改进的残差连接

场景推荐方案
标准图像分类标准残差或LAuReL
超深网络(>1000层)ReZero或LAuReL
轻量级模型SkipInit
需要显式控制LAuReL或DCA

初始化建议

方法初始化策略
标准残差
LAuReL 初始化
ReZero
SkipInit

与其他组件的兼容性

  • 与BatchNorm:兼容良好,注意Branch Scaling效应
  • 与Dropout:通常在残差分支内使用
  • 与注意力机制:LAuReL和DCA天然适合

总结

技术演进

ResNet (2015)          LAuReL (2024)          未来方向
    │                      │                    │
    ▼                      ▼                    ▼
固定恒等映射        自适应门控机制          动态计算路径
隐式优化           显式控制               元学习优化

关键要点

  1. 残差连接仍是深度网络的核心,但固定形式存在局限
  2. LAuReL通过门控机制增强表达能力,同时保持高效
  3. Branch Scaling揭示了归一化与残差缩放的内在联系
  4. 选择合适的残差变体需要根据具体任务和计算约束

未来展望

  • 自适应残差连接:根据输入和任务动态调整
  • 理论深化:理解隐式正则化效应的深层机制
  • 与稀疏计算的结合:硬件友好的动态残差设计

参考