DeepCrossAttention深度交叉注意力
概述
DeepCrossAttention(DCA)是ICML 2025提出的增强Transformer残差学习的方法,通过深度方向交叉注意力动态融合不同层的输出信息。1
与MUDDFormer解决残差信息稀释问题不同,DCA专注于层间交互增强,通过交叉注意力机制实现更深层的信息流动。
问题背景
标准残差连接的限制
传统残差连接 存在以下局限:
- 静态融合:无论层深度如何,融合权重始终为1:1
- 信息选择性缺失:无法智能选择哪些层的信息更重要
- 梯度重复:深层梯度必须经过相同的非线性路径
深度信息流失
在深层Transformer中:
- 早期层捕获的底层特征(词法、语法)被逐渐稀释
- 后期层主要携带高层语义信息
- 关键的低层线索(如位置信息)在深层几乎消失
核心方法
动态权重组合
DCA引入可学习的输入依赖权重,动态决定如何组合各层输出:
其中:
- 是第 层的输出
- 是数据依赖的组合权重
深度方向交叉注意力
关键创新是跨深度的注意力机制:
class DeepCrossAttention(nn.Module):
def __init__(self, d_model, n_heads, max_depth=12):
super().__init__()
self.max_depth = max_depth
self.n_heads = n_heads
self.d_k = d_model // n_heads
# 深度编码
self.depth_embedding = nn.Embedding(max_depth, d_model)
# 交叉注意力查询生成
self.query_proj = nn.Linear(d_model, d_model)
# 注意力输出投影
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, layer_outputs, current_depth):
"""
layer_outputs: List of [B, N, D] tensors from previous layers
current_depth: 当前层深度
"""
B, N, D = layer_outputs[0].shape
num_layers = len(layer_outputs)
# 生成当前层查询
depth_emb = self.depth_embedding(
torch.tensor(current_depth, device=layer_outputs[0].device)
)
query = self.query_proj(depth_emb).unsqueeze(0).unsqueeze(0) # [1, 1, D]
query = query.expand(B, N, -1) # [B, N, D]
# 拼接历史层输出
history = torch.stack(layer_outputs, dim=2) # [B, N, num_layers, D]
# 生成各层的键
keys = self.depth_embedding(
torch.arange(num_layers, device=history.device)
).unsqueeze(0).unsqueeze(0) # [1, 1, num_layers, D]
keys = keys.expand(B, N, -1, -1) + history
# 展平用于注意力计算
history_flat = history.view(B, N, num_layers * D)
keys_flat = keys.view(B, N, num_layers, D)
# 多头注意力
Q = query.view(B * N, 1, D)
K = keys_flat.view(B * N, num_layers, D)
V = history_flat.view(B * N, num_layers, D)
scores = torch.bmm(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
attn_weights = F.softmax(scores, dim=-1)
context = torch.bmm(attn_weights, V) # [B*N, 1, D]
context = context.view(B, N, D)
return self.out_proj(context)理论分析
准确率-模型大小权衡
DCA的理论分析表明,当集体层秩与环境维度之比低于某个阈值时,DCA提供更优的准确率-模型大小权衡。
定义集体层秩 ,则:
其中 是与任务相关的常数。
梯度流改善
DCA为深层梯度提供多条路径:
- 标准路径:
- 跨层路径:(通过交叉注意力)
这有效缓解了深层网络的梯度消失问题。
实验结果
训练效率
| 方法 | 达到相同质量的时间 | 额外参数 |
|---|---|---|
| 标准Transformer | 1.0× | 0% |
| DCA | 0.33× | 0.2% |
DCA实现3倍训练加速,同时仅增加0.2%参数。
语言建模基准
在C4数据集上的困惑度对比:
| 模型规模 | 标准 | DCA | 提升 |
|---|---|---|---|
| 125M | 24.3 | 23.1 | 5.0% |
| 355M | 18.7 | 17.4 | 7.0% |
| 1B | 14.2 | 12.9 | 9.2% |
质量-计算权衡曲线
DCA在帕累托前沿上显著优于基线:
- 相同训练时间下,DCA模型困惑度更低
- 相同模型质量下,DCA训练时间减少60-70%
实现集成
与标准Transformer块集成
class DCATransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, max_depth=48):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads)
self.dca = DeepCrossAttention(d_model, n_heads, max_depth)
self.ffn = FeedForward(d_model)
self.norm1 = LayerNorm(d_model)
self.norm2 = LayerNorm(d_model)
self.norm_dca = LayerNorm(d_model)
self.layer_outputs = [] # 存储历史层输出
def forward(self, x):
# 标准自注意力
attn_out = self.attention(x)
x = self.norm1(x + attn_out)
# 前馈网络
ffn_out = self.ffn(x)
x = self.norm2(x + ffn_out)
# 存储用于DCA
self.layer_outputs.append(x)
if len(self.layer_outputs) > self.max_depth:
self.layer_outputs.pop(0)
# 深度交叉注意力(每N层应用一次)
if len(self.layer_outputs) % 6 == 0:
dca_out = self.dca(self.layer_outputs, len(self.layer_outputs))
x = self.norm_dca(x + dca_out)
return x与MUDDFormer的对比
| 维度 | DeepCrossAttention | MUDDFormer |
|---|---|---|
| 核心思想 | 深度方向交叉注意力 | 动态加权残差 |
| 关注点 | 层间信息选择 | 残差信息稀释 |
| 权重生成 | 注意力机制 | 门控网络 |
| 额外参数 | 0.2% | 0.23% |
| 训练加速 | 3× | 1.8-2.4× |
| 适用场景 | 深层堆叠 | 计算受限 |
应用实践
推荐配置
- DCA应用频率:每6-8层应用一次DCA
- max_depth设置:通常为总层数的1.5-2倍
- 与标准残差共存:保持标准残差作为主要路径
内存考虑
DCA需要存储历史层输出,内存开销约为:
对于长序列场景,建议限制max_depth。
参考资料
相关链接
Footnotes
-
Heddes et al. “DeepCrossAttention: Supercharging Transformer Residual Connections” ICML 2025. arXiv:2502.06785 ↩