注意力机制变体综合对比

1. 引言:注意力机制的演进

注意力机制(Attention Mechanism)自2017年由Vaswani等人提出以来,已成为现代深度学习尤其是序列建模领域的核心技术。1 原始Transformer架构中的自注意力(Self-Attention)机制通过计算序列中所有位置之间的依赖关系,实现了并行化序列建模,但其核心问题在于计算复杂度与序列长度呈平方关系 ,这严重制约了模型处理长序列的能力。

1.1 从原始Transformer到高效注意力

随着应用场景从短文本向长文档、长视频、基因组序列等扩展,研究者提出了大量高效注意力变体。这些变体在**效率(Efficiency)表达力(Expressivity)**之间寻求不同平衡:

1.2 效率与表达力的基本矛盾

注意力机制的效率与表达力之间存在根本性张力:

维度高效率方案高表达力方案
计算复杂度
内存占用
建模能力受限于稀疏模式全局感受野
适用场景长序列、资源受限短序列、精度敏感

1.3 本文档的目标

本文旨在建立一个系统性分类框架,对现代注意力机制变体进行全面梳理和对比分析。主要内容包括:

  • 注意力机制的统一分类体系
  • 稀疏注意力模式的技术对比
  • 线性注意力变体的理论分析
  • FlashAttention系列的发展演进
  • 混合注意力架构的设计原则
  • 实践场景下的选择指南

2. 注意力机制分类体系

2.1 按计算复杂度分类

根据时间复杂度的渐进增长率,可以将注意力机制分为三类:

类别时间复杂度代表方法
二次复杂度标准注意力、FlashAttention
次二次复杂度LSH Attention、分块稀疏
线性复杂度线性注意力、核方法

标准注意力的核心运算是 Query-Key 矩阵乘积:

这导致了 的注意力矩阵存储和 级别的计算量。

2.2 按近似性质分类

根据是否引入近似,可以分为:

精确注意力:保留原始注意力矩阵的完整信息,包括:

  • 标准Transformer注意力
  • FlashAttention系列(IO感知优化,不损失精度)
  • 局部窗口注意力(在窗口内精确)

近似注意力:通过各种技术近似原始注意力,包括:

  • 核函数近似(Performer)
  • 低秩分解(Nyströmformer)
  • 哈希聚类(Reformer)
  • 随机投影(Random Feature方法)

2.3 按稀疏模式分类

Dense Attention(密集注意力):每个位置与所有其他位置交互,代表即原始Transformer。

Sparse Attention(稀疏注意力):仅计算部分位置对之间的注意力分数:

  • 固定稀疏:预定义稀疏模式(局部窗口、块稀疏)
  • 动态稀疏:数据驱动的稀疏模式学习(路由、聚类)
  • 分层稀疏:多尺度、分层组织的稀疏模式

Linear Attention(线性注意力):通过数学变换将二次复杂度降低到线性:

其中 是核函数,将输入映射到特征空间。


3. 稀疏注意力模式对比

稀疏注意力的核心思想是:并非所有位置对之间的交互都是必需的。通过选择性计算部分关键交互,可以在保持主要建模能力的同时显著降低计算成本。

3.1 Global + Local模式

Global + Local模式结合了全局注意力(捕获长程依赖)和局部注意力(捕获局部结构)两种机制。

3.1.1 Longformer

Longformer2由Allen AI研究院提出,设计了一种组合式稀疏注意力模式:

注意力模式配置

# Longformer 注意力模式示意
class LongformerAttention:
    def __init__(self, window_size=512, num_global_tokens=2):
        self.window_size = window_size
        self.num_global_tokens = num_global_tokens  # CLS + 特殊token
    
    def compute_attention(self, q, k, v):
        # 1. 全局注意力:特殊token与所有位置交互
        global_attn = self.global_attention(q[:self.num_global_tokens], k, v)
        
        # 2. 滑动窗口注意力:每个位置与局部窗口内位置交互
        local_attn = self.sliding_window_attention(q, k, v, self.window_size)
        
        # 3. 扩张滑动窗口:间隔采样增加感受野
        dilated_attn = self.dilated_attention(q, k, v, dilation=2)
        
        return combine_attentions(global_attn, local_attn, dilated_attn)

复杂度分析:对于序列长度 ,窗口大小 ,全局token数量

