DeepCrossAttention深度交叉注意力

概述

DeepCrossAttention(DCA)是ICML 2025提出的增强Transformer残差学习的方法,通过深度方向交叉注意力动态融合不同层的输出信息。1

与MUDDFormer解决残差信息稀释问题不同,DCA专注于层间交互增强,通过交叉注意力机制实现更深层的信息流动。

问题背景

标准残差连接的限制

传统残差连接 存在以下局限:

  1. 静态融合:无论层深度如何,融合权重始终为1:1
  2. 信息选择性缺失:无法智能选择哪些层的信息更重要
  3. 梯度重复:深层梯度必须经过相同的非线性路径

深度信息流失

在深层Transformer中:

  • 早期层捕获的底层特征(词法、语法)被逐渐稀释
  • 后期层主要携带高层语义信息
  • 关键的低层线索(如位置信息)在深层几乎消失

核心方法

动态权重组合

DCA引入可学习的输入依赖权重,动态决定如何组合各层输出:

其中:

  • 是第 层的输出
  • 是数据依赖的组合权重

深度方向交叉注意力

关键创新是跨深度的注意力机制

class DeepCrossAttention(nn.Module):
    def __init__(self, d_model, n_heads, max_depth=12):
        super().__init__()
        self.max_depth = max_depth
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # 深度编码
        self.depth_embedding = nn.Embedding(max_depth, d_model)
        
        # 交叉注意力查询生成
        self.query_proj = nn.Linear(d_model, d_model)
        
        # 注意力输出投影
        self.out_proj = nn.Linear(d_model, d_model)
        
    def forward(self, layer_outputs, current_depth):
        """
        layer_outputs: List of [B, N, D] tensors from previous layers
        current_depth: 当前层深度
        """
        B, N, D = layer_outputs[0].shape
        num_layers = len(layer_outputs)
        
        # 生成当前层查询
        depth_emb = self.depth_embedding(
            torch.tensor(current_depth, device=layer_outputs[0].device)
        )
        query = self.query_proj(depth_emb).unsqueeze(0).unsqueeze(0)  # [1, 1, D]
        query = query.expand(B, N, -1)  # [B, N, D]
        
        # 拼接历史层输出
        history = torch.stack(layer_outputs, dim=2)  # [B, N, num_layers, D]
        
        # 生成各层的键
        keys = self.depth_embedding(
            torch.arange(num_layers, device=history.device)
        ).unsqueeze(0).unsqueeze(0)  # [1, 1, num_layers, D]
        keys = keys.expand(B, N, -1, -1) + history
        
        # 展平用于注意力计算
        history_flat = history.view(B, N, num_layers * D)
        keys_flat = keys.view(B, N, num_layers, D)
        
        # 多头注意力
        Q = query.view(B * N, 1, D)
        K = keys_flat.view(B * N, num_layers, D)
        V = history_flat.view(B * N, num_layers, D)
        
        scores = torch.bmm(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        attn_weights = F.softmax(scores, dim=-1)
        
        context = torch.bmm(attn_weights, V)  # [B*N, 1, D]
        context = context.view(B, N, D)
        
        return self.out_proj(context)

理论分析

准确率-模型大小权衡

DCA的理论分析表明,当集体层秩与环境维度之比低于某个阈值时,DCA提供更优的准确率-模型大小权衡。

定义集体层秩 ,则:

其中 是与任务相关的常数。

梯度流改善

DCA为深层梯度提供多条路径

  • 标准路径:
  • 跨层路径:(通过交叉注意力)

这有效缓解了深层网络的梯度消失问题。

实验结果

训练效率

方法达到相同质量的时间额外参数
标准Transformer1.0×0%
DCA0.33×0.2%

DCA实现3倍训练加速,同时仅增加0.2%参数。

语言建模基准

在C4数据集上的困惑度对比:

模型规模标准DCA提升
125M24.323.15.0%
355M18.717.47.0%
1B14.212.99.2%

质量-计算权衡曲线

DCA在帕累托前沿上显著优于基线:

  • 相同训练时间下,DCA模型困惑度更低
  • 相同模型质量下,DCA训练时间减少60-70%

实现集成

与标准Transformer块集成

class DCATransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, max_depth=48):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads)
        self.dca = DeepCrossAttention(d_model, n_heads, max_depth)
        self.ffn = FeedForward(d_model)
        self.norm1 = LayerNorm(d_model)
        self.norm2 = LayerNorm(d_model)
        self.norm_dca = LayerNorm(d_model)
        
        self.layer_outputs = []  # 存储历史层输出
        
    def forward(self, x):
        # 标准自注意力
        attn_out = self.attention(x)
        x = self.norm1(x + attn_out)
        
        # 前馈网络
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        
        # 存储用于DCA
        self.layer_outputs.append(x)
        if len(self.layer_outputs) > self.max_depth:
            self.layer_outputs.pop(0)
        
        # 深度交叉注意力(每N层应用一次)
        if len(self.layer_outputs) % 6 == 0:
            dca_out = self.dca(self.layer_outputs, len(self.layer_outputs))
            x = self.norm_dca(x + dca_out)
        
        return x

与MUDDFormer的对比

维度DeepCrossAttentionMUDDFormer
核心思想深度方向交叉注意力动态加权残差
关注点层间信息选择残差信息稀释
权重生成注意力机制门控网络
额外参数0.2%0.23%
训练加速1.8-2.4×
适用场景深层堆叠计算受限

应用实践

推荐配置

  1. DCA应用频率:每6-8层应用一次DCA
  2. max_depth设置:通常为总层数的1.5-2倍
  3. 与标准残差共存:保持标准残差作为主要路径

内存考虑

DCA需要存储历史层输出,内存开销约为:

对于长序列场景,建议限制max_depth。

参考资料

相关链接

Footnotes

  1. Heddes et al. “DeepCrossAttention: Supercharging Transformer Residual Connections” ICML 2025. arXiv:2502.06785