概述

大语言模型的推理过程中,KV Cache 是主要的内存瓶颈。随着上下文长度增加,KV Cache 的内存占用呈线性增长。本文介绍基于低秩分解的注意力矩阵压缩技术,通过 SVD 等方法显著降低内存开销。1


KV Cache 问题

标准 Transformer 的内存消耗

对于自回归生成模型,每个 token 都需要缓存其 Key 和 Value 向量:

以 LLaMA-7B 为例:

  • 层数:32
  • 注意力头数:32
  • 头维度:128
  • 最大上下文:4096

内存占用

对于 70B 模型,KV Cache 可达 80 GB

为什么需要压缩?

  1. 内存带宽瓶颈:GPU HBM 带宽有限
  2. 长上下文需求:越来越长的上下文窗口
  3. 批量推理:多请求并行处理

注意力矩阵的低秩结构

经验观察

研究(2024)发现:2

MHA 的注意力矩阵展现出明显的低秩结构,而 FFN 子层则不然。

这一不对称性为针对性压缩提供了理论基础。

秩分析

import torch
import torch.nn.functional as F
from tqdm import tqdm
 
def analyze_attention_rank(attention_scores):
    """
    分析注意力矩阵的秩结构
    
    Args:
        attention_scores: (batch, num_heads, seq_len, seq_len)
    """
    batch, num_heads, seq_len, _ = attention_scores.shape
    
    ranks = []
    for b in range(batch):
        for h in range(num_heads):
            A = attention_scores[b, h].cpu().numpy()
            
            # 计算有效秩(阈值 0.01)
            U, s, Vt = np.linalg.svd(A, full_matrices=False)
            
            # 累积能量
            energy = np.cumsum(s**2) / np.sum(s**2)
            rank = np.searchsorted(1 - energy, 0.99) + 1
            
            ranks.append(rank)
    
    return np.array(ranks)
 
def compute_stable_rank(A):
    """
    计算矩阵的稳定秩
    
    sr(A) = ||A||_F^2 / ||A||_2^2
    """
    frobenius_norm = np.linalg.norm(A, 'fro')
    spectral_norm = np.linalg.norm(A, 2)
    
    return (frobenius_norm / spectral_norm) ** 2

层间差异

层类型Q/K 低秩性V 低秩性建议压缩策略
浅层 (L1-12)压缩 Q/K、V
深层 (L13-24)优先压缩 V
输出层 (L25-32)少压缩

SVD 基础回顾

奇异值分解

任意矩阵 可分解为:

其中:

  • :左奇异向量
  • :右奇异向量
  • :对角矩阵,

低秩近似

Eckart-Young-Mirsky 定理

矩阵 的最优 秩近似为:

其中 只保留前 个奇异值/向量。

误差界


Eigen Attention:特征注意力

核心思想

Eigen Attention 利用注意力矩阵的特征向量构建低秩子空间。3

算法

class EigenAttention(nn.Module):
    """
    Eigen Attention: 基于特征分解的注意力优化
    """
    def __init__(self, d_model, num_heads, rank_ratio=0.5):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.rank_ratio = rank_ratio
        
        # 标准 QKV 投影
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def eigen_attention(self, Q, K, V, rank=None):
        """
        核心 Eigen Attention 操作
        
        1. 对 K^T 进行特征分解
        2. 只保留主特征向量
        3. 在低秩子空间计算注意力
        """
        batch_size = Q.shape[0]
        seq_len = Q.shape[1]
        d_k = self.d_head
        
        # 计算 rank
        if rank is None:
            rank = max(1, int(seq_len * self.rank_ratio))
        
        # K^T 的特征分解 (seq_len x seq_len)
        # 使用 power iteration 加速
        K_T = K.transpose(-2, -1)  # (batch, d_k, seq_len)
        
        # 简化的实现:使用随机投影估计主特征向量
        with torch.no_grad():
            # 随机投影
            n_random = min(rank * 2, seq_len)
            omega = torch.randn(seq_len, n_random, device=Q.device)
            
            # 幂迭代
            Y = K_T @ omega  # (batch, d_k, n_random)
            QK_Y = K @ Y  # (batch, seq_len, n_random)
            
            # QR 分解得到正交基
            Q_orth, _ = torch.linalg.qr(QK_Y)
        
        # 在子空间投影 Q, K, V
        Q_proj = Q @ Q_orth  # (batch, seq_len, rank)
        K_proj = K @ Q_orth  # (batch, seq_len, rank)
        V_proj = V @ Q_orth  # (batch, seq_len, rank)
        
        # 子空间内计算注意力
        d_proj = Q_proj.shape[-1]
        scale = 1.0 / math.sqrt(d_proj)
        scores = torch.bmm(Q_proj, K_proj.transpose(-2, -1)) * scale
        attn = F.softmax(scores, dim=-1)
        
        # 子空间注意力输出
        out_proj = torch.bmm(attn, V_proj)  # (batch, seq_len, rank)
        
        # 投影回原空间
        out = out_proj @ Q_orth.transpose(-2, -1)  # (batch, seq_len, d_k)
        
        return out
    
    def forward(self, x, use_eigen=False):
        Q = self.W_q(x).view(-1, x.size(1), self.num_heads, self.d_head).transpose(1, 2)
        K = self.W_k(x).view(-1, x.size(1), self.num_heads, self.d_head).transpose(1, 2)
        V = self.W_v(x).view(-1, x.size(1), self.num_heads, self.d_head).transpose(1, 2)
        
        if use_eigen:
            out = self.eigen_attention(Q, K, V)
        else:
            # 标准注意力
            scale = 1.0 / math.sqrt(self.d_head)
            scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
            attn = F.softmax(scores, dim=-1)
            out = torch.matmul(attn, V)
        
        out = out.transpose(1, 2).contiguous().view(-1, x.size(1), self.d_model)
        return self.W_o(out)