时,总复杂度接近线性。

设计特点

  • 全局注意力应用于特殊token(如[CLS])和可学习的[global] token
  • 滑动窗口用于大多数普通token
  • 膨胀窗口(Dilated Window)进一步扩大感受野而不增加计算量

3.1.2 BigBird

BigBird3由Google Research提出,证明了稀疏注意力可以逼近全注意力(Universal Approximation)。其注意力模式包含三个组件:

稀疏模式组成

组件描述示意图
全局注意力所有token与预设的全局token交互
随机注意力每个token与随机选取的token交互
滑动窗口注意力每个token与局部窗口内token交互
# BigBird 稀疏注意力实现
class BigBirdSparseAttention:
    def __init__(self, num_random=3, window_size=3):
        self.num_random = num_random
        self.window_size = window_size
    
    def build_sparse_mask(self, seq_len, num_global=2):
        # 1. 全局token(首尾)
        global_indices = list(range(num_global)) + [seq_len - 1]
        
        # 2. 滑动窗口(中心位置)
        window_indices = [create_window(i, self.window_size) for i in range(seq_len)]
        
        # 3. 随机连接
        random_indices = [random.sample(range(seq_len), self.num_random) 
                          for _ in range(seq_len)]
        
        # 合并并去重
        sparse_mask = [set(global_indices + w + r) 
                      for w, r in zip(window_indices, random_indices)]
        return sparse_mask

理论保证:BigBird论文证明了这种稀疏模式在图灵机模拟能力上与全注意力等价,为稀疏注意力的有效性提供了理论支撑。

3.2 固定稀疏模式

固定稀疏模式使用预定义的稀疏结构,不依赖数据或学习。

3.2.1 块稀疏注意力(Block Sparse Attention)

块稀疏注意力4将注意力矩阵划分为固定大小的块,只计算部分块内的注意力:

其中 表示位置 所属的稀疏块集合。

// 块稀疏注意力计算伪代码
void block_sparse_attention(
    const Tensor& Q,      // (batch, heads, N, d_k)
    const Tensor& K,      // (batch, heads, N, d_k)
    const Tensor& V,      // (batch, heads, N, d_v)
    const Mask& mask,     // 块稀疏掩码
    Tensor& output) {
    
    int block_size = 64;
    int num_blocks = N / block_size;
    
    for (int b = 0; b < batch; ++b) {
        for (int h = 0; h < num_heads; ++h) {
            for (int i = 0; i < num_blocks; ++i) {
                for (int j = 0; j < num_blocks; ++j) {
                    if (mask[i][j]) {  // 稀疏掩码
                        // 计算块内注意力
                        auto S_block = matmul(
                            Q[b][h][i*block_size:(i+1)*block_size],
                            K[b][h][j*block_size:(j+1)*block_size].transpose()
                        ) / sqrt(d_k);
                        
                        auto attn_block = softmax(S_block);
                        auto out_block = matmul(attn_block, 
                            V[b][h][j*block_size:(j+1)*block_size]);
                        
                        output[b][h][i*block_size:(i+1)*block_size] += out_block;
                    }
                }
            }
        }
    }
}

3.2.2 局部窗口 + 膨胀的组合

通过膨胀(Dilation)机制,可以在不增加参数量的前提下指数级扩大感受野:

其中 是膨胀率, 表示从索引0开始、步长为 采样的子序列。

3.3 动态稀疏模式

动态稀疏模式根据输入数据动态决定注意力连接,具有更强的自适应能力。

3.3.1 Routing Transformer

Routing Transformer5使用 -means聚类将token路由到不同的”桶”中:

核心机制

class RoutingAttention:
    def __init__(self, num_routes, num_heads):
        self.num_routes = num_routes
        self.num_heads = num_heads
    
    def route_tokens(self, keys, queries):
        # 对keys进行k-means聚类
        routes = []
        for h in range(self.num_heads):
            # 计算每个token属于哪个聚类中心
            cluster_ids = kmeans(
                keys[:, h, :, :],  # (batch, seq, d_k)
                k=self.num_routes,
                centroids=learnable_centroids[h]
            )
            routes.append(cluster_ids)
        return routes
    
    def attend(self, q, k, v, routes):
        # 在路由层面计算注意力
        for h in range(self.num_heads):
            for route_id in range(self.num_routes):
                # 获取该路由的query和key
                route_mask = routes[h] == route_id
                q_route = q[:, route_mask, h, :]
                k_route = k[:, route_mask, h, :]
                v_route = v[:, route_mask, h, :]
                
                # 路由内局部注意力
                attn_route = local_attention(q_route, k_route, v_route)
                # ...

