CAT压缩注意力Transformer

概述

CAT(Compress & Attend Transformer)是一种概念简洁的架构,通过两个简单要素——密集注意力压缩——实现可控的效率-质量权衡。1

传统Transformer面临注意力二次复杂度问题,CAT通过压缩+注意力的组合,在单一自适应架构中实现测试时质量-计算权衡控制。

核心思想

问题分析

现有高效Transformer方法存在以下问题:

  1. 稀疏/滑动窗口注意力:降低质量换取效率
  2. 线性注意力:表达能力受限
  3. 混合架构:需要复杂的层组合设计
  4. 固定权衡:无法根据输入动态调整

CAT框架

CAT的核心洞察:压缩本身就是高效注意力的关键

输入序列 x_1, x_2, ..., x_n
         ↓
    块划分(chunking)
         ↓
    块内压缩 (Compression)
         ↓
    压缩块注意力 (Attention)
         ↓
    逐块解码 (Decoding)

形式化定义

对于序列 ,CAT执行以下操作:

  1. 块划分:将序列划分为大小为 的块
  2. 块压缩:每个块压缩为 个表示
  3. 块间注意力:在压缩表示上执行注意力

其中 是块压缩函数。

压缩机制

块压缩函数

class BlockCompressor(nn.Module):
    def __init__(self, d_model, chunk_size, compression_ratio):
        super().__init__()
        self.chunk_size = chunk_size
        self.compression_ratio = compression_ratio
        self.c_out = chunk_size // compression_ratio
        
        # 压缩投影
        self.compress_proj = nn.Linear(chunk_size * d_model, self.c_out * d_model)
        
        # 重构投影
        self.reconstruct_proj = nn.Linear(self.c_out * d_model, chunk_size * d_model)
        
    def compress(self, x):
        # x: [B, N, D]
        B, N, D = x.shape
        
        # 填充到块大小的倍数
        pad_len = (self.chunk_size - N % self.chunk_size) % self.chunk_size
        if pad_len > 0:
            x = F.pad(x, (0, 0, 0, pad_len))
        
        # 重塑为块
        x_reshaped = x.view(B, -1, self.chunk_size, D)  # [B, n_chunks, chunk_size, D]
        
        # 展平并压缩
        x_flat = x_reshaped.view(B, -1, self.chunk_size * D)
        x_compressed = self.compress_proj(x_flat)  # [B, n_chunks, c_out * D]
        
        # 重塑为压缩表示
        x_compressed = x_compressed.view(B, -1, self.c_out, D)
        
        return x_compressed, N  # 返回压缩表示和原始长度
    
    def reconstruct(self, x_compressed, original_len):
        B, n_chunks, c_out, D = x_compressed.shape
        
        # 展平并重构
        x_flat = x_compressed.view(B, n_chunks, c_out * D)
        x_reconstructed = self.reconstruct_proj(x_flat)
        
        # 重塑回块格式
        x_reshaped = x_reconstructed.view(B, n_chunks, self.chunk_size, D)
        x = x_reshaped.view(B, -1, D)
        
        # 截断到原始长度
        return x[:, :original_len]

多尺度压缩

CAT支持多块大小训练,实现单一模型支持多种效率-质量权衡:

class MultiScaleCAT(nn.Module):
    def __init__(self, d_model, n_heads, chunk_sizes=[16, 32, 64]):
        super().__init__()
        self.chunk_sizes = chunk_sizes
        self.compressors = nn.ModuleDict({
            str(B): BlockCompressor(d_model, B, B//4) 
            for B in chunk_sizes
        })
        self.attention = MultiHeadAttention(d_model, n_heads)
        
    def forward(self, x, chunk_size=None):
        # 训练时随机选择块大小
        if chunk_size is None and self.training:
            chunk_size = random.choice(self.chunk_sizes)
        elif chunk_size is None:
            chunk_size = min(self.chunk_sizes)
            
        compressor = self.compressors[str(chunk_size)]
        x_compressed, original_len = compressor.compress(x)
        
        # 展平块维度用于注意力
        B, n_chunks, c_out, D = x_compressed.shape
        x_flat = x_compressed.view(B, n_chunks * c_out, D)
        
        # 注意力
        x_attended = self.attention(x_flat, x_flat, x_flat)
        
        return x_attended, original_len

训练策略

多块大小联合训练

CAT的核心优势是单一模型支持多种块大小

  1. 训练阶段:随机采样块大小
  2. 推理阶段:根据计算预算选择块大小
  3. 无需微调:块大小选择完全在测试时决定

损失函数

其中 是块大小分布。

理论分析

效率-质量权衡

对于块大小 ,CAT的计算复杂度为:

相比标准注意力的

块大小复杂度相对效率
1/16
1/64

表达力分析

压缩会损失高频信息,但保留:

  • 低频结构:整体语义、主题
  • 注意力模式:token间的主要依赖关系

实验结果

语言建模

方法困惑度相对速度相对内存
Dense Transformer18.21.0×1.0×
CAT-B1619.11.4×2.1×
CAT-B3219.82.5×3.8×
CAT-B6420.63.0×6.5×

长上下文理解

在长上下文任务(如 needle-in-a-haystack)上的表现:

序列长度DenseCAT-B16CAT-B32
4K92.391.891.2
16K88.187.586.9
64K72.478.281.5

CAT在长序列上反而优于Dense模型,因为压缩减少了注意力噪声。

自适应效率

单一CAT模型在多个块大小上的表现:

块大小选择 → 效率-质量权衡
     ↓
   B=16 → 实时应用(低延迟)
   B=32 → 平衡场景
   B=64 → 批处理(高质量)

与其他方法的对比

方法效率质量灵活性
Sparse Attention
Linear Attention最高
Hybrid
CAT

应用场景

  1. 边缘部署:内存受限的设备
  2. 实时应用:低延迟要求的交互
  3. 长序列处理:文档理解、代码生成
  4. 成本敏感场景:根据预算动态调整

参考资料

相关链接

Footnotes

  1. “Attention and Compression is all you need for Controllably Efficient Language Models” arXiv:2511.05313