KV Cache压缩技术综述
1. 背景与挑战
1.1 KV Cache的核心作用
在自回归大语言模型(Large Language Models, LLMs)的推理过程中,KV Cache是一种至关重要的缓存技术。Transformer架构的自注意力机制要求模型在生成每个新token时,需要回顾之前所有token的信息。KV Cache通过缓存已经计算好的Key和Value向量,避免了在每一步生成时重复计算,从而显著提升了推理效率。1
考虑一个典型的解码过程:假设我们正在生成一个长度为 的序列。在第 步时,模型需要计算注意力:
其中 是当前token的Query向量, 和 是从第1个token到第 个token的所有Key和Value向量。如果不使用KV Cache,每次生成新token时都需要重新计算所有历史token的Key和Value,这导致了 的计算复杂度。使用KV Cache后,复杂度可以降低到 级别。
1.2 内存瓶颈问题
尽管KV Cache显著加速了推理过程,但它同时也带来了严重的内存问题。随着序列长度的增加,KV Cache的内存占用呈现线性增长趋势。对于一个典型的70亿参数模型(如LLaMA-7B),假设:
- 隐藏维度
- Key/Value投影维度
- 注意力头数
- 批量大小
- 序列长度
每个token需要存储的KV向量大小为:
对于长度为8192的序列,单个样本的KV Cache大小为:
这还只是一个样本的情况。在实际部署中,如果需要同时处理多个请求或使用更大的批量,内存消耗会急剧增加。这导致了所谓的**内存墙(Memory Wall)**问题,成为限制LLM推理吞吐量的主要瓶颈。
1.3 计算与内存的权衡
KV Cache的存在本质上是在计算和内存之间寻求平衡。一方面,缓存可以避免重复计算;另一方面,存储所有历史KV向量需要消耗大量内存。在资源受限的场景下,如何高效地管理KV Cache成为关键问题。
近年来,随着长上下文模型(如GPT-4支持128K上下文、LLaMA-3支持128K上下文)的发展,KV Cache的内存问题变得更加突出。处理长达10万token的上下文意味着需要存储GB级别的KV数据,这远远超出了大多数GPU的显存容量。
2. 压缩方法分类体系
针对KV Cache的内存瓶颈,研究者们提出了多种压缩技术。根据压缩策略的不同,可以将现有方法分为三大类:量化方法(Quantization)、稀疏方法(Sparsification)和低秩方法(Low-Rank Approximation)。2
┌─────────────────────────────────────────────────────────────────┐
│ KV Cache压缩方法分类 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │
│ │ 量化方法 │ │ 稀疏方法 │ │ 低秩方法 │ │
│ ├─────────────────┤ ├─────────────────┤ ├─────────────────┤ │
│ │ • Token级量化 │ │ • 选择性丢弃 │ │ • SVD分解 │ │
│ │ • Group级量化 │ │ • 重要性估计 │ │ • 低秩投影 │ │
│ │ • 混合精度量化 │ │ • 局部性敏感哈希 │ │ • 张量分解 │ │
│ │ • 动态量化 │ │ • 分层压缩 │ │ • NMF分解 │ │
│ └─────────────────┘ └─────────────────┘ └─────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
2.1 量化方法概述
量化方法通过降低KV向量的表示精度来减少内存占用。传统的KV Cache使用FP32(32位浮点数)存储每个元素,通过量化可以压缩到FP16、INT8甚至更低精度。量化方法的核心挑战在于如何在保持模型性能的同时实现高效压缩。
2.2 稀疏方法概述
稀疏方法通过识别并保留最重要的KV向量,同时丢弃不重要的部分来减少缓存大小。这类方法通常需要设计有效的重要性估计机制,以判断哪些token的KV表示应该被保留。
2.3 低秩方法概述
低秩方法利用KV矩阵的低秩特性,通过矩阵分解或投影将高维向量压缩到低维空间。这类方法可以显著减少存储需求,但可能引入重构误差。
3. 量化方法详解
3.1 Token级量化
**Token级量化(Token-level Quantization)**是最直接的压缩方式,它对每个token的KV向量独立进行量化。这种方法假设不同token的KV表示具有相似的数值分布,因此可以使用统一的量化参数。
3.1.1 均匀量化
均匀量化将数值范围均匀划分为 个区间,其中 是目标位数。对于KV向量中的每个元素 ,量化过程为:
其中量化尺度 , 函数将结果限制在有效范围内。
反量化过程为:
均匀量化的优点是实现简单、计算速度快。但它没有考虑数值分布的非均匀性,对于存在离群值的数据,压缩效果有限。
3.1.2 非均匀量化
非均匀量化使用非均匀的量化间隔,通常基于数据分布来设计。一种常见的方法是使用非线性量化,如对数量化:
对数量化可以更好地处理动态范围大的数据,在语音和音频处理中应用广泛。
另一种非均匀量化方法是K-means聚类量化:
- 收集所有KV向量元素
- 使用K-means聚类找到 个聚类中心
- 将每个元素映射到最近的聚类中心
这种方法需要离线训练聚类中心,增加了系统复杂度。
3.2 Group级量化
**Group级量化(Group-level Quantization)**是一种更精细的量化策略,它将KV向量分成多个组(Group),每个组使用独立的量化参数。这种方法可以更好地适应不同组之间的数值差异,提高压缩质量。
3.2.1 分组策略
常见的分组策略包括:
- 按token分组:每个token的KV向量作为一组
- 按层分组:同一层的所有token的KV向量作为一组
- 按头分组:同一个注意力头的所有token的KV向量作为一组
- 混合分组:结合多种分组策略
分组粒度越细,量化精度越高,但需要存储的量化参数也越多。在实际应用中,需要在压缩率和精度之间找到平衡。
3.2.2 量化参数存储
对于Group级量化,需要额外存储每组的量化参数。假设:
- 分组数:
- 每组元素数:
- 量化位数:
- 每个量化参数使用 位存储
原始数据大小: bits
压缩后大小: bits
当 较大时,量化参数的存储开销变得显著。
3.3 混合精度量化
**混合精度量化(Mixed-Precision Quantization)**根据不同元素或组的重要性分配不同的量化精度。重要的元素使用高精度(如INT8),不重要的元素使用低精度(如INT4或INT2)。
3.3.1 敏感性分析
混合精度的关键是正确评估各部分的重要性。常用的方法包括:
- Hessian/梯度敏感性:基于二阶导数分析,Hessian矩阵较大的维度对量化更敏感
- 激活值方差:方差较大的维度通常包含更多信息,需要更高精度
- Attention分布:在Attention计算中贡献较大的维度更重要
3.3.2 优化框架
混合精度量化可以形式化为以下优化问题:
其中 是第 组的量化位数, 是对应的量化函数, 是正则化系数用于控制压缩率。
3.4 动态量化
**动态量化(Dynamic Quantization)**在推理过程中实时进行量化,而非离线预计算。这种方法可以适应输入数据的分布变化,但增加了运行时开销。
动态量化通常与KV Cache更新结合使用:当新的token被生成时,其KV向量会被动态量化后存入缓存;当需要使用缓存时,再进行反量化。
3.5 量化方法实现示例
以下是一个使用PyTorch实现Group级INT8量化的示例:
import torch
import torch.nn.functional as F
class GroupInt8Quantizer:
"""Group级INT8量化器"""
def __init__(self, group_size=128, num_groups=None):
self.group_size = group_size
self.num_groups = num_groups
def quantize(self, x: torch.Tensor) -> tuple:
"""
对输入张量进行Group级INT8量化
Args:
x: 输入张量,形状为 [seq_len, hidden_dim]
Returns:
quantized: INT8量化值
scales: 每组的量化尺度
zero_points: 每组的零点偏移
"""
seq_len, hidden_dim = x.shape
# 计算分组数
if self.num_groups is None:
self.num_groups = hidden_dim // self.group_size
# reshape为分组格式
x_reshaped = x.view(seq_len, self.num_groups, self.group_size)
# 计算每组的最小值和最大值
x_min = x_reshaped.amin(dim=-1, keepdim=True)
x_max = x_reshaped.amax(dim=-1, keepdim=True)
# 计算量化参数
scale = (x_max - x_min) / 255.0
zero_point = (-x_min / scale).round().clamp(0, 255)
# 量化
x_quantized = ((x_reshaped / scale) + zero_point).round().clamp(0, 255)
# 转换回原始形状
quantized = x_quantized.view(seq_len, hidden_dim)
return quantized.to(torch.uint8), scale.squeeze(-1), zero_point.squeeze(-1)
def dequantize(self, quantized: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor) -> torch.Tensor:
"""
反量化恢复原始精度
Args:
quantized: INT8量化值
scales: 量化尺度
zero_points: 零点偏移
Returns:
x: 恢复后的浮点张量
"""
seq_len, hidden_dim = quantized.shape
# reshape为分组格式
q_reshaped = quantized.view(seq_len, self.num_groups, self.group_size)
scales = scales.view(1, self.num_groups, 1)
zero_points = zero_points.view(1, self.num_groups, 1)
# 反量化
x_dequantized = (q_reshaped.float() - zero_points) * scales
return x_dequantized.view(seq_len, hidden_dim)
class KVCacheQuantizer:
"""KV Cache量化管理器"""
def __init__(self, num_heads, head_dim, group_size=64):
self.num_heads = num_heads
self.head_dim = head_dim
self.group_size = group_size
self.quantizers = [
GroupInt8Quantizer(group_size)
for _ in range(num_heads)
]
def quantize_kv(self, k: torch.Tensor, v: torch.Tensor):
"""
对KV向量进行量化
Args:
k: Key向量 [batch, num_heads, seq_len, head_dim]
v: Value向量 [batch, num_heads, seq_len, head_dim]
Returns:
量化后的KV和量化参数
"""
batch, num_heads, seq_len, head_dim = k.shape
k_quantized_all = []
k_scales_all = []
k_zp_all = []
v_quantized_all = []
v_scales_all = []
v_zp_all = []
for h in range(num_heads):
# 对每个头单独量化
k_head = k[:, h, :, :].squeeze(0) # [seq_len, head_dim]
v_head = v[:, h, :, :].squeeze(0)
k_q, k_s, k_zp = self.quantizers[h].quantize(k_head)
v_q, v_s, v_zp = self.quantizers[h].quantize(v_head)
k_quantized_all.append(k_q)
k_scales_all.append(k_s)
k_zp_all.append(k_zp)
v_quantized_all.append(v_q)
v_scales_all.append(v_s)
v_zp_all.append(v_zp)
return {
'k_quantized': torch.stack(k_quantized_all, dim=0), # [num_heads, seq_len, head_dim]
'k_scales': torch.stack(k_scales_all, dim=0),
'k_zp': torch.stack(k_zp_all, dim=0),
'v_quantized': torch.stack(v_quantized_all, dim=0),
'v_scales': torch.stack(v_scales_all, dim=0),
'v_zp': torch.stack(v_zp_all, dim=0),
}
def get_compression_ratio(self, original_bits=32, quantized_bits=8):
"""计算压缩比"""
return original_bits / quantized_bits4. 稀疏方法详解
4.1 选择性丢弃策略
**选择性丢弃(Selective Dropping)**是稀疏方法中最直接的一类。这类方法通过预定义的规则或学习到的策略,决定哪些token的KV表示应该被保留,哪些可以被丢弃。
4.1.1 基于规则的丢弃
最简单的丢弃策略是基于位置或内容的规则:
- 固定窗口丢弃:只保留最近 个token的KV,丢弃更早的token
- 均匀采样丢弃:每隔 个token保留一个,丢弃中间的
- 层级丢弃:根据token的位置分配不同级别的保留优先级
固定窗口丢弃虽然简单,但会丢失远程依赖信息。一种改进策略是多尺度窗口,同时保留多个不同大小的窗口:
4.1.2 基于学习的丢弃
更智能的丢弃策略通过学习来决定保留哪些token。H2O(Heavy-Hitter Oracle)算法3提出了基于”重击者”(Heavy Hitters)的概念:
核心思想:在注意力机制中,只有少数”重击者”token对当前token的预测贡献最大。通过追踪和保留这些重要的历史token,可以以较低的存储成本维持大部分模型性能。
H2O的丢弃策略基于注意力分数的累积:
其中 是token 对token 的注意力权重。在每一步,选择累积分数最低的token进行丢弃。
4.2 重要性估计方法
**重要性估计(Importance Estimation)**是稀疏方法的核心技术。准确估计每个token的重要性,可以帮助我们在保持模型性能的同时最大化压缩率。
4.2.1 基于Attention的重要性
Attention权重本身就是一种自然的重要性度量。在自回归模型中:
-
Query-Key相似度:
-
累积注意力:跟踪每个历史token收到的总注意力:
-
Token重要性分数:使用梯度或扰动分析来估计:
4.2.2 基于预测的重要性
另一种重要性估计方法基于对最终预测的贡献:
-
梯度重要性:计算损失函数对KV向量的梯度:
-
扰动重要性:对每个token的KV添加扰动,观察对输出的影响:
-
互信息:计算token与输出之间的互信息:
4.2.3 轻量级重要性估计器
精确的重要性估计通常计算开销较大。研究人员提出了多种轻量级估计方法:
- 预定义启发式:使用简单的统计量(如L2范数)作为重要性代理
- 辅助网络:训练一个小网络来预测重要性分数
- 在线更新:维护一个滚动的重要性估计,随新token的生成不断更新
4.3 局部性敏感哈希
**局部性敏感哈希(Locality-Sensitive Hashing, LSH)**是一种近似最近邻搜索技术,可以用于高效地查找相似的token并合并它们。
4.3.1 LSH原理
LSH通过哈希函数将相似的输入映射到相同的”桶”中。对于向量 ,LSH使用一组随机投影:
其中 是随机向量。相似的向量在随机投影后大概率落在相同的超平面同一侧,从而被映射到相同的哈希值。
4.3.2 在KV Cache中的应用
在KV Cache压缩中,LSH可以用于识别相似的KV表示并合并它们:
- 对每个新token的KV向量计算LSH哈希值
- 将具有相同哈希值的token分组
- 同一组内的token只保留一个代表(可以是平均值或簇中心)
- 更新Attention计算时,使用组代表代替组内所有token
4.4 分层压缩
**分层压缩(Hierarchical Compression)**采用多层次的压缩策略,在不同粒度上进行压缩。
┌─────────────────────────────────────────┐
│ 原始KV Cache │
│ [seq_len, num_heads, head_dim] │
└─────────────────┬───────────────────────┘
│
┌─────────┴─────────┐
▼ ▼
┌───────────────┐ ┌───────────────┐
│ 细粒度层 │ │ 粗粒度层 │
│ (Token级别) │ │ (块级别) │
└───────────────┘ └───────────────┘
│ │
└─────────┬─────────┘
▼
┌───────────────────┐
│ 聚合表示层 │
│ (层级摘要) │
└───────────────────┘
4.5 稀疏方法实现示例
以下是H2O算法的简化实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
class H2OKVCache:
"""
H2O (Heavy-Hitter Oracle) KV Cache实现
核心思想:追踪累积注意力分数,保留"重击者"token,丢弃低贡献token
"""
def __init__(self, max_seq_len, num_heads, head_dim, budget_ratio=0.3):
"""
Args:
max_seq_len: 最大序列长度
num_heads: 注意力头数
head_dim: 每个头的维度
budget_ratio: 保留token的比例
"""
self.max_seq_len = max_seq_len
self.num_heads = num_heads
self.head_dim = head_dim
self.budget_ratio = budget_ratio
# 累积注意力分数(每个头单独维护)
self.cumulative_attn = torch.zeros(max_seq_len, num_heads)
# 保留的token索引
self.kept_indices = []
# KV缓存
self.k_cache = torch.zeros(max_seq_len, num_heads, head_dim)
self.v_cache = torch.zeros(max_seq_len, num_heads, head_dim)
self.current_len = 0
def update(self, k_new: torch.Tensor, v_new: torch.Tensor,
attn_weights: torch.Tensor):
"""
更新KV Cache
Args:
k_new: 新Key向量 [batch, num_heads, 1, head_dim]
v_new: 新Value向量 [batch, num_heads, 1, head_dim]
attn_weights: 当前step的注意力权重 [batch, num_heads, seq_len, 1]
"""
batch, num_heads, _, head_dim = k_new.shape
pos = self.current_len
# 存储新的KV
self.k_cache[pos] = k_new.squeeze(2).squeeze(0)
self.v_cache[pos] = v_new.squeeze(2).squeeze(0)
# 更新累积注意力分数
# attn_weights: [batch, num_heads, seq_len, 1]
self.cumulative_attn[:pos+1] += attn_weights.squeeze(0).squeeze(-1).T
self.current_len += 1
# 如果超过预算,执行压缩
if self.current_len > self.max_seq_len * self.budget_ratio:
self._compress()
def _compress(self):
"""压缩:保留累积注意力分数最高的token"""
budget = int(self.max_seq_len * self.budget_ratio)
# 计算每个位置的全局重要性分数
importance = self.cumulative_attn[:self.current_len].mean(dim=1)
# 选择top-k重要位置
_, top_indices = torch.topk(importance, k=min(budget, len(importance)))
top_indices = top_indices.sort()[0] # 保持顺序
# 压缩KV缓存
self.k_cache[:len(top_indices)] = self.k_cache[top_indices]
self.v_cache[:len(top_indices)] = self.v_cache[top_indices]
# 调整累积注意力
self.cumulative_attn[:len(top_indices)] = self.cumulative_attn[top_indices]
self.kept_indices = top_indices.tolist()
self.current_len = len(top_indices)
def get_cache(self):
"""获取当前缓存的KV"""
return (
self.k_cache[:self.current_len].unsqueeze(0),
self.v_cache[:self.current_len].unsqueeze(0)
)
class StreamingAttention(nn.Module):
"""
支持H2O KV Cache的流式注意力层
"""
def __init__(self, hidden_dim, num_heads, head_dim, budget_ratio=0.3):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.q_proj = nn.Linear(hidden_dim, num_heads * head_dim)
self.k_proj = nn.Linear(hidden_dim, num_heads * head_dim)
self.v_proj = nn.Linear(hidden_dim, num_heads * head_dim)
self.o_proj = nn.Linear(num_heads * head_dim, hidden_dim)
self.kv_cache = None
self.budget_ratio = budget_ratio
def forward(self, x: torch.Tensor, use_cache=True):
"""
前向传播
Args:
x: 输入张量 [batch, seq_len, hidden_dim]
use_cache: 是否使用KV Cache
Returns:
output: 输出张量
k_new: 新的Key向量(用于更新cache)
v_new: 新的Value向量
"""
batch, seq_len, _ = x.shape
# 投影得到QKV
q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
k = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
v = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
if seq_len == 1 and use_cache and self.kv_cache is not None:
# 单步推理,使用KV Cache
k_cache, v_cache = self.kv_cache.get_cache()
# 拼接缓存的KV
k_full = torch.cat([k_cache, k], dim=2)
v_full = torch.cat([v_cache, v], dim=2)
# 计算注意力(使用缓存)
scale = self.head_dim ** -0.5
attn_weights = torch.matmul(q.transpose(1, 2) * scale,
k_full.transpose(1, 2).transpose(2, 3))
attn_weights = F.softmax(attn_weights, dim=-1)
# 更新KV Cache
self.kv_cache.update(k, v, attn_weights[:, :, :-1, :])
# 计算输出
attn_output = torch.matmul(attn_weights, v_full.transpose(1, 2))
attn_output = attn_output.transpose(1, 2).contiguous()
else:
# 普通注意力计算
scale = self.head_dim ** -0.5
attn_weights = torch.matmul(q.transpose(1, 2) * scale,
k.transpose(1, 2).transpose(2, 3))
attn_weights = F.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, v.transpose(1, 2))
attn_output = attn_output.transpose(1, 2).contiguous()
# 输出投影
attn_output = attn_output.reshape(batch, seq_len, self.num_heads * self.head_dim)
output = self.o_proj(attn_output)
return output, k, v
def init_cache(self, max_seq_len):
"""初始化KV Cache"""
self.kv_cache = H2OKVCache(
max_seq_len, self.num_heads, self.head_dim, self.budget_ratio
)5. 低秩方法详解
5.1 低秩假设与理论基础
低秩方法基于一个关键假设:KV矩阵具有低秩特性。虽然原始KV向量的维度可能很高( 或 ),但其有效信息可以用更少的维度来表示。
考虑KV Cache中存储的Key矩阵 ,其中 是序列长度。如果 的秩 ,那么 可以被分解为:
其中 ,。
存储原始矩阵需要 个参数,而存储分解后的矩阵只需要 个参数。当 时,压缩效果显著。
5.2 SVD分解方法
**奇异值分解(Singular Value Decomposition, SVD)**是最经典的低秩近似方法。对于矩阵 ,其SVD分解为:
其中:
- :左奇异向量矩阵
- :奇异值对角矩阵
- :右奇异向量矩阵
保留前 个最大的奇异值,得到截断SVD:
根据Eckart-Young-Mirsky定理,截断SVD是在Frobenius范数意义下对原矩阵的最佳低秩近似。
5.2.1 在线SVD更新
在流式场景中,KV Cache需要不断更新。在线SVD(Online SVD)可以高效地更新低秩表示:
设已有矩阵 的SVD分解,新到来的行向量为 ,则:
使用QR分解或秩-1更新算法可以高效地更新SVD分解,而无需重新计算。
5.3 低秩投影方法
**低秩投影(Low-rank Projection)**直接在原始KV向量上应用一个投影矩阵,将高维向量映射到低维空间:
其中 是投影矩阵, 是目标维度。
5.3.1 随机投影
最简单的投影方法是使用随机矩阵。Johnson-Lindenstrauss引理保证了随机投影可以近似保持点间距离:
给定 和整数 ,存在一个映射 ,其中 ,使得对于任意 个点,任意两点的距离被保持在一个 因子内。
常用的随机投影包括:
- 高斯随机投影:矩阵元素服从
- 稀疏随机投影:矩阵元素以概率 服从
- 稀疏嵌入(Sparse Encoding):将向量映射到稀疏低维空间
5.3.2 学习式投影
更优的投影方法是通过学习得到的。EliteKV4提出了一种基于RoPE频率选择的学习式低秩投影方法:
- 分析KV向量在不同RoPE频率下的能量分布
- 识别包含最多信息的频率分量
- 对这些重要频率应用投影,保留关键信息
5.4 张量分解方法
高阶张量分解可以更好地捕捉KV Cache的多维结构。假设我们有三维KV张量 ,其中:
- :序列长度
- :注意力头数
- :头维度
Tucker分解将这个张量分解为:
其中 是核心张量, 是各模态的正交因子矩阵。
**CP分解(Canonical Polyadic)**将张量表示为多个秩-1张量的和:
5.5 NMF分解方法
**非负矩阵分解(Non-negative Matrix Factorization, NMF)**假设矩阵元素非负,并将矩阵分解为两个非负矩阵的乘积:
NMF的优势在于其结果具有可解释性:每一列可以视为一个”主题”,每一行表示对应token在该主题上的权重。
5.6 低秩方法实现示例
以下是低秩投影的PyTorch实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class LowRankProjection(nn.Module):
"""
低秩投影KV Cache压缩
通过学习或随机投影将高维KV向量映射到低维空间
"""
def __init__(self, hidden_dim: int, num_heads: int, head_dim: int,
rank_ratio: float = 0.5, learnable: bool = True):
"""
Args:
hidden_dim: 隐藏层维度
num_heads: 注意力头数
head_dim: 每个头的维度
rank_ratio: 低秩维度与原始维度的比例
learnable: 是否使用可学习的投影矩阵
"""
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.rank = int(head_dim * rank_ratio)
if learnable:
# 可学习的投影矩阵
self.k_proj = nn.Linear(head_dim, self.rank, bias=False)
self.v_proj = nn.Linear(head_dim, self.rank, bias=False)
else:
# 随机投影
self.register_buffer('k_proj',
torch.randn(head_dim, self.rank) / (self.rank ** 0.5))
self.register_buffer('v_proj',
torch.randn(head_dim, self.rank) / (self.rank ** 0.5))
def project_k(self, k: torch.Tensor) -> torch.Tensor:
"""投影Key向量"""
# k: [..., head_dim]
if self.k_proj.weight.device != k.device:
self.k_proj = self.k_proj.to(k.device)
return self.k_proj(k)
def project_v(self, v: torch.Tensor) -> torch.Tensor:
"""投影Value向量"""
if self.v_proj.weight.device != v.device:
self.v_proj = self.v_proj.to(v.device)
return self.v_proj(v)
class OnlineSVDCache:
"""
基于在线SVD的KV Cache
增量式更新SVD分解,支持流式推理
"""
def __init__(self, num_heads: int, head_dim: int, rank: int):
"""
Args:
num_heads: 注意力头数
head_dim: 头维度
rank: 保留的奇异值数量
"""
self.num_heads = num_heads
self.head_dim = head_dim
self.rank = rank
# 每个头的SVD分量
self.U = [None for _ in range(num_heads)] # 左奇异向量
self.S = [None for _ in range(num_heads)] # 奇异值
self.Vt = [None for _ in range(num_heads)] # 右奇异向量转置
self.current_len = 0
def update(self, k: torch.Tensor, v: torch.Tensor):
"""
更新SVD分解
Args:
k: Key向量 [batch, num_heads, seq_len, head_dim]
v: Value向量 [batch, num_heads, seq_len, head_dim]
"""
batch, num_heads, seq_len, head_dim = k.shape
for h in range(num_heads):
# 获取当前头的Key向量
k_head = k[:, h, :, :].squeeze(0).T # [head_dim, seq_len]
if self.U[h] is None:
# 首次更新,直接SVD
try:
U, S, Vt = torch.linalg.svd(k_head, full_matrices=False)
self.U[h] = U[:, :self.rank]
self.S[h] = S[:self.rank]
self.Vt[h] = Vt[:self.rank, :]
except:
# SVD失败时,使用幂迭代法
U, S, Vt = self._power_iteration_svd(k_head)
self.U[h] = U
self.S[h] = S
self.Vt[h] = Vt
else:
# 增量更新
self._incremental_update(h, k_head)
self.current_len += seq_len
def _incremental_update(self, head_idx: int, k_new: torch.Tensor):
"""
增量更新SVD
使用秩-1更新算法
Args:
head_idx: 头索引
k_new: 新Key向量 [head_dim, seq_len]
"""
U, S, Vt = self.U[head_idx], self.S[head_idx], self.Vt[head_idx]
# 计算新向量在现有基上的投影系数
coeffs = U.T @ k_new # [rank, new_seq_len]
# 残差
residual = k_new - U @ coeffs # [head_dim, new_seq_len]
# 构建增广矩阵
# [U | residual] 的SVD
try:
U_new, S_new, Vt_new = torch.linalg.svd(residual, full_matrices=False)
# 合并新旧基
U_combined = torch.cat([U, U_new[:, :1]], dim=1)
S_combined = torch.cat([S, S_new[:1]])
Vt_combined = torch.cat([Vt, Vt_new[:1, :]], dim=0)
# 重新截断到目标秩
self.U[head_idx] = U_combined[:, :self.rank]
self.S[head_idx] = S_combined[:self.rank]
self.Vt[head_idx] = Vt_combined[:self.rank, :]
except:
# 如果更新失败,保持现有分解
pass
def _power_iteration_svd(self, A: torch.Tensor, num_iter: int = 10) -> tuple:
"""幂迭代法计算SVD(适用于长宽比极端的矩阵)"""
m, n = A.shape
k = min(self.rank, min(m, n))
# 随机初始化
Q = torch.randn(n, k, device=A.device)
Q, _ = torch.linalg.qr(Q)
for _ in range(num_iter):
Z = A @ Q
Q, _ = torch.linalg.qr(Z)
Z = A.T @ Q
Q, _ = torch.linalg.qr(Z)
# 计算最终奇异值和向量
B = Q.T @ A
U_small, S, Vt = torch.linalg.svd(B, full_matrices=False)
U = Q @ U_small
return U[:, :k], S[:k], Vt[:k, :]
def get_projected_k(self, head_idx: int) -> torch.Tensor:
"""获取投影后的Key矩阵"""
if self.U[head_idx] is None:
return None
# 返回 U @ diag(S) 作为压缩后的表示
return self.U[head_idx] * self.S[head_idx].unsqueeze(0)
def get_cache(self, head_idx: int) -> tuple:
"""获取指定头的低秩分解"""
return self.U[head_idx], self.S[head_idx], self.Vt[head_idx]
class AdaptiveRankCache(nn.Module):
"""
自适应秩的KV Cache
根据局部重要性动态调整各部分的压缩秩
"""
def __init__(self, num_heads: int, head_dim: int, max_rank: int):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.max_rank = max_rank
# 各头的基础秩
self.base_rank = nn.Parameter(
torch.ones(num_heads) * (max_rank // 2)
)
# 秩调整网络
self.rank_adapter = nn.Sequential(
nn.Linear(1, max_rank // 2),
nn.GELU(),
nn.Linear(max_rank // 2, num_heads)
)
# KV缓存
self.k_cache = {}
self.v_cache = {}
def get_effective_rank(self, importance_score: torch.Tensor) -> torch.Tensor:
"""
根据重要性分数调整秩
Args:
importance_score: [num_heads] 重要性分数
Returns:
effective_rank: [num_heads] 各头的有效秩
"""
# 基础调整
delta_rank = self.rank_adapter(importance_score.unsqueeze(-1))
effective_rank = self.base_rank + delta_rank.squeeze(-1)
# 限制在有效范围内
return effective_rank.clamp(1, self.max_rank).long()6. 方法对比
6.1 性能对比矩阵
下表对比了三种主要KV Cache压缩方法的核心特性:
| 特性 | 量化方法 | 稀疏方法 | 低秩方法 |
|---|---|---|---|
| 压缩原理 | 降低表示精度 | 选择性丢弃 | 维度约简 |
| 典型压缩比 | 2x - 8x | 2x - 10x | 2x - 4x |
| 精度损失 | 量化误差 | 信息丢失 | 重构误差 |
| 计算开销 | 低 | 中等 | 中等 |
| 实现复杂度 | 低 | 中等 | 中等 |
| 兼容性 | 通用 | 需要设计丢弃策略 | 需要投影适配 |
| 典型方法 | INT8/FP16量化 | H2O, SnapKV | KVMem, LearnLM |
6.2 各方法优缺点分析
6.2.1 量化方法
优点:
- 实现简单,与模型结构无关
- 压缩比可精确控制
- 可与其他方法结合使用
- 硬件支持良好(INT8加速)
缺点:
- 存在量化误差,可能影响模型精度
- 对于极端值处理困难
- 不能改变KV矩阵的秩
6.2.2 稀疏方法
优点:
- 可以显著减少计算量(不仅节省内存)
- 保留最重要的信息
- 适合长序列场景
缺点:
- 需要设计有效的丢弃策略
- 可能丢失重要的远程依赖
- 实现复杂度较高
6.2.3 低秩方法
优点:
- 理论保证(最优低秩近似)
- 保持语义信息的完整性
- 适合高维数据压缩
缺点:
- 重构误差可能累积
- SVD计算开销较大
- 秩的选择需要调优
6.3 适用场景分析
| 场景 | 推荐方法 | 理由 |
|---|---|---|
| 极长序列(>100K) | 稀疏方法 + 量化 | 稀疏方法显著减少计算量 |
| 资源受限设备 | 混合精度量化 | 硬件加速友好 |
| 高质量生成 | 低秩方法 | 保持语义完整性 |
| 流式推理 | 在线低秩更新 | 支持增量计算 |
| 多模态场景 | 分层压缩 | 适应不同模态特点 |
7. 未来研究方向
7.1 自适应压缩
未来的KV Cache压缩技术将更加自适应,能够根据输入内容和任务动态调整压缩策略:
- 内容感知压缩:识别不同类型的token(如实体、关键词、噪声),采用不同压缩力度
- 任务感知压缩:根据下游任务调整压缩重点(如摘要任务保留首尾,检索任务保留中间)
- 上下文感知压缩:根据当前上下文选择性地保留或压缩特定信息
7.2 动态秩分配
当前的压缩方法通常使用统一的压缩参数(如固定的量化位数或低秩维度)。未来的研究将探索动态秩分配:
- 不同注意力头使用不同的压缩比
- 根据局部重要性动态调整秩
- 学习最优的秩分配策略
7.3 端到端优化
将KV Cache压缩与模型训练/微调过程结合,实现端到端优化:
- 在训练目标中加入压缩正则化项
- 学习专门为压缩设计的投影矩阵
- 通过蒸馏将压缩策略从大模型迁移到小模型
7.4 硬件协同设计
针对KV Cache压缩设计专用的硬件加速器:
- 支持多种压缩格式的混合计算
- 压缩/解压缩的专用流水线
- 近存计算架构优化
7.5 多模态扩展
将KV Cache压缩技术扩展到多模态场景:
- 视觉-语言模型的跨模态压缩
- 视频理解中的时空压缩
- 音频-文本联合建模
7.6 可证明的压缩界限
理论研究方面,探索KV Cache压缩的信息论界限:
- 给定任务性能要求下的最小压缩率
- 不同压缩方法的理论最优性证明
- 压缩误差与下游任务性能的关系
8. 总结
KV Cache压缩是解决大语言模型推理效率问题的关键技术。通过量化、稀疏和低秩三类方法的协同发展,我们可以在内存效率和模型性能之间找到更好的平衡点。
随着长上下文模型和高效推理需求的增长,KV Cache压缩技术将继续演进,朝着自适应、智能化和硬件协同的方向发展。未来的研究将更加注重理论与实践的结合,为构建更高效的LLM系统奠定基础。
参考资料
Footnotes
-
Ainslie J, Lee-Thorp J, de Jong M, et al. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints[J]. arXiv preprint arXiv:2305.13245, 2023. ↩
-
Zhang Z, Sheng Y, Zhou Y, et al. H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models[J]. arXiv preprint arXiv:2310.04625, 2023. ↩
-
Liu Z, Wang J, Dao T, et al. Deja Vu: Contextual Sparsity for Efficient LLMs at Inference Time[C]. International Conference on Machine Learning, 2023. ↩
-
Zhou J, Liu T, et al. Full Attention with Low-rank Mechanism for KV Cache Compression in Long-context LLM Inference[J]. arXiv, 2024. ↩