复杂度分析:假设token均匀分布在 个路由中,则复杂度降为

3.3.2 Reformer:局部敏感哈希注意力

Reformer6使用局部敏感哈希(Locality-Sensitive Hashing, LSH)将相似的Query和Key分配到同一”桶”中:

LSH机制

// 简化的LSH注意力实现
class LSHAttention {
public:
    // LSH函数:将向量映射到哈希桶
    int lsh_forward(const Vec& x, int num_buckets, int num_hashes) {
        // 生成随机投影向量
        std::vector<Vec> projections = generate_random_projections(x.dim, num_buckets);
        
        // 计算哈希值(取多个哈希的组合)
        std::vector<int> hashes;
        for (int i = 0; i < num_hashes; ++i) {
            float dot = x.dot(projections[i]);
            int bucket = (dot > 0) ? 1 : 0;
            hashes.push_back(bucket << i);  // 组合多个哈希位
        }
        return std::accumulate(hashes.begin(), hashes.end(), 0);
    }
    
    // LSH注意力计算
    Tensor lsh_attention(const Tensor& Q, const Tensor& K, const Tensor& V) {
        int num_buckets = 64;
        int num_hashes = 8;
        
        // 1. 为每个Query计算哈希桶
        std::vector<int> q_buckets = compute_hashes(Q, num_buckets, num_hashes);
        
        // 2. 为每个Key计算哈希桶
        std::vector<int> k_buckets = compute_hashes(K, num_buckets, num_hashes);
        
        // 3. 按桶排序以增加哈希冲突概率
        auto [sorted_q, q_indices] = argsort(q_buckets);
        auto [sorted_k, k_indices] = argsort(k_buckets);
        
        // 4. 在同一桶内计算局部注意力
        // 同一桶内的元素大概率相似(LSH特性)
        Tensor output = compute_local_attention(sorted_q, sorted_k, sorted_k);
        
        // 5. 恢复到原始顺序
        return unsort_by_indices(output, q_indices);
        
        return output;
    }
};

LSH理论背景:LSH的核心特性是”相似向量大概率映射到同一桶”:

其中 是向量 之间的夹角。

3.4 分层稀疏模式

分层稀疏模式通过多尺度/层次化设计,在不同层级使用不同的稀疏策略。

3.4.1 Swin Transformer的移位窗口注意力

Swin Transformer7通过层次化结构和移位窗口机制实现了高效的稀疏注意力。详见Swin Transformer

核心设计

Stage 1: 划分非重叠窗口 (W×W)
┌─────┬─────┬─────┐
│ W_1 │ W_2 │ W_3 │
├─────┼─────┼─────┤
│ W_4 │ W_5 │ W_6 │
└─────┴─────┴─────┘
           ↓ 移位 (⌈W/2⌉, ⌈W/2⌉)
Stage 2: 移位后的窗口划分
┌─────┬─────┬─────┬─────┐
│ A   │ B   │     │ D   │
├─────┼─────┼─────┼─────┤
│ E   │ F   │ G   │ H   │
├─────┼─────┼─────┼─────┤
│     │ I   │ J   │ K   │
├─────┼─────┼─────┼─────┤
│ L   │ M   │     │ O   │
└─────┴─────┴─────┴─────┘

移位机制确保相邻窗口之间存在交叉连接,弥补了固定窗口导致的感受野局限。

3.4.2 金字塔结构与多尺度表示

金字塔结构(如Pyramid Vision Transformer、PVT)在不同阶段逐步降低分辨率,实现计算量的有效控制:

Stage分辨率特征维度注意力类型
Stage 1局部窗口注意力
Stage 2稀疏全局注意力
Stage 3稀疏全局注意力
Stage 4稀疏全局注意力

4. 线性注意力变体对比

线性注意力的核心思想是通过数学变换,将 的注意力计算转换为 的线性计算。

4.1 核函数方法

核函数方法通过特征映射 将原始高维空间中的点积运算转换为低维空间中的近似计算。

