MUDDFormer多路动态密集连接

概述

MUDDFormer(Multiway Dynamic Dense connections)是ICML 2025提出的Transformer架构改进方法,旨在解决传统残差连接的信息稀释问题1

标准残差连接通过 直接加和,这种静态融合策略在深层网络中会稀释关键信息。MUDDFormer通过动态生成连接权重,实现输入依赖的自适应信息流控制。

核心思想

传统残差连接的问题

标准残差连接存在以下问题:

  1. 信息稀释:加法操作使深层网络难以保留浅层关键信息
  2. 静态权重:无论输入内容如何,残差分支和恒等分支的权重比为
  3. 梯度路径:深层梯度必须经过非线性变换,梯度流效率低

多路动态密集连接

MUDDFormer为每个输入流(Query、Key、Value、残差)动态生成连接权重:

其中:

  • 是输入流集合
  • 是动态生成的门控权重
  • 是对应的变换操作

动态权重生成

权重生成器 由以下组件构成:

  1. 序列位置编码:捕捉token位置信息
  2. 输入流特征:编码当前流的语义内容
  3. 温度参数:控制权重的软硬程度

架构设计

MUDDFormer块结构

输入 x
    │
    ├──→ [Q流] ─→ W_q ─→ Q
    │              ↑
    │          G_Q(x)
    │
    ├──→ [K流] ─→ W_k ─→ K
    │              ↑
    │          G_K(x)
    │
    ├──→ [V流] ─→ W_v ─→ V
    │              ↑
    │          G_V(x)
    │
    └──→ [R流] ─→ F(x) ─→ 残差
                    ↑
                G_R(x)
    │
    └──→ 加权求和 ─→ LayerNorm ─→ 输出

权重生成网络

class MUDDWeightGenerator(nn.Module):
    def __init__(self, d_model, n_streams=4):
        super().__init__()
        self.n_streams = n_streams
        
        # 位置编码
        self.pos_encoder = PositionalEncoding(d_model)
        
        # 共享特征提取器
        self.feature_extractor = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, d_model)
        )
        
        # 每个流的独立生成器
        self.weight_heads = nn.ModuleList([
            nn.Linear(d_model, 1) for _ in range(n_streams)
        ])
        
    def forward(self, x, positions):
        # x: [B, N, D]
        pos_emb = self.pos_encoder(positions)  # 位置编码
        feat = self.feature_extractor(x) + pos_emb  # 融合特征
        
        # 生成各流权重
        logits = torch.cat([
            head(feat) for head in self.weight_heads
        ], dim=-1)  # [B, N, n_streams]
        
        # 温度软化
        weights = F.softmax(logits, dim=-1)
        
        return weights.chunk(self.n_streams, dim=-1)

与标准Transformer的集成

class MUDDAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout)
        self.mudd_gen = MUDDWeightGenerator(d_model, n_streams=4)
        self.norm = nn.LayerNorm(d_model)
        
    def forward(self, x, positions=None):
        B, N, D = x.shape
        if positions is None:
            positions = torch.arange(N, device=x.device).unsqueeze(0).expand(B, -1)
        
        # 生成MUDD权重
        w_q, w_k, w_v, w_r = self.mudd_gen(x, positions)
        
        # Q, K, V变换
        Q = x @ self.W_q.weight.T
        K = x @ self.W_k.weight.T  
        V = x @ self.W_v.weight.T
        
        # 残差路径
        residual = self.ffn(x)
        
        # 加权融合
        x_transformed = w_q * Q + w_k * K + w_v * V + w_r * residual
        
        # 自注意力
        attn_out, _ = self.attention(x_transformed, x_transformed, x_transformed)
        
        return self.norm(x + attn_out)

理论分析

信息保留能力

传统残差:,浅层信息以 比例传递到第 层。

MUDDFormer通过动态门控可实现:

其中 是数据依赖的门控值,关键信息可获得更高

参数效率

方法额外参数性能提升
标准残差0%基线
Dense连接~20%~10%
MUDDFormer0.23%匹配1.8-2.4×更大模型

MUDDFormer仅增加0.23%参数,即可达到相当于训练计算量增加1.8-2.4倍的效果。

实验结果

语言建模

模型参数量训练TokensValidation PPL
Pythia-2.8B2.8B300B8.32
MUDDPythia-2.8B2.8B300B7.89
Pythia-6.9B6.9B300B7.91

MUDDPythia-2.8B在预训练困惑度上匹配Pythia-6.9B(参数量减少60%)。

下游任务

任务Pythia-6.9BMUDDPythia-2.8BPythia-12B
LAMBADA68.269.171.3
HellaSwag52.451.853.9
PIQA76.877.278.1

计算效率

  • 训练速度:与基线相同(无额外计算开销)
  • 内存占用:增加 <1%
  • 推理延迟:无显著增加

与DeepCrossAttention的对比

特性MUDDFormerDeepCrossAttention
连接方式动态加权求和深度交叉注意力
额外参数0.23%0.2%
加速比1.8-2.4×
设计重点残差信息流层间交互

应用建议

  1. 资源受限场景:MUDDFormer是在有限计算预算下提升性能的利器
  2. 深层模型:推荐用于12层以上的Transformer
  3. 与MoE结合:可与专家混合架构协同

参考资料

相关链接

Footnotes

  1. Xiao et al. “MUDDFormer: Breaking Residual Bottlenecks in Transformers via Multiway Dynamic Dense Connections” ICML 2025. arXiv:2502.12170