注意力机制变体综合对比
1. 引言:注意力机制的演进
注意力机制(Attention Mechanism)自2017年由Vaswani等人提出以来,已成为现代深度学习尤其是序列建模领域的核心技术。1 原始Transformer架构中的自注意力(Self-Attention)机制通过计算序列中所有位置之间的依赖关系,实现了并行化序列建模,但其核心问题在于计算复杂度与序列长度呈平方关系 ,这严重制约了模型处理长序列的能力。
1.1 从原始Transformer到高效注意力
随着应用场景从短文本向长文档、长视频、基因组序列等扩展,研究者提出了大量高效注意力变体。这些变体在**效率(Efficiency)与表达力(Expressivity)**之间寻求不同平衡:
1.2 效率与表达力的基本矛盾
注意力机制的效率与表达力之间存在根本性张力:
| 维度 | 高效率方案 | 高表达力方案 |
|---|---|---|
| 计算复杂度 | 或 | |
| 内存占用 | ||
| 建模能力 | 受限于稀疏模式 | 全局感受野 |
| 适用场景 | 长序列、资源受限 | 短序列、精度敏感 |
1.3 本文档的目标
本文旨在建立一个系统性分类框架,对现代注意力机制变体进行全面梳理和对比分析。主要内容包括:
- 注意力机制的统一分类体系
- 稀疏注意力模式的技术对比
- 线性注意力变体的理论分析
- FlashAttention系列的发展演进
- 混合注意力架构的设计原则
- 实践场景下的选择指南
2. 注意力机制分类体系
2.1 按计算复杂度分类
根据时间复杂度的渐进增长率,可以将注意力机制分为三类:
| 类别 | 时间复杂度 | 代表方法 |
|---|---|---|
| 二次复杂度 | 标准注意力、FlashAttention | |
| 次二次复杂度 | LSH Attention、分块稀疏 | |
| 线性复杂度 | 线性注意力、核方法 |
标准注意力的核心运算是 Query-Key 矩阵乘积:
这导致了 的注意力矩阵存储和 级别的计算量。
2.2 按近似性质分类
根据是否引入近似,可以分为:
精确注意力:保留原始注意力矩阵的完整信息,包括:
- 标准Transformer注意力
- FlashAttention系列(IO感知优化,不损失精度)
- 局部窗口注意力(在窗口内精确)
近似注意力:通过各种技术近似原始注意力,包括:
- 核函数近似(Performer)
- 低秩分解(Nyströmformer)
- 哈希聚类(Reformer)
- 随机投影(Random Feature方法)
2.3 按稀疏模式分类
Dense Attention(密集注意力):每个位置与所有其他位置交互,代表即原始Transformer。
Sparse Attention(稀疏注意力):仅计算部分位置对之间的注意力分数:
- 固定稀疏:预定义稀疏模式(局部窗口、块稀疏)
- 动态稀疏:数据驱动的稀疏模式学习(路由、聚类)
- 分层稀疏:多尺度、分层组织的稀疏模式
Linear Attention(线性注意力):通过数学变换将二次复杂度降低到线性:
其中 是核函数,将输入映射到特征空间。
3. 稀疏注意力模式对比
稀疏注意力的核心思想是:并非所有位置对之间的交互都是必需的。通过选择性计算部分关键交互,可以在保持主要建模能力的同时显著降低计算成本。
3.1 Global + Local模式
Global + Local模式结合了全局注意力(捕获长程依赖)和局部注意力(捕获局部结构)两种机制。
3.1.1 Longformer
Longformer2由Allen AI研究院提出,设计了一种组合式稀疏注意力模式:
注意力模式配置:
# Longformer 注意力模式示意
class LongformerAttention:
def __init__(self, window_size=512, num_global_tokens=2):
self.window_size = window_size
self.num_global_tokens = num_global_tokens # CLS + 特殊token
def compute_attention(self, q, k, v):
# 1. 全局注意力:特殊token与所有位置交互
global_attn = self.global_attention(q[:self.num_global_tokens], k, v)
# 2. 滑动窗口注意力:每个位置与局部窗口内位置交互
local_attn = self.sliding_window_attention(q, k, v, self.window_size)
# 3. 扩张滑动窗口:间隔采样增加感受野
dilated_attn = self.dilated_attention(q, k, v, dilation=2)
return combine_attentions(global_attn, local_attn, dilated_attn)复杂度分析:对于序列长度 ,窗口大小 ,全局token数量 :
当 时,总复杂度接近线性。
设计特点:
- 全局注意力应用于特殊token(如[CLS])和可学习的[global] token
- 滑动窗口用于大多数普通token
- 膨胀窗口(Dilated Window)进一步扩大感受野而不增加计算量
3.1.2 BigBird
BigBird3由Google Research提出,证明了稀疏注意力可以逼近全注意力(Universal Approximation)。其注意力模式包含三个组件:
稀疏模式组成:
| 组件 | 描述 | 示意图 |
|---|---|---|
| 全局注意力 | 所有token与预设的全局token交互 | |
| 随机注意力 | 每个token与随机选取的token交互 | |
| 滑动窗口注意力 | 每个token与局部窗口内token交互 |
# BigBird 稀疏注意力实现
class BigBirdSparseAttention:
def __init__(self, num_random=3, window_size=3):
self.num_random = num_random
self.window_size = window_size
def build_sparse_mask(self, seq_len, num_global=2):
# 1. 全局token(首尾)
global_indices = list(range(num_global)) + [seq_len - 1]
# 2. 滑动窗口(中心位置)
window_indices = [create_window(i, self.window_size) for i in range(seq_len)]
# 3. 随机连接
random_indices = [random.sample(range(seq_len), self.num_random)
for _ in range(seq_len)]
# 合并并去重
sparse_mask = [set(global_indices + w + r)
for w, r in zip(window_indices, random_indices)]
return sparse_mask理论保证:BigBird论文证明了这种稀疏模式在图灵机模拟能力上与全注意力等价,为稀疏注意力的有效性提供了理论支撑。
3.2 固定稀疏模式
固定稀疏模式使用预定义的稀疏结构,不依赖数据或学习。
3.2.1 块稀疏注意力(Block Sparse Attention)
块稀疏注意力4将注意力矩阵划分为固定大小的块,只计算部分块内的注意力:
其中 表示位置 所属的稀疏块集合。
// 块稀疏注意力计算伪代码
void block_sparse_attention(
const Tensor& Q, // (batch, heads, N, d_k)
const Tensor& K, // (batch, heads, N, d_k)
const Tensor& V, // (batch, heads, N, d_v)
const Mask& mask, // 块稀疏掩码
Tensor& output) {
int block_size = 64;
int num_blocks = N / block_size;
for (int b = 0; b < batch; ++b) {
for (int h = 0; h < num_heads; ++h) {
for (int i = 0; i < num_blocks; ++i) {
for (int j = 0; j < num_blocks; ++j) {
if (mask[i][j]) { // 稀疏掩码
// 计算块内注意力
auto S_block = matmul(
Q[b][h][i*block_size:(i+1)*block_size],
K[b][h][j*block_size:(j+1)*block_size].transpose()
) / sqrt(d_k);
auto attn_block = softmax(S_block);
auto out_block = matmul(attn_block,
V[b][h][j*block_size:(j+1)*block_size]);
output[b][h][i*block_size:(i+1)*block_size] += out_block;
}
}
}
}
}
}3.2.2 局部窗口 + 膨胀的组合
通过膨胀(Dilation)机制,可以在不增加参数量的前提下指数级扩大感受野:
其中 是膨胀率, 表示从索引0开始、步长为 采样的子序列。
3.3 动态稀疏模式
动态稀疏模式根据输入数据动态决定注意力连接,具有更强的自适应能力。
3.3.1 Routing Transformer
Routing Transformer5使用 -means聚类将token路由到不同的”桶”中:
核心机制:
class RoutingAttention:
def __init__(self, num_routes, num_heads):
self.num_routes = num_routes
self.num_heads = num_heads
def route_tokens(self, keys, queries):
# 对keys进行k-means聚类
routes = []
for h in range(self.num_heads):
# 计算每个token属于哪个聚类中心
cluster_ids = kmeans(
keys[:, h, :, :], # (batch, seq, d_k)
k=self.num_routes,
centroids=learnable_centroids[h]
)
routes.append(cluster_ids)
return routes
def attend(self, q, k, v, routes):
# 在路由层面计算注意力
for h in range(self.num_heads):
for route_id in range(self.num_routes):
# 获取该路由的query和key
route_mask = routes[h] == route_id
q_route = q[:, route_mask, h, :]
k_route = k[:, route_mask, h, :]
v_route = v[:, route_mask, h, :]
# 路由内局部注意力
attn_route = local_attention(q_route, k_route, v_route)
# ...复杂度分析:假设token均匀分布在 个路由中,则复杂度降为 。
3.3.2 Reformer:局部敏感哈希注意力
Reformer6使用局部敏感哈希(Locality-Sensitive Hashing, LSH)将相似的Query和Key分配到同一”桶”中:
LSH机制:
// 简化的LSH注意力实现
class LSHAttention {
public:
// LSH函数:将向量映射到哈希桶
int lsh_forward(const Vec& x, int num_buckets, int num_hashes) {
// 生成随机投影向量
std::vector<Vec> projections = generate_random_projections(x.dim, num_buckets);
// 计算哈希值(取多个哈希的组合)
std::vector<int> hashes;
for (int i = 0; i < num_hashes; ++i) {
float dot = x.dot(projections[i]);
int bucket = (dot > 0) ? 1 : 0;
hashes.push_back(bucket << i); // 组合多个哈希位
}
return std::accumulate(hashes.begin(), hashes.end(), 0);
}
// LSH注意力计算
Tensor lsh_attention(const Tensor& Q, const Tensor& K, const Tensor& V) {
int num_buckets = 64;
int num_hashes = 8;
// 1. 为每个Query计算哈希桶
std::vector<int> q_buckets = compute_hashes(Q, num_buckets, num_hashes);
// 2. 为每个Key计算哈希桶
std::vector<int> k_buckets = compute_hashes(K, num_buckets, num_hashes);
// 3. 按桶排序以增加哈希冲突概率
auto [sorted_q, q_indices] = argsort(q_buckets);
auto [sorted_k, k_indices] = argsort(k_buckets);
// 4. 在同一桶内计算局部注意力
// 同一桶内的元素大概率相似(LSH特性)
Tensor output = compute_local_attention(sorted_q, sorted_k, sorted_k);
// 5. 恢复到原始顺序
return unsort_by_indices(output, q_indices);
return output;
}
};LSH理论背景:LSH的核心特性是”相似向量大概率映射到同一桶”:
其中 是向量 和 之间的夹角。
3.4 分层稀疏模式
分层稀疏模式通过多尺度/层次化设计,在不同层级使用不同的稀疏策略。
3.4.1 Swin Transformer的移位窗口注意力
Swin Transformer7通过层次化结构和移位窗口机制实现了高效的稀疏注意力。详见Swin Transformer。
核心设计:
Stage 1: 划分非重叠窗口 (W×W)
┌─────┬─────┬─────┐
│ W_1 │ W_2 │ W_3 │
├─────┼─────┼─────┤
│ W_4 │ W_5 │ W_6 │
└─────┴─────┴─────┘
↓ 移位 (⌈W/2⌉, ⌈W/2⌉)
Stage 2: 移位后的窗口划分
┌─────┬─────┬─────┬─────┐
│ A │ B │ │ D │
├─────┼─────┼─────┼─────┤
│ E │ F │ G │ H │
├─────┼─────┼─────┼─────┤
│ │ I │ J │ K │
├─────┼─────┼─────┼─────┤
│ L │ M │ │ O │
└─────┴─────┴─────┴─────┘
移位机制确保相邻窗口之间存在交叉连接,弥补了固定窗口导致的感受野局限。
3.4.2 金字塔结构与多尺度表示
金字塔结构(如Pyramid Vision Transformer、PVT)在不同阶段逐步降低分辨率,实现计算量的有效控制:
| Stage | 分辨率 | 特征维度 | 注意力类型 |
|---|---|---|---|
| Stage 1 | 局部窗口注意力 | ||
| Stage 2 | 稀疏全局注意力 | ||
| Stage 3 | 稀疏全局注意力 | ||
| Stage 4 | 稀疏全局注意力 |
4. 线性注意力变体对比
线性注意力的核心思想是通过数学变换,将 的注意力计算转换为 的线性计算。
4.1 核函数方法
核函数方法通过特征映射 将原始高维空间中的点积运算转换为低维空间中的近似计算。
4.1.1 Performer:Random Feature (FAVOR+)
Performer8使用Random Feature近似方法,将softmax注意力转化为线性形式:
理论基础:利用Bochner定理,softmax核函数可以表示为随机特征期望:
其中 是随机特征映射:
其中 是随机矩阵, 是随机偏置。
// Performer FAVOR+ 机制实现
template <typename scalar_t>
struct PerformerAttention {
// 随机特征映射
static Tensor random_feature(const Tensor& x, int m) {
// 生成随机矩阵和偏置
auto W = randn({x.size(-1), m});
auto b = uniform({m}, 0, 2 * M_PI);
// 指数随机投影
Tensor proj = matmul(x, W) + b;
// 复数形式的cos/sin特征
Tensor cos_feat = cos(proj);
Tensor sin_feat = sin(proj);
// 连接并归一化
return cat({cos_feat, sin_feat}, dim=-1) / sqrt(m);
}
// 线性复杂度的近似注意力
static Tensor favor_attention(const Tensor& Q, const Tensor& K,
const Tensor& V, int m) {
// 映射到随机特征空间
Tensor Q_prime = random_feature(Q, m); // (B, H, N, 2m)
Tensor K_prime = random_feature(K, m); // (B, H, N, 2m)
// 线性注意力:先计算 K'V,再与 Q' 计算
// \phi(K)^T V: (B, H, 2m, d)
Tensor KV_product = matmul(K_prime.transpose(-2, -1), V);
// Q' @ (K'V) / (Q' @ K' @ 1)
Tensor numerator = matmul(Q_prime, KV_product);
Tensor denominator = matmul(Q_prime, K_prime.sum(-2, true));
return numerator / (denominator + 1e-6);
}
};近似误差分析:Performer提供了有界的近似误差保证:
其中 与随机特征维度 相关。
4.1.2 近似误差与收敛性保证
核函数方法的近似质量取决于随机特征的维度 和输入分布。Performer论文给出了以下收敛保证:
中心极限定理视角:当 时,随机特征估计收敛到真实核函数值:
方差估计:设 ,则:
4.2 线性化方法
线性化方法直接修改注意力机制的结构,移除导致二次复杂度的操作。
4.2.1 Linear Transformer
Linear Transformer9将softmax函数替换为特征映射函数:
原始softmax注意力:
Linear Transformer:
其中 (指数线性单元加1)。
class LinearTransformerAttention(nn.Module):
"""
线性注意力机制
核心:将softmax分解为特征映射和加权聚合
"""
def __init__(self, d_model, d_k):
super().__init__()
self.d_k = d_k
self.proj_q = nn.Linear(d_model, d_k)
self.proj_k = nn.Linear(d_model, d_k)
self.proj_v = nn.Linear(d_model, d_k)
self.scale = d_k ** -0.5
def forward(self, Q, K, V):
# 特征映射:elu + 1
Q_prime = F.elu(Q) + 1
K_prime = F.elu(K) + 1
# 分母:每个query的归一化项
K_sum = K_prime.sum(dim=-2, keepdim=True) # (B, 1, H, d)
denom = (Q_prime * K_sum).sum(dim=-1, keepdim=True) # (B, N, H, 1)
# 分子:加权聚合
KV_product = torch.einsum('bnhd,bnhm->bhmd', K_prime, V)
numerator = torch.einsum('bnhd,bhmd->bnhm', Q_prime, KV_product)
return numerator / (denom + 1e-8)4.2.2 位置编码补偿
线性注意力移除了softmax,损失了自归一化特性,需要额外的位置编码来建模序列顺序:
解决方案:
-
相对位置编码:在特征映射中引入位置信息
-
傅里叶位置编码:使用正弦/余弦函数
-
旋转位置编码(RoPE):
详见ALiBi位置编码。
4.3 低秩方法
低秩方法利用注意力矩阵的低秩结构进行近似分解。
4.3.1 Nyströmformer:基于Landmark的近似
Nyströmformer10使用Nyström方法,利用矩阵的低秩近似:
Nyström近似理论:对于半正定矩阵 ,选取 个landmark点 ,则:
其中 是landmark索引集合。
class NystromAttention(nn.Module):
"""
Nyströmformer注意力
使用Nyström方法近似全注意力
"""
def __init__(self, num_landmarks=64):
super().__init__()
self.num_landmarks = num_landmarks
def forward(self, Q, K, V):
seq_len = Q.size(1)
# 1. 选择landmark点(均匀采样)
landmark_indices = torch.linspace(
0, seq_len-1, self.num_landmarks,
dtype=torch.long, device=Q.device
)
# 2. 计算子矩阵
K_landmark = K[:, landmark_indices, :, :] # (B, m, H, d)
# 3. 计算伪逆(使用可学习的kernel矩阵)
# Q @ K^T ≈ Q @ K_landmark^T @ M @ K_landmark @ V
M = torch.softmax(K_landmark @ K_landmark.transpose(-2, -1), dim=-1)
# 4. 计算Nyström近似
# 分子部分
KV = torch.einsum('bnhd,bmhd->bhnm', K_landmark, V)
Z = M @ KV # (B, H, m, d)
numerator = torch.einsum('bnhd,bhmd->bnhm', Q, Z)
# 分母部分
Z = M @ K_landmark # (B, H, m, d)
denominator = torch.einsum('bnhd,bhmd->bnhm', Q, Z)
return numerator / (denominator + 1e-8)4.3.2 矩阵分解视角
从矩阵分解视角看,注意力矩阵 通常具有低秩结构。设 的秩为 ,则:
其中 是低秩分解的基, 是奇异值对角矩阵。
Nyströmformer利用这一特性,通过采样landmarks来估计低秩分解。
5. FlashAttention系列对比
FlashAttention11系列是近年来最重要的注意力优化技术,通过IO感知(IO-aware)设计和算法创新,在不损失精度的前提下显著降低内存占用。
5.1 FlashAttention-1:2D Tiling,IO感知
FlashAttention-1由Stanford大学提出,核心思想是分块计算(Tiling)和算力融合(Kernel Fusion)。
核心创新:
// FlashAttention 核心算法伪代码
template <typename scalar_t>
__global__ void flash_attention_kernel(
const scalar_t* Q, const scalar_t* K, const scalar_t* V,
scalar_t* O, scalar_t* L, // 累积的log-sum-exp
int seq_len, int head_dim, int block_m, int block_n) {
// 1. 为当前query块分配共享内存
__shared__ scalar_t sQ[BLOCK_M][HEAD_DIM];
__shared__ scalar_t sK[BLOCK_N][HEAD_DIM];
__shared__ scalar_t sV[BLOCK_N][HEAD_DIM];
// 2. 外循环:遍历query块
for (int cur_m = blockIdx.x * block_m; cur_m < seq_len; cur_m += block_m) {
// 初始化输出和归一化因子
scalar_t m_i = -INFINITY; // 当前行的最大值
scalar_t l_i = 0; // 当前行的指数和
// 3. 内循环:分块计算注意力
for (int cur_n = 0; cur_n < seq_len; cur_n += block_n) {
// 加载K、V块到共享内存
load_to_shared(sK, K, cur_n, block_n);
load_to_shared(sV, V, cur_n, block_n);
__syncthreads();
// 计算QK^T的块
scalar_t m_hat = -INFINITY;
for (int j = 0; j < block_n; ++j) {
scalar_t qk = dot(sQ[i], sK[j]);
m_hat = max(m_hat, qk);
}
// 安全softmax计算
for (int j = 0; j < block_n; ++j) {
exp_qk = exp(sQ[i] * sK[j] - m_new);
// 更新行归一化因子
}
}
// 写入输出
O[cur_m + i] = l_i_new * o_i;
}
}IO复杂度分析:FlashAttention-1的IO复杂度为 ,其中 是SRAM大小, 是HBM带宽。相比标准实现,内存访问量减少约 倍。
5.2 FlashAttention-2:更好的并行性
FlashAttention-212在v1基础上进行了多项优化:
主要改进:
- 更好的并行策略:v1只在query维度并行,v2同时在query和key维度并行
- 更细粒度的Tiling:减少不必要的内存访问
- 支持序列长度非16/32倍数:更灵活的padding
// FlashAttention-2 并行策略
// v1: grid(batch * num_heads),串行处理key维度
// v2: grid(batch * num_heads * num_stages),并行处理key维度
// 新的tiling策略:减少寄存器压力
template <typename scalar_t>
__global__ void flash_attention_v2_kernel(
const scalar_t* Q, const scalar_t* K, const scalar_t* V,
scalar_t* O, scalar_t* L,
int seq_len, int num_stages) {
// stage 0: 处理前seq_len/2的key
// stage 1: 处理后seq_len/2的key
// 两个stage可以同时执行
for (int stage = 0; stage < num_stages; ++stage) {
// 预取下一阶段的数据
prefetch_KV(stage + 1);
// 并行计算
parallel_attention_block(Q_block, K_block, V_block, O_block);
}
}5.3 FlashAttention-3:3D并行,FP8
FlashAttention-313进一步挖掘现代GPU架构的潜力:
架构级优化:
| 特性 | FlashAttn-1 | FlashAttn-2 | FlashAttn-3 |
|---|---|---|---|
| 并行维度 | 2D (Q, Head) | 2D (Q, Head) | 3D (Q, Head, K) |
| Tensor Core | 不支持 | 部分支持 | 完全支持 |
| 数值精度 | FP16/BF16 | FP16/BF16 | FP16/BF16/FP8 |
| ** warp专门化** | 否 | 否 | 是 |
| 异步执行 | 否 | 否 | 是 |
FP8支持:
// FlashAttention-3 FP8 实现
__global__ void flash_attention_fp8_kernel(
const __fp8_e4m3* Q, const __fp8_e4m3* K, const __fp8_e4m3* V,
float* O, int seq_len) {
// FP8 -> BF16 转换(保留高精度累加)
__bf16* Q_bf16 = convert_fp8_to_bf16(Q);
__bf16* K_bf16 = convert_fp8_to_bf16(K);
__bf16* V_bf16 = convert_fp8_to_bf16(V);
// Tensor Core矩阵乘法(支持FP8输入)
// WGMMA指令:异步矩阵乘
wgmma_ma_cluster_t ma_cluster;
wgmma_wait_group_t wait_group;
// 异步启动矩阵乘法
wgmma_emit_window(ma_cluster, Q_bf16);
wgmma_accumulate(ma_cluster, K_bf16, &wait_group);
// 其他计算...
wgmma_commit_group(wait_group);
wgmma_wait_group(wait_group);
}5.4 对比表
| 特性 | FlashAttention-1 | FlashAttention-2 | FlashAttention-3 |
|---|---|---|---|
| 论文年份 | 2022 | 2023 | 2024 |
| H100加速 | ~4× | ~8× | ~16× |
| 内存占用 | |||
| 精度损失 | 无 | 无 | 无(可配) |
| 长序列优势 | 显著 | 更显著 | 极显著 |
| 数值格式 | FP16/BF16 | FP16/BF16 | FP16/BF16/FP8 |
6. 混合注意力架构
混合注意力架构结合不同类型注意力的优势,在效率与表达力之间寻求更好平衡。
6.1 Hybrid Attention-SSM
选择性状态空间模型(Selective SSM)与注意力机制的结合是近年来的重要研究方向。详见状态空间模型。
6.1.1 Mamba:选择性状态空间+注意力机制
Mamba14通过选择性机制(Selectivity Mechanism)使SSM能够像注意力一样”选择性”地处理信息。详见Mamba-2理论。
Mamba vs 注意力对比:
| 维度 | 注意力 | Mamba (SSM) |
|---|---|---|
| 计算复杂度 | ||
| 并行性 | 高度并行 | 受限(递归结构) |
| 选择性 | 可学习 | 输入依赖(选择性) |
| 位置建模 | 需额外编码 | 隐式建模 |
| 内存占用 |
6.1.2 Hyena:隐式长卷积的混合
Hyena15使用多层长卷积(Long Convolution)与注意力结合:
其中 是可学习的滤波器系数,与输入相关。
6.2 滑动窗口+全局注意力
Longformer和BigBird(详见第3节)是滑动窗口+全局注意力模式的典型代表。
设计原则:
- 局部建模:使用滑动窗口注意力捕获局部结构
- 全局交互:少量全局token负责跨距离信息传递
- 效率平衡:局部注意力 ,全局注意力
7. 实践选择指南
7.1 场景驱动的注意力选择
根据序列长度和应用场景选择合适的注意力机制:
| 序列长度 | 推荐方案 | 理由 |
|---|---|---|
| 短序列 () | 标准注意力 | 精度优先,二次复杂度可接受 |
| 中等序列 () | FlashAttention | 内存高效,保持精度 |
| 长序列 () | 稀疏/线性注意力 | 计算和内存约束 |
| 超长序列 () | 分层稀疏+局部注意力 | 多尺度建模能力 |
7.2 资源约束分析
内存受限场景:
- 优先选择FlashAttention系列(内存 vs )
- 考虑稀疏注意力(存储稀疏矩阵而非稠密矩阵)
- Linear Transformer内存需求最小
计算受限场景:
- 稀疏注意力在计算量上优势明显
- 分块稀疏可利用硬件特性加速
- 混合精度(FP8)可提升吞吐量
精度优先场景:
- FlashAttention是唯一在保持精度同时优化内存的方案
- 稀疏注意力的精度损失需要具体评估
- 核函数近似的误差积累需要注意
7.3 精度-效率权衡
不同注意力变体的精度-效率权衡曲线:
精度
^
│ ★ FlashAttention
│ ★ ★
│ ★ ★
│ ★
│ ★ ★ 稀疏注意力
│ ★ ★ ★
│ ★ ★ ★
│★ ★
│__________________________> 计算效率
↑ ↑
低效率 高效率
8. 核心算法对比表
| 注意力类型 | 时间复杂度 | 空间复杂度 | 近似误差 | 代表模型 | 精度保持 |
|---|---|---|---|---|---|
| 标准Attention | 0 | Vanilla Transformer | ✓ | ||
| FlashAttention-1 | 0 | FlashAttention | ✓ | ||
| FlashAttention-2 | 0 | FlashAttention-2 | ✓ | ||
| FlashAttention-3 | 0 | FlashAttention-3 | ✓ | ||
| Performer | 有界 | Performer | ~ | ||
| Linear Transformer | 有界 | Linear Transformer | ~ | ||
| LSH Attention | 有界 | Reformer | ~ | ||
| Nyströmformer | 有界 | Nyströmformer | ~ | ||
| Longformer | 可控 | Longformer | ✓ | ||
| BigBird | 可控 | BigBird | ✓ | ||
| Swin Transformer | 可控 | Swin Transformer | ✓ | ||
| Routing Transformer | 有界 | Routing Transformer | ~ |
注:~ 表示精度损失程度取决于具体任务和配置;✓ 表示精度保持
9. 参考文献
Footnotes
-
Vaswani A, Shazeer N, Parmar N, et al. Attention is All you Need[J]. Advances in Neural Information Processing Systems, 2017. ↩
-
Beltagy I, Peters M E, Cohan A. Longformer: The Long-Document Transformer[J]. arXiv preprint arXiv:2004.05150, 2020. ↩
-
Zaheer M, Guruganesh G, Dubey K A, et al. Big Bird: Transformers for Longer Sequences[J]. Advances in Neural Information Processing Systems, 2020. ↩
-
Child R, Gray S, Radford A, et al. Generating Long Sequences with Sparse Transformers[J]. arXiv preprint arXiv:1904.10509, 2019. ↩
-
Roy A, Saffar M, Vaswani A, et al. Efficient Content-Based Sparse Attention with Routing Transformers[J]. Transactions of the Association for Computational Linguistics, 2021. ↩
-
Kitaev N, Kaiser L, Levskaya A. Reformer: The Efficient Transformer[J]. International Conference on Learning Representations, 2020. ↩
-
Liu Z, Lin Y, Cao Y, et al. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows[C]. IEEE/CVF International Conference on Computer Vision, 2021. ↩
-
Choromanski H, Likhosherstov V, Dohan D, et al. Rethinking Attention with Performers[J]. International Conference on Learning Representations, 2021. ↩
-
Katharopoulos A, Vyas A, Pappas N, et al. Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention[C]. International Conference on Machine Learning, 2020. ↩
-
Xiong Y, Zeng Z, Chakraborty R, et al. Nyströmformer: A Nyström-based Algorithm for Approximating Self-Attention[J]. AAAI Conference on Artificial Intelligence, 2021. ↩
-
Dao T, Fu D Y, Ermon S, et al. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness[J]. Advances in Neural Information Processing Systems, 2022. ↩
-
Dao T. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning[J]. International Conference on Learning Representations, 2024. ↩
-
Shah A, Chen Y, Luo J, et al. FlashAttention-3: Fast and Accurate Attention with A100 GPU Architecture and Irregular Matrices[J]. arXiv preprint arXiv:2407.08608, 2024. ↩
-
Gu A, Dao T. Mamba: Linear-Time Sequence Modeling with Selective State Spaces[J]. arXiv preprint arXiv:2312.00752, 2023. ↩
-
Poli M, Massaroli S, Nguyen E, et al. Hyena Hierarchy: Towards Larger Convolutional Language Models[J]. International Conference on Machine Learning, 2023. ↩