压缩效果

方法KV Cache 压缩率性能损失
无压缩0%
压缩 20%0.8×< 1%
压缩 40%0.6×~ 2%
压缩 60%0.4×~ 5%

KV Cache 压缩实践

Online SVD 压缩

class KVCacheSVD:
    """
    KV Cache 的在线 SVD 压缩
    
    在生成过程中动态压缩
    """
    def __init__(self, rank=32, update_freq=16):
        self.rank = rank
        self.update_freq = update_freq
        self.cache = {}  # {layer_idx: {'U': ..., 'S': ..., 'V': ...}}
        self.token_counts = {}
    
    def compress(self, K, V, layer_idx):
        """
        对 K, V 进行 SVD 分解并压缩
        
        Args:
            K: (batch, num_heads, seq_len, d_head)
            V: (batch, num_heads, seq_len, d_head)
            layer_idx: 层索引
        """
        batch, num_heads, seq_len, d_head = K.shape
        
        if layer_idx not in self.cache:
            # 首次:完整存储
            self.cache[layer_idx] = {
                'K': K,
                'V': V,
                'U_k': None, 'S_k': None, 'V_k': None,
                'U_v': None, 'S_v': None, 'V_v': None
            }
            self.token_counts[layer_idx] = seq_len
            return K, V
        
        # 合并新旧 token
        K_full = torch.cat([self.cache[layer_idx]['K'], K], dim=2)
        V_full = torch.cat([self.cache[layer_idx]['V'], V], dim=2)
        
        # 每隔 update_freq 个 token 重新压缩
        if K_full.shape[2] - self.token_counts[layer_idx] >= self.update_freq:
            K_compressed, V_compressed = self._svd_compress(K_full, V_full)
            self.cache[layer_idx]['K'] = K_compressed
            self.cache[layer_idx]['V'] = V_compressed
            self.token_counts[layer_idx] = K_full.shape[2]
            return K_compressed[:, :, -K.shape[2]:], V_compressed[:, :, -V.shape[2]:]
        
        return K, V
    
    def _svd_compress(self, K, V):
        """
        SVD 压缩 K, V
        """
        batch, num_heads, seq_len, d_head = K.shape
        
        K_compressed = []
        V_compressed = []
        
        for b in range(batch):
            K_head = []
            V_head = []
            for h in range(num_heads):
                # K 的 SVD
                K_mat = K[b, h].cpu().numpy()
                U_k, s_k, Vt_k = np.linalg.svd(K_mat, full_matrices=False)
                
                # 保留 top-k 奇异值
                k = min(self.rank, len(s_k))
                K_recon = U_k[:, :k] @ np.diag(s_k[:k]) @ Vt_k[:k]
                
                # V 的 SVD
                V_mat = V[b, h].cpu().numpy()
                U_v, s_v, Vt_v = np.linalg.svd(V_mat, full_matrices=False)
                
                # 保留 top-k
                V_recon = U_v[:, :k] @ np.diag(s_v[:k]) @ Vt_v[:k]
                
                K_head.append(torch.from_numpy(K_recon).to(K.device))
                V_head.append(torch.from_numpy(V_recon).to(V.device))
            
            K_compressed.append(torch.stack(K_head, dim=0))
            V_compressed.append(torch.stack(V_head, dim=0))
        
        return torch.stack(K_compressed, dim=0), torch.stack(V_compressed, dim=0)

H2O: 最近最少使用策略