4.1.1 Performer:Random Feature (FAVOR+)

Performer8使用Random Feature近似方法,将softmax注意力转化为线性形式:

理论基础:利用Bochner定理,softmax核函数可以表示为随机特征期望:

其中 是随机特征映射:

其中 是随机矩阵, 是随机偏置。

// Performer FAVOR+ 机制实现
template <typename scalar_t>
struct PerformerAttention {
    // 随机特征映射
    static Tensor random_feature(const Tensor& x, int m) {
        // 生成随机矩阵和偏置
        auto W = randn({x.size(-1), m});
        auto b = uniform({m}, 0, 2 * M_PI);
        
        // 指数随机投影
        Tensor proj = matmul(x, W) + b;
        
        // 复数形式的cos/sin特征
        Tensor cos_feat = cos(proj);
        Tensor sin_feat = sin(proj);
        
        // 连接并归一化
        return cat({cos_feat, sin_feat}, dim=-1) / sqrt(m);
    }
    
    // 线性复杂度的近似注意力
    static Tensor favor_attention(const Tensor& Q, const Tensor& K, 
                                  const Tensor& V, int m) {
        // 映射到随机特征空间
        Tensor Q_prime = random_feature(Q, m);  // (B, H, N, 2m)
        Tensor K_prime = random_feature(K, m);  // (B, H, N, 2m)
        
        // 线性注意力:先计算 K'V,再与 Q' 计算
        // \phi(K)^T V: (B, H, 2m, d)
        Tensor KV_product = matmul(K_prime.transpose(-2, -1), V);
        
        // Q' @ (K'V) / (Q' @ K' @ 1)
        Tensor numerator = matmul(Q_prime, KV_product);
        Tensor denominator = matmul(Q_prime, K_prime.sum(-2, true));
        
        return numerator / (denominator + 1e-6);
    }
};

近似误差分析:Performer提供了有界的近似误差保证:

其中 与随机特征维度 相关。

4.1.2 近似误差与收敛性保证

核函数方法的近似质量取决于随机特征的维度 和输入分布。Performer论文给出了以下收敛保证:

中心极限定理视角:当 时,随机特征估计收敛到真实核函数值:

方差估计:设 ,则:

4.2 线性化方法

线性化方法直接修改注意力机制的结构,移除导致二次复杂度的操作。

4.2.1 Linear Transformer

Linear Transformer9将softmax函数替换为特征映射函数:

原始softmax注意力

Linear Transformer

其中 (指数线性单元加1)。

class LinearTransformerAttention(nn.Module):
    """
    线性注意力机制
    核心:将softmax分解为特征映射和加权聚合
    """
    def __init__(self, d_model, d_k):
        super().__init__()
        self.d_k = d_k
        self.proj_q = nn.Linear(d_model, d_k)
        self.proj_k = nn.Linear(d_model, d_k)
        self.proj_v = nn.Linear(d_model, d_k)
        self.scale = d_k ** -0.5
    
    def forward(self, Q, K, V):
        # 特征映射:elu + 1
        Q_prime = F.elu(Q) + 1
        K_prime = F.elu(K) + 1
        
        # 分母:每个query的归一化项
        K_sum = K_prime.sum(dim=-2, keepdim=True)  # (B, 1, H, d)
        denom = (Q_prime * K_sum).sum(dim=-1, keepdim=True)  # (B, N, H, 1)
        
        # 分子:加权聚合
        KV_product = torch.einsum('bnhd,bnhm->bhmd', K_prime, V)
        numerator = torch.einsum('bnhd,bhmd->bnhm', Q_prime, KV_product)
        
        return numerator / (denom + 1e-8)

4.2.2 位置编码补偿

线性注意力移除了softmax,损失了自归一化特性,需要额外的位置编码来建模序列顺序:

解决方案

  1. 相对位置编码:在特征映射中引入位置信息

  2. 傅里叶位置编码:使用正弦/余弦函数

  3. 旋转位置编码(RoPE)

详见ALiBi位置编码

4.3 低秩方法

低秩方法利用注意力矩阵的低秩结构进行近似分解。

4.3.1 Nyströmformer:基于Landmark的近似

Nyströmformer10使用Nyström方法,利用矩阵的低秩近似:

Nyström近似理论:对于半正定矩阵 ,选取 个landmark点 ,则:

其中 是landmark索引集合。

class NystromAttention(nn.Module):
    """
    Nyströmformer注意力
    使用Nyström方法近似全注意力
    """
    def __init__(self, num_landmarks=64):
        super().__init__()
        self.num_landmarks = num_landmarks
    
    def forward(self, Q, K, V):
        seq_len = Q.size(1)
        
        # 1. 选择landmark点(均匀采样)
        landmark_indices = torch.linspace(
            0, seq_len-1, self.num_landmarks, 
            dtype=torch.long, device=Q.device
        )
        
        # 2. 计算子矩阵
        K_landmark = K[:, landmark_indices, :, :]  # (B, m, H, d)
        
        # 3. 计算伪逆(使用可学习的kernel矩阵)
        # Q @ K^T ≈ Q @ K_landmark^T @ M @ K_landmark @ V
        M = torch.softmax(K_landmark @ K_landmark.transpose(-2, -1), dim=-1)
        
        # 4. 计算Nyström近似
        # 分子部分
        KV = torch.einsum('bnhd,bmhd->bhnm', K_landmark, V)
        Z = M @ KV  # (B, H, m, d)
        numerator = torch.einsum('bnhd,bhmd->bnhm', Q, Z)
        
        # 分母部分
        Z = M @ K_landmark  # (B, H, m, d)
        denominator = torch.einsum('bnhd,bhmd->bnhm', Q, Z)
        
        return numerator / (denominator + 1e-8)

4.3.2 矩阵分解视角

从矩阵分解视角看,注意力矩阵 通常具有低秩结构。设 的秩为 ,则:

其中 是低秩分解的基, 是奇异值对角矩阵。

Nyströmformer利用这一特性,通过采样landmarks来估计低秩分解。


5. FlashAttention系列对比

FlashAttention11系列是近年来最重要的注意力优化技术,通过IO感知(IO-aware)设计和算法创新,在不损失精度的前提下显著降低内存占用。

5.1 FlashAttention-1:2D Tiling,IO感知

FlashAttention-1由Stanford大学提出,核心思想是分块计算(Tiling)算力融合(Kernel Fusion)

核心创新

// FlashAttention 核心算法伪代码
template <typename scalar_t>
__global__ void flash_attention_kernel(
    const scalar_t* Q, const scalar_t* K, const scalar_t* V,
    scalar_t* O, scalar_t* L,  // 累积的log-sum-exp
    int seq_len, int head_dim, int block_m, int block_n) {
    
    // 1. 为当前query块分配共享内存
    __shared__ scalar_t sQ[BLOCK_M][HEAD_DIM];
    __shared__ scalar_t sK[BLOCK_N][HEAD_DIM];
    __shared__ scalar_t sV[BLOCK_N][HEAD_DIM];
    
    // 2. 外循环:遍历query块
    for (int cur_m = blockIdx.x * block_m; cur_m < seq_len; cur_m += block_m) {
        
        // 初始化输出和归一化因子
        scalar_t m_i = -INFINITY;  // 当前行的最大值
        scalar_t l_i = 0;          // 当前行的指数和
        
        // 3. 内循环:分块计算注意力
        for (int cur_n = 0; cur_n < seq_len; cur_n += block_n) {
            // 加载K、V块到共享内存
            load_to_shared(sK, K, cur_n, block_n);
            load_to_shared(sV, V, cur_n, block_n);
            __syncthreads();
            
            // 计算QK^T的块
            scalar_t m_hat = -INFINITY;
            for (int j = 0; j < block_n; ++j) {
                scalar_t qk = dot(sQ[i], sK[j]);
                m_hat = max(m_hat, qk);
            }
            
            // 安全softmax计算
            for (int j = 0; j < block_n; ++j) {
                exp_qk = exp(sQ[i] * sK[j] - m_new);
                // 更新行归一化因子
            }
        }
        
        // 写入输出
        O[cur_m + i] = l_i_new * o_i;
    }
}

IO复杂度分析:FlashAttention-1的IO复杂度为 ,其中 是SRAM大小, 是HBM带宽。相比标准实现,内存访问量减少约 倍。

5.2 FlashAttention-2:更好的并行性

FlashAttention-212在v1基础上进行了多项优化:

