门控注意力:NeurIPS 2025 最佳论文深度解析

引言

2025 年 NeurIPS 最佳论文之一《Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free》由 Qwen 团队提出,系统研究了注意力门控 (Gated Attention) 的设计空间。

核心发现

  • 在注意力输出后置门控(G1 配置:SDPA 输出 + 逐头逐元素 + sigmoid + 乘法)能同时解决三个核心问题
    1. 缓解 Attention Sink 现象(首 token 吸收过多注意力)
    2. 支持 32k→128k 的 RoPE-base 上下文扩展
    3. 允许使用更高学习率训练

这是简单但效果惊人的架构改进——仅添加每头一个 sigmoid 门控(参数开销 < 1%),却带来训练和推理的双重提升。1


一、背景与动机

1.1 Attention Sink 现象

问题:训练好的 LLM 中,第一个 token(通常是 BOS)会吸收不成比例的注意力分数。

现象

  • 即使输入与首 token 语义无关,注意力仍会”流”向它
  • 在长序列中尤其严重
  • 浪费了模型容量

已有解释

  • Softmax 性质:softmax 需要归一化 → 某处必须”承载”残差注意力
  • 超神经元 (Super-Neuron):特定神经元成为”信号汇”
  • 因果掩码:自回归中首 token 是唯一可被所有位置关注的 token

实践影响

  • 滑动窗口注意力会丢失 sink → 性能崩溃
  • 长度外推时需要重新校准 sink
  • KV Cache 压缩时需保留 sink

1.2 高学习率训练的不稳定

问题:标准 Attention 在高学习率下训练不稳定。

原因

  • 注意力的 softmax 梯度与分数值耦合
  • 大梯度 → softmax 饱和 → 梯度消失
  • 训练动力学难控制

1.3 长上下文扩展的挑战

问题:从 32k 扩展到 128k 通常需要:

  • 复杂的 RoPE 插值(YaRN、LongRoPE)
  • 额外的位置编码训练
  • 二次验证(“长度泛化测试”)

核心困难

  • 注意力分数分布随长度变化
  • 极端长度下数值不稳定

1.4 门控注意力的动机

如果我们在注意力的关键位置加一个 sigmoid 门控,能否同时解决上述三个问题?

这就是 Qwen 团队的系统研究。


二、5 个候选门控位置系统化

2.1 门控位置的分类

Qwen 团队系统分析了 5 个候选门控位置(G1-G5),并辅以多种门控函数变体。

2.1.1 G1: SDPA 输出后置门控

# 数学形式
y = σ(W_g · attn_output) ⊙ attn_output
# 等价
y = σ(W_g · x) ⊙ softmax(QK^T / √d) V

特点

  • 在注意力计算完成后立即门控
  • 逐头逐元素(per-head per-element)
  • sigmoid 激活

2.1.2 G2: QK^T 后置门控

# 数学形式
attn_logits = (QK^T / √d) ⊙ σ(W_g · x)
y = softmax(attn_logits) V

特点

  • 在 softmax 之前门控注意力分数
  • 改变注意力分布
  • 不直接修改输出

2.1.3 G3: Softmax 前置门控

# 数学形式
attn_pre = QK^T / √d
attn_pre = attn_pre ⊙ σ(W_gate(Q))
y = softmax(attn_pre) V

特点

  • 仅对 query 侧门控
  • 等价于”软检索过滤器”

2.1.4 G4: 残差连接门控

# 数学形式
y = x + σ(W_g · x) ⊙ attn(x)
# 等价
y = x + σ(W_g · x) ⊙ (softmax(QK^T / √d) V)

特点

  • 在整个注意力残差上加门控
  • 等价于”可学习的残差缩放”

2.1.5 G5: 多头合并后门控

# 数学形式
multi_head = Concat(head_1, ..., head_h) W_O
y = σ(W_g · x) ⊙ multi_head

特点

  • 在头合并后门控
  • 头之间共享门控

2.2 完整 5 位置对比表

位置数学形式是否逐头计算开销效果
G1最优
G2取决于实现中等
G3隐含在 Q较弱
G4中等
G5否(合并后)较弱

2.3 系统实验结论

Qwen 团队在 1.7B 参数模型上系统比较所有配置:

最终实验结论:

G1 > G4 > G2 > G5 > G3
↑                       ↑
最优                   最差

G1 的优势

  1. 注意力分数 → 输出之间是直接乘法关系
  2. 门控不破坏 softmax 概率分布
  3. 逐头门控保留多头异质性
  4. 数值稳定(sigmoid 输出 ∈ [0, 1])

三、G1 配置详解

3.1 数学形式

设输入 ,多头参数 个头,每头维度

每头计算

其中 是每个头的门控投影。

3.2 实现细节

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
 
class GatedMultiHeadAttention(nn.Module):
    """Gated Multi-Head Attention (G1 配置)"""
    def __init__(self, d_model, n_heads, dropout=0.0, bias=False):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
 
        # QKV 投影
        self.W_q = nn.Linear(d_model, d_model, bias=bias)
        self.W_k = nn.Linear(d_model, d_model, bias=bias)
        self.W_v = nn.Linear(d_model, d_model, bias=bias)
 
        # 门控投影(每头一个)
        self.W_g = nn.Linear(d_model, d_model, bias=bias)
 
        # 输出投影
        self.W_o = nn.Linear(d_model, d_model, bias=bias)
 
        self.dropout = nn.Dropout(dropout)
 
    def forward(self, x, mask=None, kv_cache=None):
        """
        x: (B, T, d_model)
        """
        B, T, d = x.shape
 
        # QKV 投影
        Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
 
        # 门控投影:与 QKV 同样的多头切分
        G = self.W_g(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        G = torch.sigmoid(G)  # (B, n_heads, T, d_k)
 
        # 注意力分数
        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
 
        # 因果掩码
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
 
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
 
        # 注意力输出
        out = attn @ V  # (B, n_heads, T, d_k)
 
        # G1 门控:逐头逐元素
        out = G * out  # 广播:G 的最后一维与 out 相同
 
        # 多头合并
        out = out.transpose(1, 2).contiguous().view(B, T, d)
        out = self.W_o(out)
 
        return out

3.3 训练 vs 推理

训练

  • 门控 随其他参数一起学习
  • 初始化:,使 初始
  • 学习率与其他参数相同

推理

  • 门控是确定性的
  • KV Cache 不需要修改
  • 仅每步额外一次 sigmoid 乘法

性能开销

  • 参数:每头 ,对 个头总计
  • 相比标准 MHA:增加 投影参数
  • 计算:每步增加 (门控计算)

四、G1 解决的关键问题

4.1 Attention Sink 缓解

核心现象

  • 标准 Attention 中,BOS token 经常吸收 30-80% 的注意力分数
  • 长序列中更严重

G1 如何缓解

门控作为输入相关的衰减因子

  • 当 head 想要”关注”无关 token 时, 接近 0
  • 当 head 想要”关注”相关 token 时, 接近 1

实验结果

  • BOS 注意力分数从 60-80% 降至 5-15%
  • 注意力分布更均匀
  • 长序列中分布更稳定
def visualize_attention_sinks(model, x, threshold=0.3):
    """可视化 attention sink 程度"""
    B, T, d = x.shape
    with torch.no_grad():
        # 提取第一层的注意力分数
        attn = model.get_attention_scores(x)  # (B, n_heads, T, T)
        # 统计首 token 注意力占比
        bos_attn = attn[:, :, :, 0]  # 所有 head、query 关注 BOS 的分数
        mean_bos_attn = bos_attn.mean(dim=(1, 2))  # (B,)
        n_sinks = (mean_bos_attn > threshold).sum().item()
    return mean_bos_attn.mean().item(), n_sinks
 
# 实验对比
# 标准 MHA:  mean_bos_attn = 0.65, n_sinks = 8/12 layers
# Gated MHA: mean_bos_attn = 0.10, n_sinks = 0/12 layers

4.2 长上下文扩展(32k → 128k)

核心问题

  • RoPE-base 在 32k 训练后无法直接外推到 128k
  • 注意力分数随长度变化(缩放失效)

G1 的帮助

  1. 门控自适应缩放

    • 短序列: 接近 1(正常)
    • 长序列: 降低(抑制无效注意力)
    • 等价于学习一个长度无关的”有效注意力”
  2. 不需要位置编码插值

    • 标准的 YaRN/LongRoPE 可省略
    • 门控自然处理长度变化

实验

  • 32k 训练 → 128k 推理
  • 标准 MHA:Perplexity 爆炸
  • Gated MHA:Perplexity 平滑上升

4.3 高学习率训练

问题根源

  • 注意力梯度 与 softmax 分数耦合
  • 高学习率 → 大梯度 → softmax 饱和 → 训练失败

G1 的解耦

门控 的梯度独立于 softmax:

这意味着:

  • 即使 softmax 部分饱和,门控仍能学习
  • 总梯度被分到两个独立通路
  • 等价于隐式梯度裁剪

实验

  • 标准 MHA:最高学习率
  • Gated MHA:最高学习率 (3x 提升)

4.4 训练效率

由于可以高学习率训练,总训练步数减少

  • 1.7B 模型:相同 loss 减少 30-40% 训练时间
  • 收敛更快
  • 内存相同

五、消融实验

5.1 门控函数的选择

门控函数公式效果
Sigmoid最优
Tanh略差
ReLU较差(输出非 [0,1])
SiLU较差
固定 1无门控基线

为什么 Sigmoid 最优

  • 输出 ∈ [0, 1]:可解释为”软开关”
  • 平滑梯度:训练友好
  • 单调:易于分析

5.2 是否逐头共享

配置参数量效果
逐头独立 (per-head)最优
头组合后单门控 (single)略差
所有头共享 (shared)较差

为什么逐头独立最优

  • 不同头学习不同模式
  • 一头可能关注局部,另一头关注全局
  • 共享门控会限制这种异质性

5.3 初始化策略

初始化效果
零初始化 (W_g = 0)起始 ,稳定训练
Kaiming 初始化起始 变化大,初期不稳定
Xavier 初始化类似 Kaiming

推荐:零初始化 (但需注意打破对称性 → 标准做法是单独线性层)

5.4 与其他门控组件的关系

Qwen 团队研究了与门控 FFN (GLU/SwiGLU) 的交互:

配置效果
仅 Gated Attention良好
仅 Gated FFN (SwiGLU)良好
同时使用最优
都不使用基线

互补性

  • Gated Attention 控制信息路由(哪些 token 被关注)
  • Gated FFN 控制信息处理(特征如何变换)
  • 两者正交,独立贡献

六、实际部署

6.1 Qwen3-Next 实现

Qwen3-Next 是首个大规模部署 Gated Attention 的工业模型:

  • 80B+ 参数
  • 256K 上下文
  • 训练效率提升 30%
class Qwen3NextAttention(nn.Module):
    """Qwen3-Next 风格的 Gated Attention"""
    def __init__(self, d_model, n_heads, n_kv_heads=None, bias=False):
        super().__init__()
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads or n_heads
        self.d_k = d_model // n_heads
        self.n_rep = n_heads // self.n_kv_heads
 
        # Q (全头), KV (分组)
        self.W_q = nn.Linear(d_model, n_heads * self.d_k, bias=bias)
        self.W_k = nn.Linear(d_model, self.n_kv_heads * self.d_k, bias=bias)
        self.W_v = nn.Linear(d_model, self.n_kv_heads * self.d_k, bias=bias)
        # 门控
        self.W_g = nn.Linear(d_model, n_heads * self.d_k, bias=bias)
        # 输出
        self.W_o = nn.Linear(d_model, d_model, bias=False)
 
    def forward(self, x, mask=None, kv_cache=None):
        B, T, d = x.shape
        Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
        G = torch.sigmoid(self.W_g(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2))
 
        # KV 复制(GQA)
        K = K.repeat_interleave(self.n_rep, dim=1)
        V = V.repeat_interleave(self.n_rep, dim=1)
 
        # SDPA
        out = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
 
        # G1 门控
        out = G * out
 
        out = out.transpose(1, 2).contiguous().view(B, T, d)
        return self.W_o(out)

6.2 与 MoE 结合

Gated Attention 与 Mixture of Experts 正交互补

  • Gated Attention 处理序列方向的路由
  • MoE 处理特征方向的路由
class GatedAttentionMoEBlock(nn.Module):
    """Gated Attention + MoE 组合"""
    def __init__(self, d_model, n_heads, n_experts, top_k):
        super().__init__()
        self.attn = GatedMultiHeadAttention(d_model, n_heads)
        self.moe = MoE(d_model, n_experts, top_k)
        self.norm1 = nn.RMSNorm(d_model)
        self.norm2 = nn.RMSNorm(d_model)
 
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.moe(self.norm2(x))
        return x

6.3 训练技巧

  1. 学习率:可提升 2-3 倍
  2. Warmup:标准 2000 步
  3. 权重衰减:与 MHA 相同
  4. 初始化 + 标准偏差
  5. 梯度裁剪:可放宽到 1.0

七、理论分析

7.1 门控作为隐式正则化

视角:门控 软特征选择器

理论

  • :head 完全忽略输入
  • :head 完全保留输入
  • 中间值:软过滤

正则化效应

  • 鼓励”必要才激活”
  • 抑制噪声
  • 等价于 软正则化

7.2 与门控线性单元(GLU)的理论联系

GLU(Dauphin et al. 2017)形式:

Gated Attention 是 GLU 在注意力维度的推广:

  • GLU:门控 作用于 FFN 输出
  • Gated Attention:门控 作用于 Attention 输出

共同理论

  • 乘法交互(gating)比加法(residual)更具表达力
  • 门控实现条件计算(conditional computation)
  • 与 Highway Network (Srivastava 2015) 一脉相承

7.3 注意力熵的调节

定义:注意力熵

门控对熵的影响

  • :输出 ,但 不变
  • 实际效果:降低等效注意力熵
  • 鼓励稀疏化的有效注意力
def attention_entropy_with_gating(scores, gate):
    """计算门控后的有效注意力熵"""
    attn = F.softmax(scores, dim=-1)
    # 门控后的有效输出
    gated_attn = gate * attn
    gated_attn = gated_attn / (gated_attn.sum(dim=-1, keepdim=True) + 1e-9)
    entropy = -(gated_attn * torch.log(gated_attn + 1e-9)).sum(dim=-1)
    return entropy.mean()

7.4 与 Hebbian 学习的联系

门控 的学习规则

  • 大时, 变化快
  • 小时, 稳定
  • 这与STDP(spike-timing-dependent plasticity)有形式相似

猜想:门控可能模拟生物神经元的突触可塑性,但需要更深入研究。


八、与其他架构改进的对比

8.1 vs Multi-Head Latent Attention (MLA, DeepSeek-V3)

维度MLAGated Attention
核心思想低秩压缩 KV门控输出
节省KV Cache 8x
表达力略低略高
长上下文改善显著改善
训练复杂简单

可叠加:MLA 思想与 Gated Attention 兼容,可联合使用。

8.2 vs Sliding Window Attention (SWA)

维度SWAGated Attention
核心思想限制感受野门控注意力
复杂度
内存
表现
训练技巧需处理 sink自动处理 sink

:Gated Attention + SWA 仍可能需要 Sink 处理。

8.3 vs Gated FFN (SwiGLU)

维度SwiGLUGated Attention
位置FFNAttention
作用特征变换信息路由
模式静态门控动态门控(输入依赖)
表达力

九、未来方向

9.1 开放问题

  1. 最佳门控位置:G1 在大多数任务最优,但某些任务 G2/G4 可能更好
  2. 门控与位置编码交互:RoPE 下的门控行为
  3. 稀疏门控:将 G1 稀疏化为 Top-k
  4. 多层级门控:每层不同的门控配置
  5. 跨模态门控:视觉-语言的门控策略

9.2 工业影响

  • Qwen 系列:Qwen3-Next 已采用
  • DeepSeek-V4(预计):可能采用
  • Llama-4(预计):可能采用
  • 开源:Gated Attention 实现简单,可广泛采用

9.3 与其他方向结合

  • 稀疏注意力:Gated Attention + Block-Sparse
  • MoE:Gated Attention + MoE 的稀疏专家
  • 线性注意力:门控化 Linear Attention
  • Mamba:门控化 Mamba 块

十、完整 PyTorch 训练示例

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import DataLoader
 
 
class GatedTransformerBlock(nn.Module):
    """完整 Gated Transformer Block"""
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1, use_gated_attn=True):
        super().__init__()
        self.use_gated_attn = use_gated_attn
        if use_gated_attn:
            self.attn = GatedMultiHeadAttention(d_model, n_heads, dropout=dropout)
        else:
            self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        # SwiGLU FFN(推荐)
        self.ffn = SwiGLU(d_model, d_ff)
        self.norm1 = nn.RMSNorm(d_model)
        self.norm2 = nn.RMSNorm(d_model)
 
    def forward(self, x, mask=None):
        # Pre-norm
        x = x + self.attn(self.norm1(x), mask=mask)
        x = x + self.ffn(self.norm2(x))
        return x
 
 
