MiniMax Sparse Attention (MSA) 超长上下文稀疏注意力

1. 问题背景

1.1 超长上下文的必要性

超长上下文能力对于前沿LLM越来越重要:

应用场景上下文需求
Agent工作流10万+ tokens
代码仓库推理50万+ tokens
持久记忆100万+ tokens

1.2 现有方法的挑战

Softmax注意力的二次复杂度

当序列长度达到百万级时,计算和内存开销变得不可接受。

现有稀疏注意力的局限

  • KV Cache容量仍随序列长度线性增长
  • 稀疏模式可能导致关键信息丢失
  • 实现复杂度高,难以在各种GPU上高效部署

2. MiniMax Sparse Attention核心设计

2.1 架构概述

MSA是一种基于分组查询注意力(GQA)的块级稀疏注意力,其设计原则是简洁性和可扩展性

┌──────────────────────────────────────────────────────────────┐
│                    MSA Architecture                           │
├──────────────────────────────────────────────────────────────┤
│                                                              │
│   Input Tokens ──► Block Division ──► Index Branch          │
│                           │                   │              │
│                           │                   ▼              │
│                           │          Score Key-Value Blocks  │
│                           │                   │              │
│                           ▼                   ▼              │
│                      ┌─────────────────────────┐             │
│                      │   Top-k Block Selection  │             │
│                      │   (Per GQA Group)       │             │
│                      └─────────────────────────┘             │
│                                    │                         │
│                                    ▼                         │
│                      ┌─────────────────────────┐             │
│                      │     Main Branch         │             │
│                      │  (Exact Block-Sparse    │             │
│                      │   Attention)            │             │
│                      └─────────────────────────┘             │
│                                    │                         │
│                                    ▼                         │
│                              Output                           │
└──────────────────────────────────────────────────────────────┘

2.2 核心技术

2.2.1 轻量级索引分支

索引分支(Index Branch)负责对键值块进行评分,并独立为每个GQA组选择Top-k子集:

设计要点

  • 轻量化:使用小型MLP而非完整注意力
  • 分组选择:每个GQA组独立选择感兴趣块
  • 保持多样性:不同组可选择不同块

2.2.2 主分支精确注意力

主分支(Main Branch)在索引分支选择的块上执行精确的块稀疏注意力:

其中 表示索引分支选中的Top-k块。

2.3 块级稀疏模式

标准注意力 vs MSA稀疏注意力:

标准注意力(序列长度=16):
┌────────────────────────┐
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
└────────────────────────┘

MSA稀疏注意力(Top-4块,块大小=4):
┌────────────────────────┐
│ ■ □ □ □ ■ □ □ □ □ □ ■ □ │  ← 仅关注选中块
│ ■ □ □ □ ■ □ □ □ □ □ ■ □ │
│ ■ □ □ □ ■ □ □ □ □ □ ■ □ │
│ ■ □ □ □ ■ □ □ □ □ □ ■ □ │
│ □ □ □ □ □ □ □ □ ■ □ □ □ │
│ □ □ □ □ □ □ □ □ ■ □ □ □ │
│ □ □ □ □ □ □ □ □ ■ □ □ □ │
│ □ □ □ □ □ □ □ □ ■ □ □ □ │
│ □ □ □ □ □ □ □ □ □ □ □ □ │
│ □ □ □ □ □ □ □ □ □ □ □ □ │
│ ■ □ □ □ ■ □ □ □ □ □ □ □ │
│ ■ □ □ □ ■ □ □ □ □ □ □ □ │
└────────────────────────┘
□ = 稀疏(不计算)
■ = 密集(执行注意力)

3. GPU协同设计

3.1 高效执行路径

MSA与GPU执行路径协同设计,实现实用加速:

3.1.1 无指数运算的Top-k选择

传统稀疏注意力需要计算Softmax,但指数运算(exp)会导致数值溢出。MSA采用无exp的Top-k选择:

def top_k_selection_without_exp(scores, k):
    """
    选择Top-k而不使用exp运算
    
    传统Softmax需要: exp(x_i) / Σ exp(x_j)
    MSA使用: 直接对原始分数排序选择
    """
    # 直接对分数排序(无需exp)
    _, indices = torch.topk(scores, k, dim=-1)
    
    # 创建稀疏掩码
    mask = torch.zeros_like(scores)
    mask.scatter_(1, indices, 1.0)
    
    return mask

3.1.2 KV外稀疏注意力

KV外稀疏(KV-outer Sparse)提高Tensor Core利用率:

  • 块级访问模式匹配GPU内存层次结构
  • 减少全局内存访问
  • 提高计算密度

