引言
标准Transformer的Softmax注意力计算复杂度为 ,其中 为序列长度。这一特性严重限制了模型处理长序列的能力。稀疏注意力机制通过只计算部分token对之间的注意力,将复杂度降至 或 。同时,长度外推技术使模型能够在推理时处理比训练时更长的序列。
本文系统分析稀疏注意力的设计原理和主流方法,探讨长度外推的技术路线。
稀疏注意力的动机
稠密注意力的分散问题
随着序列长度增加,标准注意力的分布趋向分散:
这导致两个核心问题:
- 注意力稀释:每个token分得的注意力质量下降
- 计算浪费:大量计算用于不重要的token对
稀疏性的自然假设
语言和图像数据具有天然的稀疏结构:
- 局部性:相邻token通常高度相关
- 稀疏依赖:远距离依赖是”选择性”的
- 幂律分布:注意力权重呈长尾分布
稀疏注意力模式分类
模式图示
稠密注意力: 局部稀疏: 全局+局部:
■■■■■■■■ ■■□□□□□ ■■■■■■■■
■■■■■■■■ ■■□□□□□ □□■■■□□
■■■■■■■■ → □□□□□□□ → □□■■■□□
■■■■■■■■ □□□□□□□ □□□■■□□
■■■■■■■■ □□□□□□□ □□□■■□□
■■■■■■■■ □□□□□□□ □□□■■□□
■■■■■■■■ □□□□□□□ □□□■■□□
分类体系
| 类型 | 模式 | 复杂度 | 代表方法 |
|---|---|---|---|
| 固定模式 | 预定义稀疏结构 | Local Attention | |
| 可学习模式 | 通过训练学习 | Sparse Transformer | |
| 动态模式 | 输入自适应 | Longformer | |
| 组合模式 | 多模式混合 | 可变 | BigBird |
主流稀疏注意力方法
1. Local Attention(局部注意力)
每个token只关注固定窗口 内的token:
def local_attention(q, k, v, window_size=128):
n = q.shape[1]
scores = q @ k.transpose(-2, -1) # [B, H, n, n]
# 创建局部掩码
mask = torch.zeros_like(scores)
for i in range(n):
lo = max(0, i - window_size)
hi = min(n, i + window_size + 1)
mask[..., i, lo:hi] = 1
scores = scores.masked_fill(mask == 0, -inf)
return softmax(scores) @ v优点:计算高效,复杂度
缺点:无法捕获长距离依赖
2. Longformer:局部+全局注意力
Longformer组合了三种注意力模式:
┌─────────────────────────────────────────────┐
│ [Global] □□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□ │
│ [Local] ■■■■■□□□□□□■■■■■■■■■■□□□□□□■■■■■□□ │
│ [Dilated] ■□□■□□□■□□■□□□■□□■□□□■□□■□□□■□□ │
└─────────────────────────────────────────────┘
- 局部窗口:捕获局部上下文
- 全局注意力:特殊token(如
[CLS])关注所有位置 - 扩张窗口:跳跃式覆盖长距离
3. H₂O(Heavy-Hitter Oracle)
H₂O基于观察:语言模型倾向于”重击”(heavy-hitter)少数关键token。通过动态维护这些关键token的稀疏集合:
累积注意力分数,选择top-作为稀疏关注对象。
4. StreamingLLM
StreamingLLM提出了Attention Sink概念:语言模型会形成4个左右的”注意力接收器”,即使这些token语义上不重要。
解决方案是保留这些sink token的key-value:
def streaming_forward(x, kv_cache, sink_tokens=4):
B, n, D = x.shape
# 检测sink token(通过注意力分散度)
sinks = detect_sinks(x, k=sink_tokens)
# 滑动窗口:只保留最近L个token
window = x[..., -window_size:, :]
# 组合sink + window
context = torch.cat([sinks, window], dim=1)
return attention(context)长度外推技术
问题定义
设模型在长度为 的序列上训练,目标是处理长度为 的序列。
核心挑战
- 位置编码泛化:训练时未见过的位置
- 注意力分布外推:超出分布的查询-键交互
- 分布漂移:token统计特性的变化
技术路线
1. 位置编码改进
RoPE(Rotary Position Embedding)
RoPE通过旋转编码位置信息:
注意力分数:
关键性质:相对位置编码,不依赖绝对距离。
ALiBi(Attention with Linear Biases)
其中 是位置衰减系数。
2. SWAN-GPT:动态缩放
SWAN-GPT通过动态缩放注意力分数实现鲁棒的长度外推:
其中 是查询特定的缩放因子:
3. 训练策略
渐进式训练
Phase 1: 2K → 8K tokens
Phase 2: 8K → 32K tokens
Phase 3: 32K → 128K tokens
上下文窗口扩展
其中 随训练进度衰减。
稀疏注意力的实现优化
Flash Attention的稀疏扩展
def flash_sparse_attention(q, k, v, sparsity_mask):
# 使用block-wise计算,只处理非稀疏位置
return FlashAttention(
q, k, v,
block_mask=sparsity_mask, # 稀疏掩码
cu_seqlens_q=get_cu_seqlens(sparsity_mask)
)内存优化
| 策略 | 内存复杂度 | 适用场景 |
|---|---|---|
| KV Cache | 推理加速 | |
| 稀疏KV Cache | 长序列生成 | |
| 量化KV Cache | 内存受限场景 |
性能对比
LongBench基准测试
| 方法 | 32K上下文 | 64K上下文 | 128K上下文 |
|---|---|---|---|
| Full Attention | 45.2 | OOM | OOM |
| Longformer | 44.1 | 43.8 | 42.5 |
| H₂O | 44.8 | 44.2 | 43.1 |
| StreamingLLM | 43.5 | 43.2 | 42.8 |
| SWAN-GPT | 45.0 | 44.6 | 44.3 |
总结与展望
稀疏注意力机制和长度外推技术是解决Transformer长上下文处理能力的关键:
- 稀疏模式:局部+全局的组合策略最为有效
- Attention Sink:理解其机制有助于设计更好的稀疏方案
- 外推策略:RoPE+动态缩放是目前最鲁棒的方法
未来方向:
- 自适应稀疏:根据输入动态调整稀疏模式
- 硬件协同:专为稀疏操作优化的GPU内核
- 理论保证:稀疏注意力的表达力界限