Nexus高阶注意力机制

Nexus是一种新型的高阶注意力架构,通过递归框架增强Transformer的表示能力。该方法发表于ICLR 2025,能够在单层注意力内建模多跳关系,有效打破标准注意力的低秩瓶颈。1

1. 低秩瓶颈问题

标准自注意力的局限

标准Transformer的自注意力机制定义为:

其中 分别为查询、键、值矩阵。

核心问题:注意力权重矩阵 仅建模token对之间的直接成对交互,无法捕获高阶依赖关系。

多跳推理的困境

考虑三个token 的交互,标准注意力需要:

建模三元关系需要堆叠多层注意力或进行多步推理,导致:

  • 计算负担增加
  • 信息损失和梯度消失
  • 训练/推理效率低下

理论分析:线性瓶颈

Bhojanapalli等人(2020)指出当 时,标准注意力无法表达任意注意力权重 2

定理3.1(线性瓶颈)

  1. 给定 个输入 和对应的目标行随机矩阵 ,只要 ,总存在映射 使得:

  2. ,存在 ,但对于所有线性变换 ,式(14) 仍不成立

推论:标准注意力机制甚至无法表达秩为1的对数注意力权重矩阵,这深刻揭示了其表达能力不足的本质。


2. Nexus架构

核心思想

与标准方法使用静态线性投影不同,Nexus通过嵌套自注意力机制动态精炼Query和Key表示:

高阶注意力定义

首先,通过自注意力分别精炼查询和键:

完整的高阶注意力为:

递归扩展

构建递归高阶注意力以捕获更深层的依赖:

推广到 阶:

┌─────────────────────────────────────────────────────────────┐
│                    Nexus Layer                               │
│  ┌─────────┐    ┌─────────┐    ┌─────────┐                 │
│  │   Q     │───→│Inner Attn│───→│Refined Q│──┐             │
│  │(Linear) │    └─────────┘    └─────────┘  │             │
│  └─────────┘                            ┌───┴───┐         │
│  ┌─────────┐    ┌─────────┐    ┌─────────┐    │         │
│  │   K     │───→│Inner Attn│───→│Refined K│───→│Attn→Out │
│  │(Linear) │    └─────────┘    └─────────┘    │         │
│  └─────────┘                            ┌───┴───┘         │
│  ┌─────────┐                             │                 │
│  │   V     │────────────────────────────→│                 │
│  └─────────┘                             └─────────────────┘
└─────────────────────────────────────────────────────────────┘

3. 动态Query/Key精炼

语义对齐机制

Nexus的核心洞见是:Query和Key是内层注意力循环的输出。这意味着:

  1. 上下文感知投影:标准Transformer中 仅依赖 ,而Nexus中 通过内层注意力聚合了全局信息

  2. 语义高亮:内层Key注意力呈现明显的垂直条纹,表明特定位置被许多后续token关注——这相当于在主注意力计算前识别并聚合全局相关语义

  3. 预推理能力:这种设计在最终注意力计算前完成了”预推理”步骤,简化了主注意力的任务

注意力模式可视化

实验可视化表明:

注意力类型特点功能
Baseline对角线强聚焦 + 左侧垂直带建模序列顺序
Nexus Outer与baseline类似的因果结构保持基础能力
Nexus Inner Q内层查询聚合动态调整查询表示
Nexus Inner K垂直条纹模式语义高亮与全局聚合

4. 高阶相关性建模

多Token依赖

高阶注意力机制允许模型直接捕获多token依赖,无需堆叠多层:

其中

三元关系建模

对于三个token

  • 标准注意力:需要顺序成对计算
  • 高阶注意力:在单次注意力层中同时考虑其组合影响

与CoT推理的联系

递归高阶注意力类似于在注意力层内执行链式思维(Chain-of-Thought)推理:


5. 参数共享策略

问题与解决方案

高阶机制的朴素实现需要为每个递归步骤使用独立的投影矩阵,导致参数量随 线性增长。

权重共享策略

基于假设:将向量投影到Query或Key空间的语义变换在不同递归层次上基本相似

修改递归步骤,使内层注意力复用外层权重

参数量分析

配置参数量说明
标准注意力基准
m阶朴素实现随阶数线性增长
m阶权重共享与标准注意力相同

关键优势:Nexus的参数量与标准Transformer完全相同,仅增加计算密度(推理FLOPs),而不增加模型存储。


6. 复杂度分析

时间复杂度

标准自注意力的计算复杂度为 ,其中 为序列长度。

对于 阶高阶注意力:

其中 ,递归展开得:

实践权衡

  • :约2倍计算量,但显著性能提升
  • :约4倍计算量,进一步提升
  • 实验表明 是效率与性能的最佳平衡点

内存效率

方面标准TransformerNexus
模型参数量(相同)
推理FLOPs
存储需求(相同)

7. 打破线性瓶颈的理论证明

核心洞察

Nexus通过非线性映射 替代线性投影,从而增强表达能力。

高阶注意力的表达能力

