CAT压缩注意力Transformer
概述
CAT(Compress & Attend Transformer)是一种概念简洁的架构,通过两个简单要素——密集注意力和压缩——实现可控的效率-质量权衡。1
传统Transformer面临注意力二次复杂度问题,CAT通过压缩+注意力的组合,在单一自适应架构中实现测试时质量-计算权衡控制。
核心思想
问题分析
现有高效Transformer方法存在以下问题:
- 稀疏/滑动窗口注意力:降低质量换取效率
- 线性注意力:表达能力受限
- 混合架构:需要复杂的层组合设计
- 固定权衡:无法根据输入动态调整
CAT框架
CAT的核心洞察:压缩本身就是高效注意力的关键
输入序列 x_1, x_2, ..., x_n
↓
块划分(chunking)
↓
块内压缩 (Compression)
↓
压缩块注意力 (Attention)
↓
逐块解码 (Decoding)
形式化定义
对于序列 ,CAT执行以下操作:
- 块划分:将序列划分为大小为 的块
- 块压缩:每个块压缩为 个表示
- 块间注意力:在压缩表示上执行注意力
其中 是块压缩函数。
压缩机制
块压缩函数
class BlockCompressor(nn.Module):
def __init__(self, d_model, chunk_size, compression_ratio):
super().__init__()
self.chunk_size = chunk_size
self.compression_ratio = compression_ratio
self.c_out = chunk_size // compression_ratio
# 压缩投影
self.compress_proj = nn.Linear(chunk_size * d_model, self.c_out * d_model)
# 重构投影
self.reconstruct_proj = nn.Linear(self.c_out * d_model, chunk_size * d_model)
def compress(self, x):
# x: [B, N, D]
B, N, D = x.shape
# 填充到块大小的倍数
pad_len = (self.chunk_size - N % self.chunk_size) % self.chunk_size
if pad_len > 0:
x = F.pad(x, (0, 0, 0, pad_len))
# 重塑为块
x_reshaped = x.view(B, -1, self.chunk_size, D) # [B, n_chunks, chunk_size, D]
# 展平并压缩
x_flat = x_reshaped.view(B, -1, self.chunk_size * D)
x_compressed = self.compress_proj(x_flat) # [B, n_chunks, c_out * D]
# 重塑为压缩表示
x_compressed = x_compressed.view(B, -1, self.c_out, D)
return x_compressed, N # 返回压缩表示和原始长度
def reconstruct(self, x_compressed, original_len):
B, n_chunks, c_out, D = x_compressed.shape
# 展平并重构
x_flat = x_compressed.view(B, n_chunks, c_out * D)
x_reconstructed = self.reconstruct_proj(x_flat)
# 重塑回块格式
x_reshaped = x_reconstructed.view(B, n_chunks, self.chunk_size, D)
x = x_reshaped.view(B, -1, D)
# 截断到原始长度
return x[:, :original_len]多尺度压缩
CAT支持多块大小训练,实现单一模型支持多种效率-质量权衡:
class MultiScaleCAT(nn.Module):
def __init__(self, d_model, n_heads, chunk_sizes=[16, 32, 64]):
super().__init__()
self.chunk_sizes = chunk_sizes
self.compressors = nn.ModuleDict({
str(B): BlockCompressor(d_model, B, B//4)
for B in chunk_sizes
})
self.attention = MultiHeadAttention(d_model, n_heads)
def forward(self, x, chunk_size=None):
# 训练时随机选择块大小
if chunk_size is None and self.training:
chunk_size = random.choice(self.chunk_sizes)
elif chunk_size is None:
chunk_size = min(self.chunk_sizes)
compressor = self.compressors[str(chunk_size)]
x_compressed, original_len = compressor.compress(x)
# 展平块维度用于注意力
B, n_chunks, c_out, D = x_compressed.shape
x_flat = x_compressed.view(B, n_chunks * c_out, D)
# 注意力
x_attended = self.attention(x_flat, x_flat, x_flat)
return x_attended, original_len训练策略
多块大小联合训练
CAT的核心优势是单一模型支持多种块大小:
- 训练阶段:随机采样块大小
- 推理阶段:根据计算预算选择块大小
- 无需微调:块大小选择完全在测试时决定
损失函数
其中 是块大小分布。
理论分析
效率-质量权衡
对于块大小 ,CAT的计算复杂度为:
相比标准注意力的 :
| 块大小 | 复杂度 | 相对效率 |
|---|---|---|
| 1× | ||
| 1/16 | ||
| 1/64 | ||
| 1× |
表达力分析
压缩会损失高频信息,但保留:
- 低频结构:整体语义、主题
- 注意力模式:token间的主要依赖关系
实验结果
语言建模
| 方法 | 困惑度 | 相对速度 | 相对内存 |
|---|---|---|---|
| Dense Transformer | 18.2 | 1.0× | 1.0× |
| CAT-B16 | 19.1 | 1.4× | 2.1× |
| CAT-B32 | 19.8 | 2.5× | 3.8× |
| CAT-B64 | 20.6 | 3.0× | 6.5× |
长上下文理解
在长上下文任务(如 needle-in-a-haystack)上的表现:
| 序列长度 | Dense | CAT-B16 | CAT-B32 |
|---|---|---|---|
| 4K | 92.3 | 91.8 | 91.2 |
| 16K | 88.1 | 87.5 | 86.9 |
| 64K | 72.4 | 78.2 | 81.5 |
CAT在长序列上反而优于Dense模型,因为压缩减少了注意力噪声。
自适应效率
单一CAT模型在多个块大小上的表现:
块大小选择 → 效率-质量权衡
↓
B=16 → 实时应用(低延迟)
B=32 → 平衡场景
B=64 → 批处理(高质量)
与其他方法的对比
| 方法 | 效率 | 质量 | 灵活性 |
|---|---|---|---|
| Sparse Attention | 高 | 中 | 低 |
| Linear Attention | 最高 | 中 | 低 |
| Hybrid | 高 | 高 | 中 |
| CAT | 高 | 高 | 高 |
应用场景
- 边缘部署:内存受限的设备
- 实时应用:低延迟要求的交互
- 长序列处理:文档理解、代码生成
- 成本敏感场景:根据预算动态调整
参考资料
相关链接
Footnotes
-
“Attention and Compression is all you need for Controllably Efficient Language Models” arXiv:2511.05313 ↩