3.2 性能指标

在109B参数模型(带原生多模态训练)上的表现:

指标MSAGQA基线
注意力计算减少28.4×
预填充加速(1M上下文)14.2×
解码加速(H800)7.6×
精度损失可忽略-

4. PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
 
class MiniMaxSparseAttention(nn.Module):
    """
    MiniMax Sparse Attention (MSA)
    Block-wise sparse attention based on Grouped Query Attention.
    """
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        n_kv_heads: int,
        block_size: int = 64,
        top_k: int = 8,
        score_hidden_dim: int = 128
    ):
        super().__init__()
        assert d_model % n_heads == 0
        assert n_heads % n_kv_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.n_groups = n_heads // n_kv_heads
        self.d_head = d_model // n_heads
        self.block_size = block_size
        self.top_k = top_k
        
        # Projections for main attention
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)
        
        # Lightweight Index Branch
        self.score_net = nn.Sequential(
            nn.Linear(d_model, score_hidden_dim),
            nn.GELU(),
            nn.Linear(score_hidden_dim, 1, bias=False)
        )
        
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Args:
            x: Input tensor [batch, seq_len, d_model]
            mask: Optional attention mask
        Returns:
            Output tensor [batch, seq_len, d_model]
        """
        batch_size, seq_len, _ = x.shape
        
        # Compute Q, K, V
        q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.d_head)
        k = self.k_proj(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
        v = self.v_proj(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
        
        # Determine number of blocks
        n_blocks = (seq_len + self.block_size - 1) // self.block_size
        
        # Index Branch: Score each block
        block_scores = self._score_blocks(q, k, n_blocks)
        
        # Select Top-k blocks per group
        selected_blocks = self._select_top_k_blocks(block_scores)
        
        # Main Branch: Compute sparse attention
        output = self._sparse_attention(
            q, k, v, selected_blocks, mask
        )
        
        # Output projection
        output = self.o_proj(output)
        return output
    
    def _score_blocks(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        n_blocks: int
    ) -> torch.Tensor:
        """
        Score each block using the Index Branch.
        """
        batch_size, seq_len, n_heads, d_head = q.shape
        block_size = self.block_size
        
        # Reshape for block-wise computation
        # Pad if necessary
        pad_len = block_size * n_blocks - seq_len
        if pad_len > 0:
            q = F.pad(q, (0, 0, 0, pad_len))
            k = F.pad(k, (0, 0, 0, pad_len))
        
        # Reshape to blocks
        q_blocks = q.view(batch_size, n_blocks, block_size, n_heads, d_head)
        k_blocks = k.view(batch_size, n_blocks, block_size, self.n_kv_heads, d_head)
        
        # Aggregate queries per block (use first token of each block)
        q_block_agg = q_blocks[:, :, 0]  # [B, n_blocks, n_heads, D]
        
        # Expand for KV heads
        q_expanded = q_block_agg.unsqueeze(3).expand(
            -1, -1, self.n_kv_heads, self.n_heads // self.n_kv_heads, -1
        )
        q_expanded = q_expanded.reshape(
            batch_size, n_blocks, self.n_kv_heads, n_heads // self.n_kv_heads, d_head
        ).transpose(2, 3)  # [B, n_blocks, n_groups, d_head]
        
        # Compute block importance scores
        # Use mean aggregation for KV
        k_block_agg = k_blocks.mean(dim=2)  # [B, n_blocks, n_kv, D]
        
        # Compute relevance score
        scores = (q_block_agg * k_block_agg.mean(dim=2, keepdim=True)).sum(dim=-1)
        
        # Index Branch MLP scoring
        combined = q_block_agg + k_block_agg.mean(dim=2, keepdim=True)
        mlp_scores = self.score_net(combined).squeeze(-1)  # [B, n_blocks, n_heads]
        
        # Combine scores
        final_scores = scores + mlp_scores
        
        return final_scores
    
    def _select_top_k_blocks(
        self,
        block_scores: torch.Tensor
    ) -> torch.Tensor:
        """
        Select Top-k blocks for each GQA group.
        """
        batch_size, n_blocks, n_heads = block_scores.shape
        
        # Get top-k indices
        _, top_k_indices = torch.topk(
            block_scores, 
            min(self.top_k, n_blocks), 
            dim=1
        )
        
        # Create selection mask
        selected = torch.zeros_like(block_scores)
        selected.scatter_(1, top_k_indices, 1.0)
        
        return selected
    
    def _sparse_attention(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        selected_blocks: torch.Tensor,
        mask: Optional[torch.Tensor]
    ) -> torch.Tensor:
        """
        Compute attention only on selected blocks.
        """
        batch_size, seq_len, n_heads, d_head = q.shape
        n_kv_heads = self.n_kv_heads
        block_size = self.block_size
        n_blocks = (seq_len + block_size - 1) // block_size
        
        # Reshape Q, K, V to blocks
        pad_len = block_size * n_blocks - seq_len
        if pad_len > 0:
            q = F.pad(q, (0, 0, 0, pad_len))
            k = F.pad(k, (0, 0, 0, pad_len))
            v = F.pad(v, (0, 0, 0, pad_len))
        
        q_blocks = q.view(batch_size, n_blocks, block_size, n_heads, d_head)
        k_blocks = k.view(batch_size, n_blocks, block_size, n_kv_heads, d_head)
        v_blocks = v.view(batch_size, n_blocks, block_size, n_kv_heads, d_head)
        
        # Expand selected mask to block granularity
        selected_mask = selected_blocks.unsqueeze(2)  # [B, n_blocks, 1, n_heads]
        selected_mask = selected_mask.expand(-1, -1, block_size, -1)  # [B, n_blocks, block, n_heads]
        selected_mask = selected_mask.reshape(batch_size, -1, n_heads)[:, :seq_len]
        
        # For KV heads, aggregate across groups
        kv_selected_mask = selected_mask.unfold(2, n_heads // n_kv_heads, n_heads // n_kv_heads).mean(dim=-1)
        
        # Compute attention on selected positions
        # Simplified: use full attention masked by selection
        q_flat = q[:, :seq_len].transpose(1, 2)  # [B, H, L, D]
        k_flat = k[:, :seq_len].transpose(1, 2)  # [B, H, L, D]
        v_flat = v[:, :seq_len].transpose(1, 2)  # [B, H, L, D]
        
        # Compute attention scores
        scale = d_head ** -0.5
        scores = torch.matmul(q_flat, k_flat.transpose(-2, -1)) * scale
        
        # Apply selection mask
        extended_mask = selected_mask.unsqueeze(1).expand(-1, n_heads, -1, -1)
        scores = scores.masked_fill(extended_mask == 0, float('-inf'))
        
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2) == 0, float('-inf'))
        
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, v_flat)
        
        return output.transpose(1, 2).reshape(batch_size, seq_len, -1)
 
 
class MSATransformerBlock(nn.Module):
    """
    Transformer block with MiniMax Sparse Attention.
    """
    def __init__(
        self,
        d_model: int,
        n_heads: int,
        n_kv_heads: int,
        block_size: int = 64,
        top_k: int = 8,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1
    ):
        super().__init__()
        
        self.attention = MiniMaxSparseAttention(
            d_model, n_heads, n_kv_heads, block_size, top_k
        )
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        mlp_hidden = int(d_model * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, mlp_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, d_model),
            nn.Dropout(dropout)
        )
        
    def forward(
        self,
        x: torch.Tensor,
        mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        # Pre-norm attention with residual
        x = x + self.attention(self.norm1(x), mask)
        x = x + self.mlp(self.norm2(x))
        return x

5. 与其他稀疏注意力方法的对比

5.1 方法对比

方法选择策略GQA支持稀疏粒度实现复杂度
H2O重要性评分Token级
PyramidKV层级压缩层级
SnapKV动态选择Token级
MSA块级Top-k块级

5.2 核心优势

  1. 简洁性:设计简单,易于在各类GPU上高效实现
  2. 可扩展性:块级稀疏天然支持长上下文
  3. GQA原生支持:适配现代LLM架构
  4. 实用加速:14.2×预填充加速,7.6×解码加速

6. 应用场景

6.1 Agent工作流

  • 多轮对话历史管理
  • 工具调用链记忆
  • 跨会话知识整合

6.2 代码仓库推理

  • 理解大型代码库的依赖关系
  • 跨文件上下文推理
  • 代码补全与修改

6.3 持久记忆

  • 长期记忆系统
  • 个性化上下文
  • 知识库的持续更新

7. 总结与展望

7.1 核心贡献

  1. 块级稀疏机制:基于GQA的高效块选择
  2. GPU协同设计:无exp的Top-k选择,KV外稀疏
  3. 工业级实现:已发布开源代码和生产级模型

7.2 未来方向

  • 自适应Top-k选择
  • 与其他高效架构的结合
  • 在更多硬件平台上的优化

参考资料