定义非线性映射:

上下文感知的聚合表示,而非简单的线性变换。

注意力模式对比

可视化实验表明,Nexus的注意力矩阵展现出:

  • 更复杂的互联模式
  • 更高的连接度和多样性
  • 更强的多面关系建模能力

这验证了Nexus能够捕获标准注意力和其它高阶注意力无法表达的依赖关系。


8. 实验结果

Pythia模型对比

在Pythia(70M-1B)上的零样本准确率对比:

规模模型ARC-CARC-EHellaswagLogiQAPiQASciQ平均
70MPythia0.2080.3590.3560.2760.5690.6150.397
70MNexus0.2040.3820.3580.2870.5860.6850.417
160MPythia0.2000.3850.3800.2600.6000.6860.419
160MNexus0.2110.4050.3850.2850.6050.7130.434
410MPythia0.2250.3940.3750.2850.6010.7080.431
410MNexus0.2260.4150.3840.2940.6080.7330.443
1BPythia0.2320.4400.3950.2960.6250.7580.458
1BNexus0.2300.4550.3990.2900.6360.7770.465

关键发现

  • Nexus在所有规模上平均准确率均优于基线
  • SciQ任务提升最大(+6%@70M),该任务需要多步推理
  • 在需要复杂逻辑依赖的任务上优势明显

消融实验

模型配置投影共享阶数平均准确率
Baseline---0.397
Nexus-QQ20.400
Nexus-QKQ, K20.409
Nexus-QKVQ, K, V20.409
Nexus-QK-SharedQ, K20.406
Nexus-RecursiveQ, K30.415

结论

  1. Q和K的高阶处理是核心,V无需高阶
  2. 权重共享仅轻微性能下降(0.409→0.406)
  3. 增加递归深度()可进一步提升

Qwen2.5模型升级实验

在数学推理基准上的表现:

基础模型方法MATH-500AIME24GPQA-Diamond平均
Qwen2.5-1.5BStandard SFT0.7860.1940.2760.419
Qwen2.5-1.5BNexus-SFT0.8010.1940.2800.425
Qwen2.5-7BStandard SFT0.9210.4520.4010.591
Qwen2.5-7BNexus-SFT0.9210.4750.4070.601

关键发现:Nexus可作为现有LLM的”架构升级套件”,在微调阶段即可带来推理能力提升。


9. 与相关工作的比较

vs. 高效注意力

方法目标权衡
Linformer复杂度低秩近似损失表达能力
Performer复杂度核近似可能不准确
Reformer依赖LSH近似
Nexus增强表达能力参数不增加,计算略增

vs. 高阶方法

方法应用领域特点
Attention on Attention图像描述增加注意力层
Deformable Attention视觉可变形卷积思想
高阶关系Transformer多模态多级结构
NexusNLP + 多模态递归+权重共享

vs. 标准Transformer

Nexus保留了标准Transformer的核心特性:

  • 因果结构(对角线聚焦)
  • 全局上下文建模
  • 可扩展的深度和宽度

同时增强了:

  • 多token依赖捕获
  • 层级关系建模
  • 推理密度

10. 实际应用建议

何时使用Nexus

适合场景

  • 多跳推理任务(LogiQA、ARC-C)
  • 数学问题求解(MATH、AIME)
  • 需要建模层级依赖的任务
  • 希望在不增加参数量的情况下提升模型能力

可能不适合

  • 计算资源极其受限的场景
  • 对延迟要求极高的在线推理

实现建议

class NexusAttention(nn.Module):
    def __init__(self, d_model, num_heads, order=2, share_weights=True):
        super().__init__()
        self.order = order
        self.share_weights = share_weights
        
        # 共享或独立的投影矩阵
        if share_weights:
            self.W_q = self.W_k = self.W_v = nn.Linear(d_model, d_model)
        else:
            self.W_q = nn.Linear(d_model, d_model)
            self.W_k = nn.Linear(d_model, d_model)
            self.W_v = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        # 标准投影
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # 递归高阶注意力
        for _ in range(self.order):
            # 内层注意力精炼Q和K
            Q = self._inner_attention(Q, Q, Q) @ self.W_q.weight
            K = self._inner_attention(K, K, K) @ self.W_k.weight
        
        # 外层主注意力
        return self._outer_attention(Q, K, V)

11. 总结

Nexus通过递归高阶注意力机制解决了标准Transformer的核心瓶颈:

方面标准注意力Nexus
投影方式线性静态非线性递归
依赖建模成对交互多token依赖
参数量(共享后)
表达能力低秩瓶颈打破瓶颈
推理密度1x~2x(m=2)

Nexus不仅是预训练的全新架构,更可作为现有LLM的”升级套件”,通过微调即可释放更强的推理能力。


参考


相关词条

Footnotes

  1. Chen et al., Nexus: Higher-Order Attention Mechanisms in Transformers, ICLR 2025

  2. Bhojanapalli et al., Low-rank bottleneck in multi-head attention models, ICML 2020