class H2OKVCache:
    """
    H2O: Heavy-Hitter Oracle
    基于最近最少使用策略的 KV Cache 管理
    """
    def __init__(self, max_cache_size):
        self.max_cache_size = max_cache_size
        self.cache = {}  # {token_pos: (K, V)}
        self.access_counts = {}
    
    def query(self, pos, K, V):
        """
        查询或添加 KV 对
        """
        if pos in self.cache:
            self.access_counts[pos] += 1
            return self.cache[pos]
        
        # 缓存满了,移除最少使用的
        if len(self.cache) >= self.max_cache_size:
            self._evict_lru()
        
        self.cache[pos] = (K, V)
        self.access_counts[pos] = 1
        
        return None  # 未命中
    
    def _evict_lru(self):
        """
        驱逐最近最少使用的 token
        """
        lru_pos = min(self.access_counts, key=self.access_counts.get)
        del self.cache[lru_pos]
        del self.access_counts[lru_pos]
    
    def get_cache_size(self):
        return len(self.cache)

分层 KV Cache

class LayeredKVCache:
    """
    分层 KV Cache
    
    - 短期层:保留最近的 N 个 token
    - 长期层:压缩存储早期 token
    """
    def __init__(self, short_term_size=128, long_term_rank=32):
        self.short_term_size = short_term_size
        self.long_term_rank = long_term_rank
        self.short_term = {}  # 短期缓存
        self.long_term = {}   # 长期缓存(SVD 压缩)
    
    def add(self, layer_idx, K, V):
        """
        添加新的 KV 对
        """
        if layer_idx not in self.short_term:
            self.short_term[layer_idx] = {'K': [], 'V': []}
            self.long_term[layer_idx] = {'U': [], 'S': [], 'Vt': [], 'V_full': []}
        
        # 添加到短期缓存
        self.short_term[layer_idx]['K'].append(K)
        self.short_term[layer_idx]['V'].append(V)
        
        # 如果短期缓存满了,压缩并移到长期缓存
        if len(self.short_term[layer_idx]['K']) > self.short_term_size:
            self._compress_to_long_term(layer_idx)
    
    def _compress_to_long_term(self, layer_idx):
        """
        将短期缓存压缩到长期存储
        """
        K = torch.cat(self.short_term[layer_idx]['K'], dim=2)
        V = torch.cat(self.short_term[layer_idx]['V'], dim=2)
        
        # SVD 压缩
        for b in range(K.shape[0]):
            for h in range(K.shape[1]):
                K_mat = K[b, h].cpu().numpy()
                V_mat = V[b, h].cpu().numpy()
                
                # K 的 SVD
                U_k, s_k, Vt_k = np.linalg.svd(K_mat, full_matrices=False)
                k = self.long_term_rank
                self.long_term[layer_idx]['U'].append(U_k[:, :k])
                self.long_term[layer_idx]['S'].append(s_k[:k])
                self.long_term[layer_idx]['Vt'].append(Vt_k[:k])
                self.long_term[layer_idx]['V_full'].append(V_mat)
        
        # 清空短期缓存
        self.short_term[layer_idx]['K'] = []
        self.short_term[layer_idx]['V'] = []
    
    def get_all(self, layer_idx):
        """
        获取完整的 KV cache
        """
        # 合并短期和长期
        K_parts = []
        V_parts = []
        
        # 长期部分(解压缩)
        if self.long_term[layer_idx]['U']:
            for i in range(len(self.long_term[layer_idx]['U'])):
                K_recon = (self.long_term[layer_idx]['U'][i] * 
                          self.long_term[layer_idx]['S'][i]) @ 
                          self.long_term[layer_idx]['Vt'][i]
                K_parts.append(torch.from_numpy(K_recon).to(self.short_term[layer_idx]['K'][0].device))
                V_parts.append(self.long_term[layer_idx]['V_full'][i])
        
        # 短期部分
        if self.short_term[layer_idx]['K']:
            K_parts.extend(self.short_term[layer_idx]['K'])
            V_parts.extend(self.short_term[layer_idx]['V'])
        
        K_full = torch.cat(K_parts, dim=2) if K_parts else None
        V_full = torch.cat(V_parts, dim=2) if V_parts else None
        
        return K_full, V_full

量化感知压缩

INT8 量化