主要改进

  1. 更好的并行策略:v1只在query维度并行,v2同时在query和key维度并行
  2. 更细粒度的Tiling:减少不必要的内存访问
  3. 支持序列长度非16/32倍数:更灵活的padding
// FlashAttention-2 并行策略
// v1: grid(batch * num_heads),串行处理key维度
// v2: grid(batch * num_heads * num_stages),并行处理key维度
 
// 新的tiling策略:减少寄存器压力
template <typename scalar_t>
__global__ void flash_attention_v2_kernel(
    const scalar_t* Q, const scalar_t* K, const scalar_t* V,
    scalar_t* O, scalar_t* L,
    int seq_len, int num_stages) {
    
    // stage 0: 处理前seq_len/2的key
    // stage 1: 处理后seq_len/2的key
    // 两个stage可以同时执行
    
    for (int stage = 0; stage < num_stages; ++stage) {
        // 预取下一阶段的数据
        prefetch_KV(stage + 1);
        
        // 并行计算
        parallel_attention_block(Q_block, K_block, V_block, O_block);
    }
}

5.3 FlashAttention-3:3D并行,FP8

FlashAttention-313进一步挖掘现代GPU架构的潜力:

架构级优化

特性FlashAttn-1FlashAttn-2FlashAttn-3
并行维度2D (Q, Head)2D (Q, Head)3D (Q, Head, K)
Tensor Core不支持部分支持完全支持
数值精度FP16/BF16FP16/BF16FP16/BF16/FP8
** warp专门化**
异步执行

FP8支持

// FlashAttention-3 FP8 实现
__global__ void flash_attention_fp8_kernel(
    const __fp8_e4m3* Q, const __fp8_e4m3* K, const __fp8_e4m3* V,
    float* O, int seq_len) {
    
    // FP8 -> BF16 转换(保留高精度累加)
    __bf16* Q_bf16 = convert_fp8_to_bf16(Q);
    __bf16* K_bf16 = convert_fp8_to_bf16(K);
    __bf16* V_bf16 = convert_fp8_to_bf16(V);
    
    // Tensor Core矩阵乘法(支持FP8输入)
    // WGMMA指令:异步矩阵乘
    wgmma_ma_cluster_t ma_cluster;
    wgmma_wait_group_t wait_group;
    
    // 异步启动矩阵乘法
    wgmma_emit_window(ma_cluster, Q_bf16);
    wgmma_accumulate(ma_cluster, K_bf16, &wait_group);
    
    // 其他计算...
    wgmma_commit_group(wait_group);
    wgmma_wait_group(wait_group);
}

5.4 对比表

特性FlashAttention-1FlashAttention-2FlashAttention-3
论文年份202220232024
H100加速~4×~8×~16×
内存占用
精度损失无(可配)
长序列优势显著更显著极显著
数值格式FP16/BF16FP16/BF16FP16/BF16/FP8

6. 混合注意力架构

混合注意力架构结合不同类型注意力的优势,在效率与表达力之间寻求更好平衡。

6.1 Hybrid Attention-SSM

选择性状态空间模型(Selective SSM)与注意力机制的结合是近年来的重要研究方向。详见状态空间模型

6.1.1 Mamba:选择性状态空间+注意力机制

Mamba14通过选择性机制(Selectivity Mechanism)使SSM能够像注意力一样”选择性”地处理信息。详见Mamba-2理论

Mamba vs 注意力对比

维度注意力Mamba (SSM)
计算复杂度
并行性高度并行受限(递归结构)
选择性可学习输入依赖(选择性)
位置建模需额外编码隐式建模
内存占用

6.1.2 Hyena:隐式长卷积的混合

Hyena15使用多层长卷积(Long Convolution)与注意力结合:

其中 是可学习的滤波器系数,与输入相关。

6.2 滑动窗口+全局注意力

Longformer和BigBird(详见第3节)是滑动窗口+全局注意力模式的典型代表。

设计原则

  1. 局部建模:使用滑动窗口注意力捕获局部结构
  2. 全局交互:少量全局token负责跨距离信息传递
  3. 效率平衡:局部注意力 ,全局注意力

7. 实践选择指南

7.1 场景驱动的注意力选择

根据序列长度和应用场景选择合适的注意力机制:

序列长度推荐方案理由
短序列 ()标准注意力精度优先,二次复杂度可接受
中等序列 ()FlashAttention内存高效,保持精度
长序列 ()稀疏/线性注意力计算和内存约束
超长序列 ()分层稀疏+局部注意力多尺度建模能力

7.2 资源约束分析

内存受限场景

  • 优先选择FlashAttention系列(内存 vs
  • 考虑稀疏注意力(存储稀疏矩阵而非稠密矩阵)
  • Linear Transformer内存需求最小

计算受限场景

  • 稀疏注意力在计算量上优势明显
  • 分块稀疏可利用硬件特性加速
  • 混合精度(FP8)可提升吞吐量

精度优先场景

  • FlashAttention是唯一在保持精度同时优化内存的方案
  • 稀疏注意力的精度损失需要具体评估
  • 核函数近似的误差积累需要注意

7.3 精度-效率权衡

不同注意力变体的精度-效率权衡曲线:

精度
  ^
  │                    ★ FlashAttention
  │               ★     ★
  │            ★  ★
  │         ★
  │      ★         ★ 稀疏注意力
  │   ★    ★    ★
  │ ★  ★  ★
  │★ ★
  │__________________________> 计算效率
       ↑              ↑
    低效率         高效率

8. 核心算法对比表

注意力类型时间复杂度空间复杂度近似误差代表模型精度保持
标准Attention0Vanilla Transformer
FlashAttention-10FlashAttention
FlashAttention-20FlashAttention-2
FlashAttention-30FlashAttention-3
Performer有界Performer~
Linear Transformer有界Linear Transformer~
LSH Attention有界Reformer~
Nyströmformer有界Nyströmformer~
Longformer可控Longformer
BigBird可控BigBird
Swin Transformer可控Swin Transformer
Routing Transformer有界Routing Transformer~

注:~ 表示精度损失程度取决于具体任务和配置;✓ 表示精度保持


9. 参考文献

Footnotes

  1. Vaswani A, Shazeer N, Parmar N, et al. Attention is All you Need[J]. Advances in Neural Information Processing Systems, 2017.

  2. Beltagy I, Peters M E, Cohan A. Longformer: The Long-Document Transformer[J]. arXiv preprint arXiv:2004.05150, 2020.

  3. Zaheer M, Guruganesh G, Dubey K A, et al. Big Bird: Transformers for Longer Sequences[J]. Advances in Neural Information Processing Systems, 2020.

  4. Child R, Gray S, Radford A, et al. Generating Long Sequences with Sparse Transformers[J]. arXiv preprint arXiv:1904.10509, 2019.

  5. Roy A, Saffar M, Vaswani A, et al. Efficient Content-Based Sparse Attention with Routing Transformers[J]. Transactions of the Association for Computational Linguistics, 2021.

  6. Kitaev N, Kaiser L, Levskaya A. Reformer: The Efficient Transformer[J]. International Conference on Learning Representations, 2020.

  7. Liu Z, Lin Y, Cao Y, et al. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows[C]. IEEE/CVF International Conference on Computer Vision, 2021.

  8. Choromanski H, Likhosherstov V, Dohan D, et al. Rethinking Attention with Performers[J]. International Conference on Learning Representations, 2021.

  9. Katharopoulos A, Vyas A, Pappas N, et al. Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention[C]. International Conference on Machine Learning, 2020.

  10. Xiong Y, Zeng Z, Chakraborty R, et al. Nyströmformer: A Nyström-based Algorithm for Approximating Self-Attention[J]. AAAI Conference on Artificial Intelligence, 2021.

  11. Dao T, Fu D Y, Ermon S, et al. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness[J]. Advances in Neural Information Processing Systems, 2022.

  12. Dao T. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning[J]. International Conference on Learning Representations, 2024.

  13. Shah A, Chen Y, Luo J, et al. FlashAttention-3: Fast and Accurate Attention with A100 GPU Architecture and Irregular Matrices[J]. arXiv preprint arXiv:2407.08608, 2024.

  14. Gu A, Dao T. Mamba: Linear-Time Sequence Modeling with Selective State Spaces[J]. arXiv preprint arXiv:2312.00752, 2023.

  15. Poli M, Massaroli S, Nguyen E, et al. Hyena Hierarchy: Towards Larger Convolutional Language Models[J]. International Conference on Machine Learning, 2023.