门控注意力:NeurIPS 2025 最佳论文深度解析
引言
2025 年 NeurIPS 最佳论文之一《Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free》由 Qwen 团队提出,系统研究了注意力门控 (Gated Attention) 的设计空间。
核心发现:
- 在注意力输出后置门控(G1 配置:SDPA 输出 + 逐头逐元素 + sigmoid + 乘法)能同时解决三个核心问题:
- 缓解 Attention Sink 现象(首 token 吸收过多注意力)
- 支持 32k→128k 的 RoPE-base 上下文扩展
- 允许使用更高学习率训练
这是简单但效果惊人的架构改进——仅添加每头一个 sigmoid 门控(参数开销 < 1%),却带来训练和推理的双重提升。1
一、背景与动机
1.1 Attention Sink 现象
问题:训练好的 LLM 中,第一个 token(通常是 BOS)会吸收不成比例的注意力分数。
现象:
- 即使输入与首 token 语义无关,注意力仍会”流”向它
- 在长序列中尤其严重
- 浪费了模型容量
已有解释:
- Softmax 性质:softmax 需要归一化 → 某处必须”承载”残差注意力
- 超神经元 (Super-Neuron):特定神经元成为”信号汇”
- 因果掩码:自回归中首 token 是唯一可被所有位置关注的 token
实践影响:
- 滑动窗口注意力会丢失 sink → 性能崩溃
- 长度外推时需要重新校准 sink
- KV Cache 压缩时需保留 sink
1.2 高学习率训练的不稳定
问题:标准 Attention 在高学习率下训练不稳定。
原因:
- 注意力的 softmax 梯度与分数值耦合
- 大梯度 → softmax 饱和 → 梯度消失
- 训练动力学难控制
1.3 长上下文扩展的挑战
问题:从 32k 扩展到 128k 通常需要:
- 复杂的 RoPE 插值(YaRN、LongRoPE)
- 额外的位置编码训练
- 二次验证(“长度泛化测试”)
核心困难:
- 注意力分数分布随长度变化
- 极端长度下数值不稳定
1.4 门控注意力的动机
如果我们在注意力的关键位置加一个 sigmoid 门控,能否同时解决上述三个问题?
这就是 Qwen 团队的系统研究。
二、5 个候选门控位置系统化
2.1 门控位置的分类
Qwen 团队系统分析了 5 个候选门控位置(G1-G5),并辅以多种门控函数变体。
2.1.1 G1: SDPA 输出后置门控
# 数学形式
y = σ(W_g · attn_output) ⊙ attn_output
# 等价
y = σ(W_g · x) ⊙ softmax(QK^T / √d) V特点:
- 在注意力计算完成后立即门控
- 逐头逐元素(per-head per-element)
- sigmoid 激活
2.1.2 G2: QK^T 后置门控
# 数学形式
attn_logits = (QK^T / √d) ⊙ σ(W_g · x)
y = softmax(attn_logits) V特点:
- 在 softmax 之前门控注意力分数
- 改变注意力分布
- 不直接修改输出
2.1.3 G3: Softmax 前置门控
# 数学形式
attn_pre = QK^T / √d
attn_pre = attn_pre ⊙ σ(W_gate(Q))
y = softmax(attn_pre) V特点:
- 仅对 query 侧门控
- 等价于”软检索过滤器”
2.1.4 G4: 残差连接门控
# 数学形式
y = x + σ(W_g · x) ⊙ attn(x)
# 等价
y = x + σ(W_g · x) ⊙ (softmax(QK^T / √d) V)特点:
- 在整个注意力残差上加门控
- 等价于”可学习的残差缩放”
2.1.5 G5: 多头合并后门控
# 数学形式
multi_head = Concat(head_1, ..., head_h) W_O
y = σ(W_g · x) ⊙ multi_head特点:
- 在头合并后门控
- 头之间共享门控
2.2 完整 5 位置对比表
| 位置 | 数学形式 | 是否逐头 | 计算开销 | 效果 |
|---|---|---|---|---|
| G1 | 是 | 小 | 最优 | |
| G2 | 取决于实现 | 中 | 中等 | |
| G3 | 隐含在 Q | 中 | 较弱 | |
| G4 | 是 | 小 | 中等 | |
| G5 | 否(合并后) | 小 | 较弱 |
2.3 系统实验结论
Qwen 团队在 1.7B 参数模型上系统比较所有配置:
最终实验结论:
G1 > G4 > G2 > G5 > G3
↑ ↑
最优 最差
G1 的优势:
- 注意力分数 → 输出之间是直接乘法关系
- 门控不破坏 softmax 概率分布
- 逐头门控保留多头异质性
- 数值稳定(sigmoid 输出 ∈ [0, 1])
三、G1 配置详解
3.1 数学形式
设输入 ,多头参数 个头,每头维度 。
每头计算:
其中 是每个头的门控投影。
3.2 实现细节
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class GatedMultiHeadAttention(nn.Module):
"""Gated Multi-Head Attention (G1 配置)"""
def __init__(self, d_model, n_heads, dropout=0.0, bias=False):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
# QKV 投影
self.W_q = nn.Linear(d_model, d_model, bias=bias)
self.W_k = nn.Linear(d_model, d_model, bias=bias)
self.W_v = nn.Linear(d_model, d_model, bias=bias)
# 门控投影(每头一个)
self.W_g = nn.Linear(d_model, d_model, bias=bias)
# 输出投影
self.W_o = nn.Linear(d_model, d_model, bias=bias)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None, kv_cache=None):
"""
x: (B, T, d_model)
"""
B, T, d = x.shape
# QKV 投影
Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
# 门控投影:与 QKV 同样的多头切分
G = self.W_g(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
G = torch.sigmoid(G) # (B, n_heads, T, d_k)
# 注意力分数
scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
# 因果掩码
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
# 注意力输出
out = attn @ V # (B, n_heads, T, d_k)
# G1 门控:逐头逐元素
out = G * out # 广播:G 的最后一维与 out 相同
# 多头合并
out = out.transpose(1, 2).contiguous().view(B, T, d)
out = self.W_o(out)
return out3.3 训练 vs 推理
训练:
- 门控 随其他参数一起学习
- 初始化:,使 初始
- 学习率与其他参数相同
推理:
- 门控是确定性的
- KV Cache 不需要修改
- 仅每步额外一次 sigmoid 乘法
性能开销:
- 参数:每头 ≈ ,对 个头总计
- 相比标准 MHA:增加 投影参数
- 计算:每步增加 (门控计算)
四、G1 解决的关键问题
4.1 Attention Sink 缓解
核心现象:
- 标准 Attention 中,BOS token 经常吸收 30-80% 的注意力分数
- 长序列中更严重
G1 如何缓解:
门控作为输入相关的衰减因子:
- 当 head 想要”关注”无关 token 时, 接近 0
- 当 head 想要”关注”相关 token 时, 接近 1
实验结果:
- BOS 注意力分数从 60-80% 降至 5-15%
- 注意力分布更均匀
- 长序列中分布更稳定
def visualize_attention_sinks(model, x, threshold=0.3):
"""可视化 attention sink 程度"""
B, T, d = x.shape
with torch.no_grad():
# 提取第一层的注意力分数
attn = model.get_attention_scores(x) # (B, n_heads, T, T)
# 统计首 token 注意力占比
bos_attn = attn[:, :, :, 0] # 所有 head、query 关注 BOS 的分数
mean_bos_attn = bos_attn.mean(dim=(1, 2)) # (B,)
n_sinks = (mean_bos_attn > threshold).sum().item()
return mean_bos_attn.mean().item(), n_sinks
# 实验对比
# 标准 MHA: mean_bos_attn = 0.65, n_sinks = 8/12 layers
# Gated MHA: mean_bos_attn = 0.10, n_sinks = 0/12 layers4.2 长上下文扩展(32k → 128k)
核心问题:
- RoPE-base 在 32k 训练后无法直接外推到 128k
- 注意力分数随长度变化(缩放失效)
G1 的帮助:
-
门控自适应缩放:
- 短序列: 接近 1(正常)
- 长序列: 降低(抑制无效注意力)
- 等价于学习一个长度无关的”有效注意力”
-
不需要位置编码插值:
- 标准的 YaRN/LongRoPE 可省略
- 门控自然处理长度变化
实验:
- 32k 训练 → 128k 推理
- 标准 MHA:Perplexity 爆炸
- Gated MHA:Perplexity 平滑上升
4.3 高学习率训练
问题根源:
- 注意力梯度 与 softmax 分数耦合
- 高学习率 → 大梯度 → softmax 饱和 → 训练失败
G1 的解耦:
门控 的梯度独立于 softmax:
这意味着:
- 即使 softmax 部分饱和,门控仍能学习
- 总梯度被分到两个独立通路
- 等价于隐式梯度裁剪
实验:
- 标准 MHA:最高学习率
- Gated MHA:最高学习率 (3x 提升)
4.4 训练效率
由于可以高学习率训练,总训练步数减少:
- 1.7B 模型:相同 loss 减少 30-40% 训练时间
- 收敛更快
- 内存相同
五、消融实验
5.1 门控函数的选择
| 门控函数 | 公式 | 效果 |
|---|---|---|
| Sigmoid | 最优 | |
| Tanh | 略差 | |
| ReLU | 较差(输出非 [0,1]) | |
| SiLU | 较差 | |
| 固定 1 | 无门控 | 基线 |
为什么 Sigmoid 最优:
- 输出 ∈ [0, 1]:可解释为”软开关”
- 平滑梯度:训练友好
- 单调:易于分析
5.2 是否逐头共享
| 配置 | 参数量 | 效果 |
|---|---|---|
| 逐头独立 (per-head) | 最优 | |
| 头组合后单门控 (single) | 略差 | |
| 所有头共享 (shared) | 较差 |
为什么逐头独立最优:
- 不同头学习不同模式
- 一头可能关注局部,另一头关注全局
- 共享门控会限制这种异质性
5.3 初始化策略
| 初始化 | 效果 |
|---|---|
| 零初始化 (W_g = 0) | 起始 ,稳定训练 |
| Kaiming 初始化 | 起始 变化大,初期不稳定 |
| Xavier 初始化 | 类似 Kaiming |
推荐:零初始化 (但需注意打破对称性 → 标准做法是单独线性层)
5.4 与其他门控组件的关系
Qwen 团队研究了与门控 FFN (GLU/SwiGLU) 的交互:
| 配置 | 效果 |
|---|---|
| 仅 Gated Attention | 良好 |
| 仅 Gated FFN (SwiGLU) | 良好 |
| 同时使用 | 最优 |
| 都不使用 | 基线 |
互补性:
- Gated Attention 控制信息路由(哪些 token 被关注)
- Gated FFN 控制信息处理(特征如何变换)
- 两者正交,独立贡献
六、实际部署
6.1 Qwen3-Next 实现
Qwen3-Next 是首个大规模部署 Gated Attention 的工业模型:
- 80B+ 参数
- 256K 上下文
- 训练效率提升 30%
class Qwen3NextAttention(nn.Module):
"""Qwen3-Next 风格的 Gated Attention"""
def __init__(self, d_model, n_heads, n_kv_heads=None, bias=False):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads or n_heads
self.d_k = d_model // n_heads
self.n_rep = n_heads // self.n_kv_heads
# Q (全头), KV (分组)
self.W_q = nn.Linear(d_model, n_heads * self.d_k, bias=bias)
self.W_k = nn.Linear(d_model, self.n_kv_heads * self.d_k, bias=bias)
self.W_v = nn.Linear(d_model, self.n_kv_heads * self.d_k, bias=bias)
# 门控
self.W_g = nn.Linear(d_model, n_heads * self.d_k, bias=bias)
# 输出
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, mask=None, kv_cache=None):
B, T, d = x.shape
Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
G = torch.sigmoid(self.W_g(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2))
# KV 复制(GQA)
K = K.repeat_interleave(self.n_rep, dim=1)
V = V.repeat_interleave(self.n_rep, dim=1)
# SDPA
out = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
# G1 门控
out = G * out
out = out.transpose(1, 2).contiguous().view(B, T, d)
return self.W_o(out)6.2 与 MoE 结合
Gated Attention 与 Mixture of Experts 正交互补:
- Gated Attention 处理序列方向的路由
- MoE 处理特征方向的路由
class GatedAttentionMoEBlock(nn.Module):
"""Gated Attention + MoE 组合"""
def __init__(self, d_model, n_heads, n_experts, top_k):
super().__init__()
self.attn = GatedMultiHeadAttention(d_model, n_heads)
self.moe = MoE(d_model, n_experts, top_k)
self.norm1 = nn.RMSNorm(d_model)
self.norm2 = nn.RMSNorm(d_model)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.moe(self.norm2(x))
return x6.3 训练技巧
- 学习率:可提升 2-3 倍
- Warmup:标准 2000 步
- 权重衰减:与 MHA 相同
- 初始化: + 标准偏差
- 梯度裁剪:可放宽到 1.0
七、理论分析
7.1 门控作为隐式正则化
视角:门控 是软特征选择器。
理论:
- 当 :head 完全忽略输入
- 当 :head 完全保留输入
- 中间值:软过滤
正则化效应:
- 鼓励”必要才激活”
- 抑制噪声
- 等价于 软正则化
7.2 与门控线性单元(GLU)的理论联系
GLU(Dauphin et al. 2017)形式:
Gated Attention 是 GLU 在注意力维度的推广:
- GLU:门控 作用于 FFN 输出
- Gated Attention:门控 作用于 Attention 输出
共同理论:
- 乘法交互(gating)比加法(residual)更具表达力
- 门控实现条件计算(conditional computation)
- 与 Highway Network (Srivastava 2015) 一脉相承
7.3 注意力熵的调节
定义:注意力熵
门控对熵的影响:
- :输出 ,但 不变
- 实际效果:降低等效注意力熵
- 鼓励稀疏化的有效注意力
def attention_entropy_with_gating(scores, gate):
"""计算门控后的有效注意力熵"""
attn = F.softmax(scores, dim=-1)
# 门控后的有效输出
gated_attn = gate * attn
gated_attn = gated_attn / (gated_attn.sum(dim=-1, keepdim=True) + 1e-9)
entropy = -(gated_attn * torch.log(gated_attn + 1e-9)).sum(dim=-1)
return entropy.mean()7.4 与 Hebbian 学习的联系
门控 的学习规则:
- 当 大时, 变化快
- 当 小时, 稳定
- 这与STDP(spike-timing-dependent plasticity)有形式相似
猜想:门控可能模拟生物神经元的突触可塑性,但需要更深入研究。
八、与其他架构改进的对比
8.1 vs Multi-Head Latent Attention (MLA, DeepSeek-V3)
| 维度 | MLA | Gated Attention |
|---|---|---|
| 核心思想 | 低秩压缩 KV | 门控输出 |
| 节省 | KV Cache 8x | 无 |
| 表达力 | 略低 | 略高 |
| 长上下文 | 改善 | 显著改善 |
| 训练 | 复杂 | 简单 |
可叠加:MLA 思想与 Gated Attention 兼容,可联合使用。
8.2 vs Sliding Window Attention (SWA)
| 维度 | SWA | Gated Attention |
|---|---|---|
| 核心思想 | 限制感受野 | 门控注意力 |
| 复杂度 | ||
| 内存 | 小 | 大 |
| 表现 | 强 | 强 |
| 训练技巧 | 需处理 sink | 自动处理 sink |
注:Gated Attention + SWA 仍可能需要 Sink 处理。
8.3 vs Gated FFN (SwiGLU)
| 维度 | SwiGLU | Gated Attention |
|---|---|---|
| 位置 | FFN | Attention |
| 作用 | 特征变换 | 信息路由 |
| 模式 | 静态门控 | 动态门控(输入依赖) |
| 表达力 | 中 | 高 |
九、未来方向
9.1 开放问题
- 最佳门控位置:G1 在大多数任务最优,但某些任务 G2/G4 可能更好
- 门控与位置编码交互:RoPE 下的门控行为
- 稀疏门控:将 G1 稀疏化为 Top-k
- 多层级门控:每层不同的门控配置
- 跨模态门控:视觉-语言的门控策略
9.2 工业影响
- Qwen 系列:Qwen3-Next 已采用
- DeepSeek-V4(预计):可能采用
- Llama-4(预计):可能采用
- 开源:Gated Attention 实现简单,可广泛采用
9.3 与其他方向结合
- 稀疏注意力:Gated Attention + Block-Sparse
- MoE:Gated Attention + MoE 的稀疏专家
- 线性注意力:门控化 Linear Attention
- Mamba:门控化 Mamba 块
十、完整 PyTorch 训练示例
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import DataLoader
class GatedTransformerBlock(nn.Module):
"""完整 Gated Transformer Block"""
def __init__(self, d_model, n_heads, d_ff, dropout=0.1, use_gated_attn=True):
super().__init__()
self.use_gated_attn = use_gated_attn
if use_gated_attn:
self.attn = GatedMultiHeadAttention(d_model, n_heads, dropout=dropout)
else:
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
# SwiGLU FFN(推荐)
self.ffn = SwiGLU(d_model, d_ff)
self.norm1 = nn.RMSNorm(d_model)
self.norm2 = nn.RMSNorm(d_model)
def forward(self, x, mask=None):
# Pre-norm
x = x + self.attn(self.norm1(x), mask=mask)
x = x + self.ffn(self.norm2(x))
return x
class SwiGLU(nn.Module):
"""SwiGLU FFN"""
def __init__(self, d_model, d_ff):
super().__init__()
self.W1 = nn.Linear(d_model, d_ff, bias=False)
self.W2 = nn.Linear(d_model, d_ff, bias=False)
self.W3 = nn.Linear(d_ff, d_model, bias=False)
def forward(self, x):
return self.W3(F.silu(self.W1(x)) * self.W2(x))
class GatedTransformerLM(nn.Module):
"""完整 Gated Transformer 语言模型"""
def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, max_len=32768, dropout=0.1):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_len, d_model)
self.layers = nn.ModuleList([
GatedTransformerBlock(d_model, n_heads, d_ff, dropout, use_gated_attn=True)
for _ in range(n_layers)
])
self.norm = nn.RMSNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# 共享嵌入(可选)
# self.lm_head.weight = self.token_emb.weight
def forward(self, ids, targets=None):
B, T = ids.shape
pos = torch.arange(T, device=ids.device)
x = self.token_emb(ids) + self.pos_emb(pos)
# 因果掩码
mask = torch.tril(torch.ones(T, T, device=ids.device)).unsqueeze(0).unsqueeze(0)
for layer in self.layers:
x = layer(x, mask=mask)
x = self.norm(x)
logits = self.lm_head(x)
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
return logits, None
# 训练循环
def train_gated_lm():
model = GatedTransformerLM(
vocab_size=32000, d_model=2048, n_heads=16, n_layers=24,
d_ff=8192, max_len=32768
)
# 高学习率训练(门控允许)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10000)
model.cuda()
for step in range(10000):
# 训练步骤(伪代码)
x = next(iter(DataLoader([]))) # 真实数据
logits, loss = model(x[:, :-1], x[:, 1:])
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if step % 100 == 0:
print(f"step {step}: loss {loss.item():.4f}")
if __name__ == "__main__":
# 测试
model = GatedTransformerLM(vocab_size=1000, d_model=128, n_heads=4, n_layers=4, d_ff=512)
x = torch.randint(0, 1000, (2, 64))
logits, loss = model(x[:, :-1], x[:, 1:])
print(f"logits: {logits.shape}, loss: {loss.item():.4f}")总结
Gated Attention(Qwen 团队,NeurIPS 2025 Best Paper)是一个简单但强大的架构改进:
- 核心思想:在 SDPA 输出后置逐头 sigmoid 门控(G1 配置)
- 三大收益:缓解 attention sink、支持 32k→128k 扩展、允许高学习率
- 实现简单:每头加一个线性层 + sigmoid,开销 < 33%
- 工业可用:Qwen3-Next 已采用,可广泛部署
关键洞察:
门控实现”条件计算”——根据输入动态决定每个头的信息保留量
这与 Gated FFN (GLU/SwiGLU) 在 FFN 维度的成功完全平行,是 2025 年最重要的架构发现之一。1
参考资料
Footnotes
-
主要参考:Qwen 团队 NeurIPS 2025 Best Paper “Gated Attention for Large Language Models” (https://openreview.net/pdf?id=1b7whO4SfY)。相关工作包括:Dauphin et al. 2017 (GLU)、Shazeer 2020 (GLU variants)、Su et al. 2024 (RoFormer/RoPE)、Xiao et al. 2024 (YaRN) 等。 ↩ ↩2