class QuantizedKVCache:
    """
    量化 KV Cache
    """
    def __init__(self, num_bits=8):
        self.num_bits = num_bits
        self.scale = {}
    
    def quantize(self, x):
        """
        对张量进行量化
        
        Args:
            x: (..., d) 要量化的张量
        Returns:
            x_q: 量化后的张量
            scale: 缩放因子
        """
        if self.num_bits == 8:
            # INT8 量化
            scale = x.abs().max() / 127.0
            x_q = (x / scale).round().clamp(-128, 127)
            return x_q.to(torch.int8), scale
        else:
            raise NotImplementedError(f"不支持 {self.num_bits} 位量化")
    
    def dequantize(self, x_q, scale):
        """
        反量化
        """
        return x_q.float() * scale
    
    def compress(self, K, V):
        """
        压缩 K, V
        """
        K_q, K_scale = self.quantize(K)
        V_q, V_scale = self.quantize(V)
        
        return {
            'K': K_q,
            'K_scale': K_scale,
            'V': V_q,
            'V_scale': V_scale
        }

FP8 量化

# 使用 TransformerEngine 的 FP8 支持
try:
    import transformer_engine
    
    class FP8KVCache:
        def __init__(self):
            self.fp8_format = transformer_engine.common.enum.FP8FwdTensors.GEMM1_INPUT
        
        def quantize(self, x):
            # FP8 E4M3 格式
            return transformer_engine.pytorch.quantization.fp8_quantize(
                x, 
                scale=torch.ones_like(x)  # 自动缩放
            )
        
        def dequantize(self, x_q):
            return x_q.float()
        
except ImportError:
    print("TransformerEngine 未安装,使用标准量化")
    QuantizedKVCache = QuantizedKVCache

综合方案:FlashAttention 集成

class FlashAttentionWithCompression(nn.Module):
    """
    集成压缩的 FlashAttention
    """
    def __init__(self, d_model, num_heads, compression_ratio=0.5):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.compression_ratio = compression_ratio
        
        # QKV 投影
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        # KV Cache 压缩器
        self.kv_compressor = KVCacheSVD(rank=int(self.d_head * compression_ratio))
    
    def forward(self, x, kv_cache=None, use_cache=True):
        B, T, D = x.shape
        
        Q = self.W_q(x).view(B, T, self.num_heads, self.d_head).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.num_heads, self.d_head).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.num_heads, self.d_head).transpose(1, 2)
        
        # KV Cache 压缩
        if use_cache and kv_cache is not None:
            K_cache, V_cache = kv_cache
            if K_cache is not None:
                K_full = torch.cat([K_cache, K], dim=2)
                V_full = torch.cat([V_cache, V], dim=2)
                
                # 压缩
                K_compressed, V_compressed = self.kv_compressor.compress(K_full, V_full)
                
                # 更新 cache
                kv_cache = (K_compressed, V_compressed)
            else:
                kv_cache = (K, V)
        
        # FlashAttention 计算
        # 使用 flash_attn 或手动实现
        scale = 1.0 / math.sqrt(self.d_head)
        scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
        
        if hasattr(F, 'scaled_dot_product_attention'):
            # PyTorch 2.0+
            out = F.scaled_dot_product_attention(Q, K, V, attn_mask=None)
        else:
            # 手动实现
            attn = F.softmax(scores, dim=-1)
            out = torch.matmul(attn, V)
        
        out = out.transpose(1, 2).contiguous().view(B, T, D)
        return self.W_o(out), kv_cache

实践建议

选择压缩策略

场景推荐策略
内存极度受限SVD 压缩 60%+
长上下文分层 + 量化
实时性要求高H2O/LRU
精度敏感动态压缩 + 监控

监控指标

def monitor_compression(model, test_inputs):
    """
    监控压缩效果
    """
    stats = {
        'original_size': 0,
        'compressed_size': 0,
        'rank_distribution': {},
        'attention_entropy': []
    }
    
    for layer_idx, layer in enumerate(model.transformer.h):
        attention = layer.attn
        
        # 原始大小
        batch, num_heads, seq_len, d_head = attention.K.shape
        original = batch * num_heads * seq_len * d_head * 2 * 4  # K+V, float32
        stats['original_size'] += original
        
        # 压缩后大小(估计)
        compressed = batch * num_heads * d_head * d_head * 2 * 4  # 低秩近似
        stats['compressed_size'] += compressed
        
        # 分析注意力秩分布
        ranks = analyze_attention_rank(attention.last_scores)
        stats['rank_distribution'][layer_idx] = ranks.mean()
    
    stats['compression_ratio'] = stats['compressed_size'] / stats['original_size']
    
    return stats

参考


相关阅读

Footnotes

  1. Xiao, G., et al. (2024). “An Empirical Investigation of Matrix Factorization Methods for Pre-trained Transformers”. arXiv:2406.11307.

  2. “LoRA-RQ: Rotated Quantization for LoRA”. ICML 2024.

  3. “Eigen Attention: Attention in Low-Rank Space for KV Cache Compression”. EMNLP Findings 2024.