引言

标准Transformer的Softmax注意力计算复杂度为 ,其中 为序列长度。这一特性严重限制了模型处理长序列的能力。稀疏注意力机制通过只计算部分token对之间的注意力,将复杂度降至 。同时,长度外推技术使模型能够在推理时处理比训练时更长的序列。

本文系统分析稀疏注意力的设计原理和主流方法,探讨长度外推的技术路线。


稀疏注意力的动机

稠密注意力的分散问题

随着序列长度增加,标准注意力的分布趋向分散:

这导致两个核心问题:

  1. 注意力稀释:每个token分得的注意力质量下降
  2. 计算浪费:大量计算用于不重要的token对

稀疏性的自然假设

语言和图像数据具有天然的稀疏结构:

  • 局部性:相邻token通常高度相关
  • 稀疏依赖:远距离依赖是”选择性”的
  • 幂律分布:注意力权重呈长尾分布

稀疏注意力模式分类

模式图示

稠密注意力:        局部稀疏:         全局+局部:
■■■■■■■■          ■■□□□□□          ■■■■■■■■
■■■■■■■■          ■■□□□□□          □□■■■□□
■■■■■■■■    →     □□□□□□□    →     □□■■■□□
■■■■■■■■          □□□□□□□          □□□■■□□
■■■■■■■■          □□□□□□□          □□□■■□□
■■■■■■■■          □□□□□□□          □□□■■□□
■■■■■■■■          □□□□□□□          □□□■■□□

分类体系

类型模式复杂度代表方法
固定模式预定义稀疏结构Local Attention
可学习模式通过训练学习Sparse Transformer
动态模式输入自适应Longformer
组合模式多模式混合可变BigBird

主流稀疏注意力方法

1. Local Attention(局部注意力)

每个token只关注固定窗口 内的token:

def local_attention(q, k, v, window_size=128):
    n = q.shape[1]
    scores = q @ k.transpose(-2, -1)  # [B, H, n, n]
    
    # 创建局部掩码
    mask = torch.zeros_like(scores)
    for i in range(n):
        lo = max(0, i - window_size)
        hi = min(n, i + window_size + 1)
        mask[..., i, lo:hi] = 1
    
    scores = scores.masked_fill(mask == 0, -inf)
    return softmax(scores) @ v

优点:计算高效,复杂度
缺点:无法捕获长距离依赖

2. Longformer:局部+全局注意力

Longformer组合了三种注意力模式:

┌─────────────────────────────────────────────┐
│  [Global]  □□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□  │
│  [Local]   ■■■■■□□□□□□■■■■■■■■■■□□□□□□■■■■■□□  │
│  [Dilated] ■□□■□□□■□□■□□□■□□■□□□■□□■□□□■□□  │
└─────────────────────────────────────────────┘
  • 局部窗口:捕获局部上下文
  • 全局注意力:特殊token(如[CLS])关注所有位置
  • 扩张窗口:跳跃式覆盖长距离

3. H₂O(Heavy-Hitter Oracle)

H₂O基于观察:语言模型倾向于”重击”(heavy-hitter)少数关键token。通过动态维护这些关键token的稀疏集合:

累积注意力分数,选择top-作为稀疏关注对象。

4. StreamingLLM

StreamingLLM提出了Attention Sink概念:语言模型会形成4个左右的”注意力接收器”,即使这些token语义上不重要。

解决方案是保留这些sink token的key-value:

def streaming_forward(x, kv_cache, sink_tokens=4):
    B, n, D = x.shape
    
    # 检测sink token(通过注意力分散度)
    sinks = detect_sinks(x, k=sink_tokens)
    
    # 滑动窗口:只保留最近L个token
    window = x[..., -window_size:, :]
    
    # 组合sink + window
    context = torch.cat([sinks, window], dim=1)
    
    return attention(context)

长度外推技术

问题定义

设模型在长度为 的序列上训练,目标是处理长度为 的序列。

核心挑战

  1. 位置编码泛化:训练时未见过的位置
  2. 注意力分布外推:超出分布的查询-键交互
  3. 分布漂移:token统计特性的变化

技术路线

1. 位置编码改进

RoPE(Rotary Position Embedding)

RoPE通过旋转编码位置信息:

注意力分数:

关键性质:相对位置编码,不依赖绝对距离。

ALiBi(Attention with Linear Biases)

其中 是位置衰减系数。

2. SWAN-GPT:动态缩放

SWAN-GPT通过动态缩放注意力分数实现鲁棒的长度外推:

其中 是查询特定的缩放因子:

3. 训练策略

渐进式训练

Phase 1: 2K → 8K tokens
Phase 2: 8K → 32K tokens
Phase 3: 32K → 128K tokens

上下文窗口扩展

其中 随训练进度衰减。


稀疏注意力的实现优化

Flash Attention的稀疏扩展

def flash_sparse_attention(q, k, v, sparsity_mask):
    # 使用block-wise计算,只处理非稀疏位置
    return FlashAttention(
        q, k, v,
        block_mask=sparsity_mask,  # 稀疏掩码
        cu_seqlens_q=get_cu_seqlens(sparsity_mask)
    )

内存优化

策略内存复杂度适用场景
KV Cache推理加速
稀疏KV Cache长序列生成
量化KV Cache内存受限场景

性能对比

LongBench基准测试

方法32K上下文64K上下文128K上下文
Full Attention45.2OOMOOM
Longformer44.143.842.5
H₂O44.844.243.1
StreamingLLM43.543.242.8
SWAN-GPT45.044.644.3

总结与展望

稀疏注意力机制和长度外推技术是解决Transformer长上下文处理能力的关键:

  1. 稀疏模式:局部+全局的组合策略最为有效
  2. Attention Sink:理解其机制有助于设计更好的稀疏方案
  3. 外推策略:RoPE+动态缩放是目前最鲁棒的方法

未来方向

  • 自适应稀疏:根据输入动态调整稀疏模式
  • 硬件协同:专为稀疏操作优化的GPU内核
  • 理论保证:稀疏注意力的表达力界限

参考