class SwiGLU(nn.Module):
    """SwiGLU FFN"""
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.W1 = nn.Linear(d_model, d_ff, bias=False)
        self.W2 = nn.Linear(d_model, d_ff, bias=False)
        self.W3 = nn.Linear(d_ff, d_model, bias=False)
 
    def forward(self, x):
        return self.W3(F.silu(self.W1(x)) * self.W2(x))
 
 
class GatedTransformerLM(nn.Module):
    """完整 Gated Transformer 语言模型"""
    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, max_len=32768, dropout=0.1):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.layers = nn.ModuleList([
            GatedTransformerBlock(d_model, n_heads, d_ff, dropout, use_gated_attn=True)
            for _ in range(n_layers)
        ])
        self.norm = nn.RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        # 共享嵌入(可选)
        # self.lm_head.weight = self.token_emb.weight
 
    def forward(self, ids, targets=None):
        B, T = ids.shape
        pos = torch.arange(T, device=ids.device)
        x = self.token_emb(ids) + self.pos_emb(pos)
 
        # 因果掩码
        mask = torch.tril(torch.ones(T, T, device=ids.device)).unsqueeze(0).unsqueeze(0)
 
        for layer in self.layers:
            x = layer(x, mask=mask)
        x = self.norm(x)
        logits = self.lm_head(x)
 
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
            return logits, loss
        return logits, None
 
 
