概述
大语言模型的推理过程中,KV Cache 是主要的内存瓶颈。随着上下文长度增加,KV Cache 的内存占用呈线性增长。本文介绍基于低秩分解的注意力矩阵压缩技术,通过 SVD 等方法显著降低内存开销。1
KV Cache 问题
标准 Transformer 的内存消耗
对于自回归生成模型,每个 token 都需要缓存其 Key 和 Value 向量:
以 LLaMA-7B 为例:
- 层数:32
- 注意力头数:32
- 头维度:128
- 最大上下文:4096
内存占用:
对于 70B 模型,KV Cache 可达 80 GB!
为什么需要压缩?
- 内存带宽瓶颈:GPU HBM 带宽有限
- 长上下文需求:越来越长的上下文窗口
- 批量推理:多请求并行处理
注意力矩阵的低秩结构
经验观察
研究(2024)发现:2
MHA 的注意力矩阵展现出明显的低秩结构,而 FFN 子层则不然。
这一不对称性为针对性压缩提供了理论基础。
秩分析
import torch
import torch.nn.functional as F
from tqdm import tqdm
def analyze_attention_rank(attention_scores):
"""
分析注意力矩阵的秩结构
Args:
attention_scores: (batch, num_heads, seq_len, seq_len)
"""
batch, num_heads, seq_len, _ = attention_scores.shape
ranks = []
for b in range(batch):
for h in range(num_heads):
A = attention_scores[b, h].cpu().numpy()
# 计算有效秩(阈值 0.01)
U, s, Vt = np.linalg.svd(A, full_matrices=False)
# 累积能量
energy = np.cumsum(s**2) / np.sum(s**2)
rank = np.searchsorted(1 - energy, 0.99) + 1
ranks.append(rank)
return np.array(ranks)
def compute_stable_rank(A):
"""
计算矩阵的稳定秩
sr(A) = ||A||_F^2 / ||A||_2^2
"""
frobenius_norm = np.linalg.norm(A, 'fro')
spectral_norm = np.linalg.norm(A, 2)
return (frobenius_norm / spectral_norm) ** 2层间差异
| 层类型 | Q/K 低秩性 | V 低秩性 | 建议压缩策略 |
|---|---|---|---|
| 浅层 (L1-12) | 高 | 中 | 压缩 Q/K、V |
| 深层 (L13-24) | 中 | 高 | 优先压缩 V |
| 输出层 (L25-32) | 低 | 低 | 少压缩 |
SVD 基础回顾
奇异值分解
任意矩阵 可分解为:
其中:
- :左奇异向量
- :右奇异向量
- :对角矩阵,
低秩近似
Eckart-Young-Mirsky 定理:
矩阵 的最优 秩近似为:
其中 只保留前 个奇异值/向量。
误差界:
Eigen Attention:特征注意力
核心思想
Eigen Attention 利用注意力矩阵的特征向量构建低秩子空间。3
算法
class EigenAttention(nn.Module):
"""
Eigen Attention: 基于特征分解的注意力优化
"""
def __init__(self, d_model, num_heads, rank_ratio=0.5):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.rank_ratio = rank_ratio
# 标准 QKV 投影
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def eigen_attention(self, Q, K, V, rank=None):
"""
核心 Eigen Attention 操作
1. 对 K^T 进行特征分解
2. 只保留主特征向量
3. 在低秩子空间计算注意力
"""
batch_size = Q.shape[0]
seq_len = Q.shape[1]
d_k = self.d_head
# 计算 rank
if rank is None:
rank = max(1, int(seq_len * self.rank_ratio))
# K^T 的特征分解 (seq_len x seq_len)
# 使用 power iteration 加速
K_T = K.transpose(-2, -1) # (batch, d_k, seq_len)
# 简化的实现:使用随机投影估计主特征向量
with torch.no_grad():
# 随机投影
n_random = min(rank * 2, seq_len)
omega = torch.randn(seq_len, n_random, device=Q.device)
# 幂迭代
Y = K_T @ omega # (batch, d_k, n_random)
QK_Y = K @ Y # (batch, seq_len, n_random)
# QR 分解得到正交基
Q_orth, _ = torch.linalg.qr(QK_Y)
# 在子空间投影 Q, K, V
Q_proj = Q @ Q_orth # (batch, seq_len, rank)
K_proj = K @ Q_orth # (batch, seq_len, rank)
V_proj = V @ Q_orth # (batch, seq_len, rank)
# 子空间内计算注意力
d_proj = Q_proj.shape[-1]
scale = 1.0 / math.sqrt(d_proj)
scores = torch.bmm(Q_proj, K_proj.transpose(-2, -1)) * scale
attn = F.softmax(scores, dim=-1)
# 子空间注意力输出
out_proj = torch.bmm(attn, V_proj) # (batch, seq_len, rank)
# 投影回原空间
out = out_proj @ Q_orth.transpose(-2, -1) # (batch, seq_len, d_k)
return out
def forward(self, x, use_eigen=False):
Q = self.W_q(x).view(-1, x.size(1), self.num_heads, self.d_head).transpose(1, 2)
K = self.W_k(x).view(-1, x.size(1), self.num_heads, self.d_head).transpose(1, 2)
V = self.W_v(x).view(-1, x.size(1), self.num_heads, self.d_head).transpose(1, 2)
if use_eigen:
out = self.eigen_attention(Q, K, V)
else:
# 标准注意力
scale = 1.0 / math.sqrt(self.d_head)
scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
out = out.transpose(1, 2).contiguous().view(-1, x.size(1), self.d_model)
return self.W_o(out)压缩效果
| 方法 | KV Cache 压缩率 | 性能损失 |
|---|---|---|
| 无压缩 | 1× | 0% |
| 压缩 20% | 0.8× | < 1% |
| 压缩 40% | 0.6× | ~ 2% |
| 压缩 60% | 0.4× | ~ 5% |
KV Cache 压缩实践
Online SVD 压缩
class KVCacheSVD:
"""
KV Cache 的在线 SVD 压缩
在生成过程中动态压缩
"""
def __init__(self, rank=32, update_freq=16):
self.rank = rank
self.update_freq = update_freq
self.cache = {} # {layer_idx: {'U': ..., 'S': ..., 'V': ...}}
self.token_counts = {}
def compress(self, K, V, layer_idx):
"""
对 K, V 进行 SVD 分解并压缩
Args:
K: (batch, num_heads, seq_len, d_head)
V: (batch, num_heads, seq_len, d_head)
layer_idx: 层索引
"""
batch, num_heads, seq_len, d_head = K.shape
if layer_idx not in self.cache:
# 首次:完整存储
self.cache[layer_idx] = {
'K': K,
'V': V,
'U_k': None, 'S_k': None, 'V_k': None,
'U_v': None, 'S_v': None, 'V_v': None
}
self.token_counts[layer_idx] = seq_len
return K, V
# 合并新旧 token
K_full = torch.cat([self.cache[layer_idx]['K'], K], dim=2)
V_full = torch.cat([self.cache[layer_idx]['V'], V], dim=2)
# 每隔 update_freq 个 token 重新压缩
if K_full.shape[2] - self.token_counts[layer_idx] >= self.update_freq:
K_compressed, V_compressed = self._svd_compress(K_full, V_full)
self.cache[layer_idx]['K'] = K_compressed
self.cache[layer_idx]['V'] = V_compressed
self.token_counts[layer_idx] = K_full.shape[2]
return K_compressed[:, :, -K.shape[2]:], V_compressed[:, :, -V.shape[2]:]
return K, V
def _svd_compress(self, K, V):
"""
SVD 压缩 K, V
"""
batch, num_heads, seq_len, d_head = K.shape
K_compressed = []
V_compressed = []
for b in range(batch):
K_head = []
V_head = []
for h in range(num_heads):
# K 的 SVD
K_mat = K[b, h].cpu().numpy()
U_k, s_k, Vt_k = np.linalg.svd(K_mat, full_matrices=False)
# 保留 top-k 奇异值
k = min(self.rank, len(s_k))
K_recon = U_k[:, :k] @ np.diag(s_k[:k]) @ Vt_k[:k]
# V 的 SVD
V_mat = V[b, h].cpu().numpy()
U_v, s_v, Vt_v = np.linalg.svd(V_mat, full_matrices=False)
# 保留 top-k
V_recon = U_v[:, :k] @ np.diag(s_v[:k]) @ Vt_v[:k]
K_head.append(torch.from_numpy(K_recon).to(K.device))
V_head.append(torch.from_numpy(V_recon).to(V.device))
K_compressed.append(torch.stack(K_head, dim=0))
V_compressed.append(torch.stack(V_head, dim=0))
return torch.stack(K_compressed, dim=0), torch.stack(V_compressed, dim=0)H2O: 最近最少使用策略
class H2OKVCache:
"""
H2O: Heavy-Hitter Oracle
基于最近最少使用策略的 KV Cache 管理
"""
def __init__(self, max_cache_size):
self.max_cache_size = max_cache_size
self.cache = {} # {token_pos: (K, V)}
self.access_counts = {}
def query(self, pos, K, V):
"""
查询或添加 KV 对
"""
if pos in self.cache:
self.access_counts[pos] += 1
return self.cache[pos]
# 缓存满了,移除最少使用的
if len(self.cache) >= self.max_cache_size:
self._evict_lru()
self.cache[pos] = (K, V)
self.access_counts[pos] = 1
return None # 未命中
def _evict_lru(self):
"""
驱逐最近最少使用的 token
"""
lru_pos = min(self.access_counts, key=self.access_counts.get)
del self.cache[lru_pos]
del self.access_counts[lru_pos]
def get_cache_size(self):
return len(self.cache)分层 KV Cache
class LayeredKVCache:
"""
分层 KV Cache
- 短期层:保留最近的 N 个 token
- 长期层:压缩存储早期 token
"""
def __init__(self, short_term_size=128, long_term_rank=32):
self.short_term_size = short_term_size
self.long_term_rank = long_term_rank
self.short_term = {} # 短期缓存
self.long_term = {} # 长期缓存(SVD 压缩)
def add(self, layer_idx, K, V):
"""
添加新的 KV 对
"""
if layer_idx not in self.short_term:
self.short_term[layer_idx] = {'K': [], 'V': []}
self.long_term[layer_idx] = {'U': [], 'S': [], 'Vt': [], 'V_full': []}
# 添加到短期缓存
self.short_term[layer_idx]['K'].append(K)
self.short_term[layer_idx]['V'].append(V)
# 如果短期缓存满了,压缩并移到长期缓存
if len(self.short_term[layer_idx]['K']) > self.short_term_size:
self._compress_to_long_term(layer_idx)
def _compress_to_long_term(self, layer_idx):
"""
将短期缓存压缩到长期存储
"""
K = torch.cat(self.short_term[layer_idx]['K'], dim=2)
V = torch.cat(self.short_term[layer_idx]['V'], dim=2)
# SVD 压缩
for b in range(K.shape[0]):
for h in range(K.shape[1]):
K_mat = K[b, h].cpu().numpy()
V_mat = V[b, h].cpu().numpy()
# K 的 SVD
U_k, s_k, Vt_k = np.linalg.svd(K_mat, full_matrices=False)
k = self.long_term_rank
self.long_term[layer_idx]['U'].append(U_k[:, :k])
self.long_term[layer_idx]['S'].append(s_k[:k])
self.long_term[layer_idx]['Vt'].append(Vt_k[:k])
self.long_term[layer_idx]['V_full'].append(V_mat)
# 清空短期缓存
self.short_term[layer_idx]['K'] = []
self.short_term[layer_idx]['V'] = []
def get_all(self, layer_idx):
"""
获取完整的 KV cache
"""
# 合并短期和长期
K_parts = []
V_parts = []
# 长期部分(解压缩)
if self.long_term[layer_idx]['U']:
for i in range(len(self.long_term[layer_idx]['U'])):
K_recon = (self.long_term[layer_idx]['U'][i] *
self.long_term[layer_idx]['S'][i]) @
self.long_term[layer_idx]['Vt'][i]
K_parts.append(torch.from_numpy(K_recon).to(self.short_term[layer_idx]['K'][0].device))
V_parts.append(self.long_term[layer_idx]['V_full'][i])
# 短期部分
if self.short_term[layer_idx]['K']:
K_parts.extend(self.short_term[layer_idx]['K'])
V_parts.extend(self.short_term[layer_idx]['V'])
K_full = torch.cat(K_parts, dim=2) if K_parts else None
V_full = torch.cat(V_parts, dim=2) if V_parts else None
return K_full, V_full量化感知压缩
INT8 量化
class QuantizedKVCache:
"""
量化 KV Cache
"""
def __init__(self, num_bits=8):
self.num_bits = num_bits
self.scale = {}
def quantize(self, x):
"""
对张量进行量化
Args:
x: (..., d) 要量化的张量
Returns:
x_q: 量化后的张量
scale: 缩放因子
"""
if self.num_bits == 8:
# INT8 量化
scale = x.abs().max() / 127.0
x_q = (x / scale).round().clamp(-128, 127)
return x_q.to(torch.int8), scale
else:
raise NotImplementedError(f"不支持 {self.num_bits} 位量化")
def dequantize(self, x_q, scale):
"""
反量化
"""
return x_q.float() * scale
def compress(self, K, V):
"""
压缩 K, V
"""
K_q, K_scale = self.quantize(K)
V_q, V_scale = self.quantize(V)
return {
'K': K_q,
'K_scale': K_scale,
'V': V_q,
'V_scale': V_scale
}FP8 量化
# 使用 TransformerEngine 的 FP8 支持
try:
import transformer_engine
class FP8KVCache:
def __init__(self):
self.fp8_format = transformer_engine.common.enum.FP8FwdTensors.GEMM1_INPUT
def quantize(self, x):
# FP8 E4M3 格式
return transformer_engine.pytorch.quantization.fp8_quantize(
x,
scale=torch.ones_like(x) # 自动缩放
)
def dequantize(self, x_q):
return x_q.float()
except ImportError:
print("TransformerEngine 未安装,使用标准量化")
QuantizedKVCache = QuantizedKVCache综合方案:FlashAttention 集成
class FlashAttentionWithCompression(nn.Module):
"""
集成压缩的 FlashAttention
"""
def __init__(self, d_model, num_heads, compression_ratio=0.5):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_model // num_heads
self.compression_ratio = compression_ratio
# QKV 投影
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
# KV Cache 压缩器
self.kv_compressor = KVCacheSVD(rank=int(self.d_head * compression_ratio))
def forward(self, x, kv_cache=None, use_cache=True):
B, T, D = x.shape
Q = self.W_q(x).view(B, T, self.num_heads, self.d_head).transpose(1, 2)
K = self.W_k(x).view(B, T, self.num_heads, self.d_head).transpose(1, 2)
V = self.W_v(x).view(B, T, self.num_heads, self.d_head).transpose(1, 2)
# KV Cache 压缩
if use_cache and kv_cache is not None:
K_cache, V_cache = kv_cache
if K_cache is not None:
K_full = torch.cat([K_cache, K], dim=2)
V_full = torch.cat([V_cache, V], dim=2)
# 压缩
K_compressed, V_compressed = self.kv_compressor.compress(K_full, V_full)
# 更新 cache
kv_cache = (K_compressed, V_compressed)
else:
kv_cache = (K, V)
# FlashAttention 计算
# 使用 flash_attn 或手动实现
scale = 1.0 / math.sqrt(self.d_head)
scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
if hasattr(F, 'scaled_dot_product_attention'):
# PyTorch 2.0+
out = F.scaled_dot_product_attention(Q, K, V, attn_mask=None)
else:
# 手动实现
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
out = out.transpose(1, 2).contiguous().view(B, T, D)
return self.W_o(out), kv_cache实践建议
选择压缩策略
| 场景 | 推荐策略 |
|---|---|
| 内存极度受限 | SVD 压缩 60%+ |
| 长上下文 | 分层 + 量化 |
| 实时性要求高 | H2O/LRU |
| 精度敏感 | 动态压缩 + 监控 |
监控指标
def monitor_compression(model, test_inputs):
"""
监控压缩效果
"""
stats = {
'original_size': 0,
'compressed_size': 0,
'rank_distribution': {},
'attention_entropy': []
}
for layer_idx, layer in enumerate(model.transformer.h):
attention = layer.attn
# 原始大小
batch, num_heads, seq_len, d_head = attention.K.shape
original = batch * num_heads * seq_len * d_head * 2 * 4 # K+V, float32
stats['original_size'] += original
# 压缩后大小(估计)
compressed = batch * num_heads * d_head * d_head * 2 * 4 # 低秩近似
stats['compressed_size'] += compressed
# 分析注意力秩分布
ranks = analyze_attention_rank(attention.last_scores)
stats['rank_distribution'][layer_idx] = ranks.mean()
stats['compression_ratio'] = stats['compressed_size'] / stats['original_size']
return stats参考
相关阅读
- 注意力秩崩溃理论 — 注意力矩阵谱分析
- Transformer 数学基础 — Transformer 架构数学推导
- 矩阵范数与神经网络 — 谱归一化技术