Nexus高阶注意力机制

概述

Nexus是ICML 2025提出的一种新型高阶注意力机制,旨在解决标准Transformer中自注意力的低秩瓶颈问题1

传统注意力机制仅建模token之间的成对(一阶)交互,难以捕获多跳关系和层级依赖。Nexus通过递归嵌套的自注意力来动态精炼Query和Key向量,使模型能够在最终注意力计算之前进行”预推理”。

核心思想

标准注意力的局限性

标准自注意力的计算如下:

其中 。注意力权重矩阵 仅建模成对交互,无法直接捕获三 token 或多 token 之间的高阶依赖。

高阶注意力的形式化

Nexus定义了一种高阶注意力机制:

其中精炼后的Query和Key通过自注意力获得:

这种设计使 编码了多 token 聚合信息,后续注意力操作可直接考虑多 token 依赖。

递归扩展

可递归定义 阶注意力:

递归步骤使模型能够捕获多层级依赖和更复杂的结构关系。

参数高效权重共享

高阶机制的核心问题是参数增长。Nexus提出权重共享策略:内外注意力层复用相同的投影矩阵。

形式上,对于标准注意力层参数

内层注意力复用外层权重的约束确保:

  • 参数复杂度 (相对于递归阶数
  • 仅增加计算密度,不增加模型存储

理论分析:打破低秩瓶颈

线性瓶颈定理1

为序列长度)时,标准注意力机制缺乏表达任意注意力权重的能力。

定理3.1:若 ,存在满足 的矩阵 ,但对于线性变换 ,仍有:

这意味着即使是一阶对数注意力权重矩阵,标准注意力也无法精确表示。

Nexus通过非线性映射 突破了这一限制。

计算复杂度

高阶注意力的时间复杂度为 ,但实践中 即可获得显著性能提升,实际开销约为标准注意力的2倍

实验结果

语言建模基准

在Pythia模型系列上评估:

模型规模PIQAHellaSwagSciQARC-EARC-CLogiQA
Pythia-160M64.241.394.269.133.227.4
Nexus-160M66.843.195.171.235.829.6

数学推理能力

Nexus在数学推理任务上显著超越基线:

  • GSM8K:提升 12.3%
  • MATH:提升 15.7%
  • 对Qwen2.5进行架构升级后,数学能力显著增强

可视化分析

Nexus的注意力矩阵展现出更密集的连接模式和更丰富的交互结构,表明其捕获了更复杂的多 token 关系。

实现细节

PyTorch伪代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class NexusAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # x: [batch, seq_len, d_model]
        B, N, D = x.shape
        
        # 标准 Q, K, V
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # 重塑为多头形式
        Q = Q.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        
        # === 高阶注意力:Query精炼 ===
        # Inner attention on Q: Q @ Q^T @ Q
        Q_refined = F.softmax(
            torch.matmul(Q, Q.transpose(-2, -1)) / math.sqrt(self.d_k), 
            dim=-1
        )
        Q_refined = torch.matmul(Q_refined, Q)
        
        # Inner attention on K: K @ K^T @ K  
        K_refined = F.softmax(
            torch.matmul(K, K.transpose(-2, -1)) / math.sqrt(self.d_k),
            dim=-1
        )
        K_refined = torch.matmul(K_refined, K)
        
        # === 最终注意力(使用精炼后的Q, K)===
        scores = torch.matmul(Q_refined, K_refined.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        
        # 应用到 V
        context = torch.matmul(attn, V)
        context = context.transpose(1, 2).contiguous().view(B, N, D)
        
        return self.W_o(context)

与相关工作的对比

方法表达力提升参数增加适用场景
标准Transformer一阶0通用
Linear Attention低秩近似0长序列
Nexus高阶O(1)推理密集任务
Attention on Attention二阶显著增加视觉任务

应用场景

  1. 数学推理:多步逻辑推导需要高阶依赖建模
  2. 代码生成:变量引用链、嵌套函数调用
  3. 科学问题:需要跨多个前提的推理
  4. 形式化验证:逻辑公式之间的推导关系

参考资料

相关链接

Footnotes

  1. Zhu et al. “Nexus: Higher-Order Attention Mechanisms in Transformers” ICML 2025. arXiv:2512.03377 2