MiniMax Sparse Attention (MSA) 超长上下文稀疏注意力
1. 问题背景
1.1 超长上下文的必要性
超长上下文能力对于前沿LLM越来越重要:
| 应用场景 | 上下文需求 |
|---|---|
| Agent工作流 | 10万+ tokens |
| 代码仓库推理 | 50万+ tokens |
| 持久记忆 | 100万+ tokens |
1.2 现有方法的挑战
Softmax注意力的二次复杂度:
当序列长度达到百万级时,计算和内存开销变得不可接受。
现有稀疏注意力的局限:
- KV Cache容量仍随序列长度线性增长
- 稀疏模式可能导致关键信息丢失
- 实现复杂度高,难以在各种GPU上高效部署
2. MiniMax Sparse Attention核心设计
2.1 架构概述
MSA是一种基于分组查询注意力(GQA)的块级稀疏注意力,其设计原则是简洁性和可扩展性:
┌──────────────────────────────────────────────────────────────┐
│ MSA Architecture │
├──────────────────────────────────────────────────────────────┤
│ │
│ Input Tokens ──► Block Division ──► Index Branch │
│ │ │ │
│ │ ▼ │
│ │ Score Key-Value Blocks │
│ │ │ │
│ ▼ ▼ │
│ ┌─────────────────────────┐ │
│ │ Top-k Block Selection │ │
│ │ (Per GQA Group) │ │
│ └─────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────┐ │
│ │ Main Branch │ │
│ │ (Exact Block-Sparse │ │
│ │ Attention) │ │
│ └─────────────────────────┘ │
│ │ │
│ ▼ │
│ Output │
└──────────────────────────────────────────────────────────────┘
2.2 核心技术
2.2.1 轻量级索引分支
索引分支(Index Branch)负责对键值块进行评分,并独立为每个GQA组选择Top-k子集:
设计要点:
- 轻量化:使用小型MLP而非完整注意力
- 分组选择:每个GQA组独立选择感兴趣块
- 保持多样性:不同组可选择不同块
2.2.2 主分支精确注意力
主分支(Main Branch)在索引分支选择的块上执行精确的块稀疏注意力:
其中 表示索引分支选中的Top-k块。
2.3 块级稀疏模式
标准注意力 vs MSA稀疏注意力:
标准注意力(序列长度=16):
┌────────────────────────┐
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
│ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ │
└────────────────────────┘
MSA稀疏注意力(Top-4块,块大小=4):
┌────────────────────────┐
│ ■ □ □ □ ■ □ □ □ □ □ ■ □ │ ← 仅关注选中块
│ ■ □ □ □ ■ □ □ □ □ □ ■ □ │
│ ■ □ □ □ ■ □ □ □ □ □ ■ □ │
│ ■ □ □ □ ■ □ □ □ □ □ ■ □ │
│ □ □ □ □ □ □ □ □ ■ □ □ □ │
│ □ □ □ □ □ □ □ □ ■ □ □ □ │
│ □ □ □ □ □ □ □ □ ■ □ □ □ │
│ □ □ □ □ □ □ □ □ ■ □ □ □ │
│ □ □ □ □ □ □ □ □ □ □ □ □ │
│ □ □ □ □ □ □ □ □ □ □ □ □ │
│ ■ □ □ □ ■ □ □ □ □ □ □ □ │
│ ■ □ □ □ ■ □ □ □ □ □ □ □ │
└────────────────────────┘
□ = 稀疏(不计算)
■ = 密集(执行注意力)
3. GPU协同设计
3.1 高效执行路径
MSA与GPU执行路径协同设计,实现实用加速:
3.1.1 无指数运算的Top-k选择
传统稀疏注意力需要计算Softmax,但指数运算(exp)会导致数值溢出。MSA采用无exp的Top-k选择:
def top_k_selection_without_exp(scores, k):
"""
选择Top-k而不使用exp运算
传统Softmax需要: exp(x_i) / Σ exp(x_j)
MSA使用: 直接对原始分数排序选择
"""
# 直接对分数排序(无需exp)
_, indices = torch.topk(scores, k, dim=-1)
# 创建稀疏掩码
mask = torch.zeros_like(scores)
mask.scatter_(1, indices, 1.0)
return mask3.1.2 KV外稀疏注意力
KV外稀疏(KV-outer Sparse)提高Tensor Core利用率:
- 块级访问模式匹配GPU内存层次结构
- 减少全局内存访问
- 提高计算密度
3.2 性能指标
在109B参数模型(带原生多模态训练)上的表现:
| 指标 | MSA | GQA基线 |
|---|---|---|
| 注意力计算减少 | 28.4× | 1× |
| 预填充加速(1M上下文) | 14.2× | 1× |
| 解码加速(H800) | 7.6× | 1× |
| 精度损失 | 可忽略 | - |
4. PyTorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
class MiniMaxSparseAttention(nn.Module):
"""
MiniMax Sparse Attention (MSA)
Block-wise sparse attention based on Grouped Query Attention.
"""
def __init__(
self,
d_model: int,
n_heads: int,
n_kv_heads: int,
block_size: int = 64,
top_k: int = 8,
score_hidden_dim: int = 128
):
super().__init__()
assert d_model % n_heads == 0
assert n_heads % n_kv_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_groups = n_heads // n_kv_heads
self.d_head = d_model // n_heads
self.block_size = block_size
self.top_k = top_k
# Projections for main attention
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.o_proj = nn.Linear(d_model, d_model)
# Lightweight Index Branch
self.score_net = nn.Sequential(
nn.Linear(d_model, score_hidden_dim),
nn.GELU(),
nn.Linear(score_hidden_dim, 1, bias=False)
)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
x: Input tensor [batch, seq_len, d_model]
mask: Optional attention mask
Returns:
Output tensor [batch, seq_len, d_model]
"""
batch_size, seq_len, _ = x.shape
# Compute Q, K, V
q = self.q_proj(x).view(batch_size, seq_len, self.n_heads, self.d_head)
k = self.k_proj(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
v = self.v_proj(x).view(batch_size, seq_len, self.n_kv_heads, self.d_head)
# Determine number of blocks
n_blocks = (seq_len + self.block_size - 1) // self.block_size
# Index Branch: Score each block
block_scores = self._score_blocks(q, k, n_blocks)
# Select Top-k blocks per group
selected_blocks = self._select_top_k_blocks(block_scores)
# Main Branch: Compute sparse attention
output = self._sparse_attention(
q, k, v, selected_blocks, mask
)
# Output projection
output = self.o_proj(output)
return output
def _score_blocks(
self,
q: torch.Tensor,
k: torch.Tensor,
n_blocks: int
) -> torch.Tensor:
"""
Score each block using the Index Branch.
"""
batch_size, seq_len, n_heads, d_head = q.shape
block_size = self.block_size
# Reshape for block-wise computation
# Pad if necessary
pad_len = block_size * n_blocks - seq_len
if pad_len > 0:
q = F.pad(q, (0, 0, 0, pad_len))
k = F.pad(k, (0, 0, 0, pad_len))
# Reshape to blocks
q_blocks = q.view(batch_size, n_blocks, block_size, n_heads, d_head)
k_blocks = k.view(batch_size, n_blocks, block_size, self.n_kv_heads, d_head)
# Aggregate queries per block (use first token of each block)
q_block_agg = q_blocks[:, :, 0] # [B, n_blocks, n_heads, D]
# Expand for KV heads
q_expanded = q_block_agg.unsqueeze(3).expand(
-1, -1, self.n_kv_heads, self.n_heads // self.n_kv_heads, -1
)
q_expanded = q_expanded.reshape(
batch_size, n_blocks, self.n_kv_heads, n_heads // self.n_kv_heads, d_head
).transpose(2, 3) # [B, n_blocks, n_groups, d_head]
# Compute block importance scores
# Use mean aggregation for KV
k_block_agg = k_blocks.mean(dim=2) # [B, n_blocks, n_kv, D]
# Compute relevance score
scores = (q_block_agg * k_block_agg.mean(dim=2, keepdim=True)).sum(dim=-1)
# Index Branch MLP scoring
combined = q_block_agg + k_block_agg.mean(dim=2, keepdim=True)
mlp_scores = self.score_net(combined).squeeze(-1) # [B, n_blocks, n_heads]
# Combine scores
final_scores = scores + mlp_scores
return final_scores
def _select_top_k_blocks(
self,
block_scores: torch.Tensor
) -> torch.Tensor:
"""
Select Top-k blocks for each GQA group.
"""
batch_size, n_blocks, n_heads = block_scores.shape
# Get top-k indices
_, top_k_indices = torch.topk(
block_scores,
min(self.top_k, n_blocks),
dim=1
)
# Create selection mask
selected = torch.zeros_like(block_scores)
selected.scatter_(1, top_k_indices, 1.0)
return selected
def _sparse_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
selected_blocks: torch.Tensor,
mask: Optional[torch.Tensor]
) -> torch.Tensor:
"""
Compute attention only on selected blocks.
"""
batch_size, seq_len, n_heads, d_head = q.shape
n_kv_heads = self.n_kv_heads
block_size = self.block_size
n_blocks = (seq_len + block_size - 1) // block_size
# Reshape Q, K, V to blocks
pad_len = block_size * n_blocks - seq_len
if pad_len > 0:
q = F.pad(q, (0, 0, 0, pad_len))
k = F.pad(k, (0, 0, 0, pad_len))
v = F.pad(v, (0, 0, 0, pad_len))
q_blocks = q.view(batch_size, n_blocks, block_size, n_heads, d_head)
k_blocks = k.view(batch_size, n_blocks, block_size, n_kv_heads, d_head)
v_blocks = v.view(batch_size, n_blocks, block_size, n_kv_heads, d_head)
# Expand selected mask to block granularity
selected_mask = selected_blocks.unsqueeze(2) # [B, n_blocks, 1, n_heads]
selected_mask = selected_mask.expand(-1, -1, block_size, -1) # [B, n_blocks, block, n_heads]
selected_mask = selected_mask.reshape(batch_size, -1, n_heads)[:, :seq_len]
# For KV heads, aggregate across groups
kv_selected_mask = selected_mask.unfold(2, n_heads // n_kv_heads, n_heads // n_kv_heads).mean(dim=-1)
# Compute attention on selected positions
# Simplified: use full attention masked by selection
q_flat = q[:, :seq_len].transpose(1, 2) # [B, H, L, D]
k_flat = k[:, :seq_len].transpose(1, 2) # [B, H, L, D]
v_flat = v[:, :seq_len].transpose(1, 2) # [B, H, L, D]
# Compute attention scores
scale = d_head ** -0.5
scores = torch.matmul(q_flat, k_flat.transpose(-2, -1)) * scale
# Apply selection mask
extended_mask = selected_mask.unsqueeze(1).expand(-1, n_heads, -1, -1)
scores = scores.masked_fill(extended_mask == 0, float('-inf'))
if mask is not None:
scores = scores.masked_fill(mask.unsqueeze(1).unsqueeze(2) == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, v_flat)
return output.transpose(1, 2).reshape(batch_size, seq_len, -1)
class MSATransformerBlock(nn.Module):
"""
Transformer block with MiniMax Sparse Attention.
"""
def __init__(
self,
d_model: int,
n_heads: int,
n_kv_heads: int,
block_size: int = 64,
top_k: int = 8,
mlp_ratio: float = 4.0,
dropout: float = 0.1
):
super().__init__()
self.attention = MiniMaxSparseAttention(
d_model, n_heads, n_kv_heads, block_size, top_k
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
mlp_hidden = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(d_model, mlp_hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden, d_model),
nn.Dropout(dropout)
)
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
# Pre-norm attention with residual
x = x + self.attention(self.norm1(x), mask)
x = x + self.mlp(self.norm2(x))
return x5. 与其他稀疏注意力方法的对比
5.1 方法对比
| 方法 | 选择策略 | GQA支持 | 稀疏粒度 | 实现复杂度 |
|---|---|---|---|---|
| H2O | 重要性评分 | 否 | Token级 | 高 |
| PyramidKV | 层级压缩 | 否 | 层级 | 中 |
| SnapKV | 动态选择 | 否 | Token级 | 中 |
| MSA | 块级Top-k | 是 | 块级 | 低 |
5.2 核心优势
- 简洁性:设计简单,易于在各类GPU上高效实现
- 可扩展性:块级稀疏天然支持长上下文
- GQA原生支持:适配现代LLM架构
- 实用加速:14.2×预填充加速,7.6×解码加速
6. 应用场景
6.1 Agent工作流
- 多轮对话历史管理
- 工具调用链记忆
- 跨会话知识整合
6.2 代码仓库推理
- 理解大型代码库的依赖关系
- 跨文件上下文推理
- 代码补全与修改
6.3 持久记忆
- 长期记忆系统
- 个性化上下文
- 知识库的持续更新
7. 总结与展望
7.1 核心贡献
- 块级稀疏机制:基于GQA的高效块选择
- GPU协同设计:无exp的Top-k选择,KV外稀疏
- 工业级实现:已发布开源代码和生产级模型
7.2 未来方向
- 自适应Top-k选择
- 与其他高效架构的结合
- 在更多硬件平台上的优化