MUDDFormer多路动态密集连接
概述
MUDDFormer(Multiway Dynamic Dense connections)是ICML 2025提出的Transformer架构改进方法,旨在解决传统残差连接的信息稀释问题。1
标准残差连接通过 直接加和,这种静态融合策略在深层网络中会稀释关键信息。MUDDFormer通过动态生成连接权重,实现输入依赖的自适应信息流控制。
核心思想
传统残差连接的问题
标准残差连接存在以下问题:
- 信息稀释:加法操作使深层网络难以保留浅层关键信息
- 静态权重:无论输入内容如何,残差分支和恒等分支的权重比为
- 梯度路径:深层梯度必须经过非线性变换,梯度流效率低
多路动态密集连接
MUDDFormer为每个输入流(Query、Key、Value、残差)动态生成连接权重:
其中:
- 是输入流集合
- 是动态生成的门控权重
- 是对应的变换操作
动态权重生成
权重生成器 由以下组件构成:
- 序列位置编码:捕捉token位置信息
- 输入流特征:编码当前流的语义内容
- 温度参数:控制权重的软硬程度
架构设计
MUDDFormer块结构
输入 x
│
├──→ [Q流] ─→ W_q ─→ Q
│ ↑
│ G_Q(x)
│
├──→ [K流] ─→ W_k ─→ K
│ ↑
│ G_K(x)
│
├──→ [V流] ─→ W_v ─→ V
│ ↑
│ G_V(x)
│
└──→ [R流] ─→ F(x) ─→ 残差
↑
G_R(x)
│
└──→ 加权求和 ─→ LayerNorm ─→ 输出
权重生成网络
class MUDDWeightGenerator(nn.Module):
def __init__(self, d_model, n_streams=4):
super().__init__()
self.n_streams = n_streams
# 位置编码
self.pos_encoder = PositionalEncoding(d_model)
# 共享特征提取器
self.feature_extractor = nn.Sequential(
nn.Linear(d_model, d_model),
nn.GELU(),
nn.Linear(d_model, d_model)
)
# 每个流的独立生成器
self.weight_heads = nn.ModuleList([
nn.Linear(d_model, 1) for _ in range(n_streams)
])
def forward(self, x, positions):
# x: [B, N, D]
pos_emb = self.pos_encoder(positions) # 位置编码
feat = self.feature_extractor(x) + pos_emb # 融合特征
# 生成各流权重
logits = torch.cat([
head(feat) for head in self.weight_heads
], dim=-1) # [B, N, n_streams]
# 温度软化
weights = F.softmax(logits, dim=-1)
return weights.chunk(self.n_streams, dim=-1)与标准Transformer的集成
class MUDDAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, n_heads, dropout)
self.mudd_gen = MUDDWeightGenerator(d_model, n_streams=4)
self.norm = nn.LayerNorm(d_model)
def forward(self, x, positions=None):
B, N, D = x.shape
if positions is None:
positions = torch.arange(N, device=x.device).unsqueeze(0).expand(B, -1)
# 生成MUDD权重
w_q, w_k, w_v, w_r = self.mudd_gen(x, positions)
# Q, K, V变换
Q = x @ self.W_q.weight.T
K = x @ self.W_k.weight.T
V = x @ self.W_v.weight.T
# 残差路径
residual = self.ffn(x)
# 加权融合
x_transformed = w_q * Q + w_k * K + w_v * V + w_r * residual
# 自注意力
attn_out, _ = self.attention(x_transformed, x_transformed, x_transformed)
return self.norm(x + attn_out)理论分析
信息保留能力
传统残差:,浅层信息以 比例传递到第 层。
MUDDFormer通过动态门控可实现:
其中 是数据依赖的门控值,关键信息可获得更高 。
参数效率
| 方法 | 额外参数 | 性能提升 |
|---|---|---|
| 标准残差 | 0% | 基线 |
| Dense连接 | ~20% | ~10% |
| MUDDFormer | 0.23% | 匹配1.8-2.4×更大模型 |
MUDDFormer仅增加0.23%参数,即可达到相当于训练计算量增加1.8-2.4倍的效果。
实验结果
语言建模
| 模型 | 参数量 | 训练Tokens | Validation PPL |
|---|---|---|---|
| Pythia-2.8B | 2.8B | 300B | 8.32 |
| MUDDPythia-2.8B | 2.8B | 300B | 7.89 |
| Pythia-6.9B | 6.9B | 300B | 7.91 |
MUDDPythia-2.8B在预训练困惑度上匹配Pythia-6.9B(参数量减少60%)。
下游任务
| 任务 | Pythia-6.9B | MUDDPythia-2.8B | Pythia-12B |
|---|---|---|---|
| LAMBADA | 68.2 | 69.1 | 71.3 |
| HellaSwag | 52.4 | 51.8 | 53.9 |
| PIQA | 76.8 | 77.2 | 78.1 |
计算效率
- 训练速度:与基线相同(无额外计算开销)
- 内存占用:增加 <1%
- 推理延迟:无显著增加
与DeepCrossAttention的对比
| 特性 | MUDDFormer | DeepCrossAttention |
|---|---|---|
| 连接方式 | 动态加权求和 | 深度交叉注意力 |
| 额外参数 | 0.23% | 0.2% |
| 加速比 | 1.8-2.4× | 3× |
| 设计重点 | 残差信息流 | 层间交互 |
应用建议
- 资源受限场景:MUDDFormer是在有限计算预算下提升性能的利器
- 深层模型:推荐用于12层以上的Transformer
- 与MoE结合:可与专家混合架构协同
参考资料
相关链接
Footnotes
-
Xiao et al. “MUDDFormer: Breaking Residual Bottlenecks in Transformers via Multiway Dynamic Dense Connections” ICML 2025. arXiv:2502.12170 ↩