# 训练循环
def train_gated_lm():
    model = GatedTransformerLM(
        vocab_size=32000, d_model=2048, n_heads=16, n_layers=24,
        d_ff=8192, max_len=32768
    )
    # 高学习率训练(门控允许)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.1)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10000)
 
    model.cuda()
    for step in range(10000):
        # 训练步骤(伪代码)
        x = next(iter(DataLoader([])))  # 真实数据
        logits, loss = model(x[:, :-1], x[:, 1:])
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        if step % 100 == 0:
            print(f"step {step}: loss {loss.item():.4f}")
 
 
if __name__ == "__main__":
    # 测试
    model = GatedTransformerLM(vocab_size=1000, d_model=128, n_heads=4, n_layers=4, d_ff=512)
    x = torch.randint(0, 1000, (2, 64))
    logits, loss = model(x[:, :-1], x[:, 1:])
    print(f"logits: {logits.shape}, loss: {loss.item():.4f}")

总结

Gated Attention(Qwen 团队,NeurIPS 2025 Best Paper)是一个简单但强大的架构改进:

  1. 核心思想:在 SDPA 输出后置逐头 sigmoid 门控(G1 配置)
  2. 三大收益:缓解 attention sink、支持 32k→128k 扩展、允许高学习率
  3. 实现简单:每头加一个线性层 + sigmoid,开销 < 33%
  4. 工业可用:Qwen3-Next 已采用,可广泛部署

关键洞察

门控实现”条件计算”——根据输入动态决定每个头的信息保留量

这与 Gated FFN (GLU/SwiGLU) 在 FFN 维度的成功完全平行,是 2025 年最重要的架构发现之一。1


参考资料

Footnotes

  1. 主要参考:Qwen 团队 NeurIPS 2025 Best Paper “Gated Attention for Large Language Models” (https://openreview.net/pdf?id=1b7whO4SfY)。相关工作包括:Dauphin et al. 2017 (GLU)、Shazeer 2020 (GLU variants)、Su et al. 2024 (RoFormer/RoPE)、Xiao et al. 2024 (YaRN) 等。 2