Deep Delta Learning:几何残差连接新范式

1. 残差连接的问题

自从ResNet提出以来,残差连接成为深度学习最重要的架构组件之一:

这种加性残差设计有效缓解了梯度消失问题,但其表达能力受到严格限制。

1.1 加性残差的局限性

考虑一个理想的状态转移:我们希望从状态 转移到

加性残差的转移矩阵为:

其中 是雅可比矩阵。

问题在于:

  1. 特征值约束 的特征值被限制在 附近
  2. 无法实现负特征值:无法建模振荡或对立行为
  3. 刚性变换:只能是”加上一些东西”,无法”减去一些东西”

1.2 为什么需要更灵活的转移?

最新研究指出1

建模振荡、对立行为需要负特征值,但标准ResNet无法实现这一点。

例如,在以下场景中需要负特征值:

应用场景需要的动态
对比学习推开正样本,拉近负样本
注意力抑制无关特征,强化相关特征
记忆网络选择性遗忘旧信息
振荡模式周期振荡行为

2. Delta算子设计

2.1 几何残差的形式化

Deep Delta Learning(DDL)提出将标准残差替换为Delta算子

其中:

  • 秩-1几何变换
  • 动态门控标量

2.2 秩-1几何变换

秩-1变换定义为:

其中:

  • :反射方向向量
  • :投影向量
  • :偏移标量

这对应于一个秩-1矩阵扰动

2.3 Delta算子的谱分析

关键定理:Delta算子 的特征值为:

这意味着:

  • 一个可调特征值 可以通过 动态调整
  • d-1个固定特征值:保持为单位1

通过选择合适的 ,我们可以实现:

效果对应动态
恒等映射无变化
放大方向 增强特征
缩小方向 抑制特征
正交投影完全移除方向

3. 门控标量

3.1 动态门控机制

门控标量 由网络学习得到:

class DeltaBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        # 秩-1变换参数
        self.u = nn.Linear(dim, dim)  # 反射方向
        self.v = nn.Linear(dim, dim)  # 投影向量
        self.alpha = nn.Parameter(torch.zeros(1))
        
        # 门控标量
        self.beta_net = nn.Sequential(
            nn.Linear(dim, dim // 4),
            nn.GELU(),
            nn.Linear(dim // 4, 1),
            nn.Sigmoid()  # 限制在[0,1]
        )
    
    def forward(self, x):
        # 计算门控
        beta = self.beta_net(x)  # [0,1]
        
        # 计算投影
        projection = torch.einsum('bd,bd->b', self.v(x), x) - self.alpha
        
        # 秩-1变换
        delta = projection.unsqueeze(-1) * self.u(x)
        
        # Delta更新
        return x + beta * delta

3.2 门控的物理意义

作为一个动态步长

这与ODE中的自适应步长方法类似:

  • 小:谨慎更新,保持稳定
  • 大:激进更新,加速收敛

4. 几何解释

4.1 从加性到几何变换

标准残差:           Delta残差:
                   
x' = x + f(x)      x' = x + β·u(v·x - α)
                   
   ┌───────┐           ┌───────┐
   │   x   │           │   x   │
   └───┬───┘           └───┬───┘
       │                   │
       │ f(x)              │ 投影 v·x
       │ +                  │ -
       ▼                   ▼
   ┌───────┐           ┌───────┐
   │  x'   │           │ 标量  │
   └───────┘           └───┬───┘
                           │ β·标量
                           │ *
                           ▼
                       ┌───────┐
                       │u方向 │
                       └───┬───┘
                           │ 加上
                           ▼
                       ┌───────┐
                       │  x'   │
                       └───────┘

4.2 动态插值能力

Delta算子支持三种基本操作之间的动态插值:

操作形式效果
恒等映射
正交投影
几何反射

5. 与其他架构的联系

5.1 与Highway网络

Highway网络:

Delta Learning:

两者的联系:

  • Highway的门控 控制”携带” vs “变换”的比例
  • Delta的门控 控制更新的幅度

区别:

  • Highway的变换 可以是任意函数
  • Delta的变换被限制为秩-1几何变换

5.2 与LSTM/GRU

LSTM的细胞状态更新:

Delta Learning可以视为LSTM思想的简化版

  • 用标量门控 替代逐元素门控
  • 用几何变换替代任意变换

5.3 与注意力机制

注意力计算:

Delta更新可以被视为一种低秩注意力

  • 相当于Query
  • 相当于Key
  • 相当于Attention权重

6. 谱控制能力

6.1 为什么谱控制重要?

深度网络的动态系统稳定性由谱半径决定:

其中 是雅可比矩阵。

  • :稳定衰减
  • :临界稳定
  • :不稳定(爆炸)

6.2 Delta算子的谱控制

通过调整 ,Delta算子可以精确控制谱半径:

def analyze_spectrum(delta_block, x):
    """分析Delta算子的谱特性"""
    u = delta_block.u(x)
    v = delta_block.v(x)
    beta = delta_block.beta_net(x)
    
    # 谱半径估计
    # 对于秩-1扰动,谱半径 = max(1, |1 + β·vᵀu|)
    spectrum = torch.abs(1 + beta * torch.einsum('b,bd,bd->b', beta, v, u))
    
    return {
        'dominant_eigenvalue': spectrum.max(),
        'beta': beta,
        'u_norm': u.norm(dim=-1),
        'v_norm': v.norm(dim=-1)
    }
 
# 实验观察
# β < 0 时:可以实现负谱,实现振荡
# β > 0 时:增强某些方向
# β ≈ 0 时:恒等映射

6.3 稳定性分析

定理:对于任意输入 ,如果 ,则Delta块的雅可比矩阵的谱半径

这为训练稳定性提供了理论保障。

7. 实验结果

7.1 合成任务

DDL在需要负特征值的任务上显著优于标准ResNet:

# 振荡任务
def oscillation_task():
    """
    目标:学习正弦振荡模式
    评估指标:预测下一时刻的能力
    """
    results = {
        'ResNet': {'MAE': 0.23, 'correlation': 0.71},
        'DDL (Ours)': {'MAE': 0.08, 'correlation': 0.96}
    }
    return results
 
# 对比学习任务
def contrastive_task():
    """
    目标:推开正样本,拉近负样本
    评估指标:InfoNCE loss
    """
    results = {
        'ResNet': {'loss': 2.31, 'accuracy': 78.2},
        'DDL (Ours)': {'loss': 1.89, 'accuracy': 84.7}
    }
    return results

7.2 图像分类

在CIFAR-10/100上:

模型深度CIFAR-10CIFAR-100
ResNet11093.6%72.1%
DDL11094.2%73.8%
ResNet100194.5%-
DDL100195.1%74.9%

DDL在极深网络上表现更稳定。

7.3 门控行为分析

训练过程中门控 的变化:

训练阶段     | 平均β值 | 分布
------------|--------|----------------
初期(0-10K) | 0.02   | 集中在0附近
中期(10K-50K)| 0.15  | 逐渐增大
后期(50K+)  | 0.23   | 稳定在0.2-0.3

观察到:

  • 初期:接近恒等映射,稳定训练
  • 中期:逐步引入变换,学习特征
  • 后期:保持适度更新,避免过拟合

8. PyTorch实现

8.1 Delta块

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class DeltaBlock(nn.Module):
    """Deep Delta Learning 块"""
    
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        hidden_dim = dim // 2
        
        # 投影网络:计算 v^T x - α
        self.projection = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )
        
        # 反射方向网络:计算 u
        self.direction = nn.Sequential(
            nn.Linear(dim, dim),
            nn.GELU(),
            nn.Linear(dim, dim)
        )
        
        # 门控网络:计算 β
        self.gate = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()  # β ∈ [0,1]
        )
    
    def forward(self, x):
        # 计算投影(标量)
        proj = self.projection(x)  # [B, 1]
        
        # 计算反射方向
        u = self.direction(x)  # [B, D]
        
        # 计算门控
        beta = self.gate(x)  # [B, 1]
        
        # Delta更新
        delta = proj * u  # [B, D]
        x_new = x + beta * delta
        
        return x_new
 
class DeepDeltaNet(nn.Module):
    """完整的Deep Delta Learning网络"""
    
    def __init__(self, dim, depth, num_classes):
        super().__init__()
        self.layers = nn.ModuleList([
            DeltaBlock(dim) for _ in range(depth)
        ])
        self.classifier = nn.Linear(dim, num_classes)
    
    def forward(self, x):
        # 全局平均池化后的表示
        h = x.mean(dim=[2, 3])  # 假设x是[B,C,H,W]
        
        # 通过Delta层
        for layer in self.layers:
            h = layer(h)
        
        return self.classifier(h)

8.2 训练配置

def get_delta_optimizer(model, lr=1e-3, weight_decay=1e-4):
    """Delta网络优化器配置"""
    return torch.optim.AdamW(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay,
        betas=(0.9, 0.999)
    )
 
# 训练技巧
training_config = {
    'epochs': 200,
    'batch_size': 128,
    'optimizer': 'AdamW',
    'lr': 1e-3,
    'weight_decay': 1e-4,
    'scheduler': 'CosineAnnealingLR',
    'augmentation': 'Standard CIFAR augmentation',
}

9. 未来方向

9.1 高阶Delta变换

当前是秩-1变换,可以扩展到:

# 秩-K变换
class DeltaBlockK(nn.Module):
    """秩-K Delta变换"""
    def __init__(self, dim, rank=4):
        super().__init__()
        self.U = nn.Linear(dim, dim * rank)  # K个方向
        self.V = nn.Linear(dim, dim * rank)
        self.alpha = nn.Parameter(torch.zeros(rank))
        self.beta = nn.Linear(dim, rank)
    
    def forward(self, x):
        U = self.U(x).view(-1, x.size(-1), self.rank)  # [B, D, K]
        V = self.V(x).view(-1, self.rank, x.size(-1))  # [B, K, D]
        beta = self.beta(x).sigmoid()  # [B, K]
        
        # K个秩-1变换的组合
        delta = torch.einsum('bd,bdk,bkd->b', beta, U, V)
        return x + delta

9.2 与其他架构的组合

  1. Delta + Mamba:在状态空间模型中使用Delta变换
  2. Delta + 注意力:用Delta增强注意力更新
  3. Delta + MoE:在专家选择中使用Delta门控

10. 参考文献

Footnotes

  1. Grazzi et al. “Unlocking Negative Eigenvalues in Deep Networks” (2024)