引言

KV Cache(Key-Value Cache)是Transformer模型推理优化的核心技术,用于存储注意力计算中的中间结果,避免重复计算。然而,随着序列长度增加和模型规模增大,KV Cache的内存消耗成为推理效率的主要瓶颈。

KV Cache的挑战

# KV Cache内存消耗分析
def kv_cache_memory_analysis(
    batch_size: int,
    seq_len: int,
    num_layers: int,
    num_heads: int,
    head_dim: int,
    dtype: str = "float16"
) -> dict:
    """
    分析KV Cache的内存消耗
    
    以LLaMA-7B为例:
    - num_layers = 32
    - num_heads = 32
    - head_dim = 128
    """
    bytes_per_element = {"float32": 4, "float16": 2, "int8": 1}[dtype]
    
    # 单个token的KV cache大小
    kv_per_token = 2 * num_layers * num_heads * head_dim * bytes_per_element
    
    # 整个序列的KV cache大小
    kv_total = kv_per_token * batch_size * seq_len
    
    # 转换为GB
    kv_gb = kv_total / (1024 ** 3)
    
    return {
        "per_token_bytes": kv_per_token,
        "per_token_gb": kv_per_token / (1024 ** 3),
        "total_gb": kv_gb,
        "per_layer_gb": kv_gb / num_layers
    }
 
 
# LLaMA-7B 序列长度与内存关系
configs = [
    (1, 2048, "LLaMA-7B"),
    (1, 4096, "LLaMA-7B"),
    (1, 8192, "LLaMA-7B"),
    (1, 32768, "LLaMA-7B"),
]
 
for bs, seq, model in configs:
    mem = kv_cache_memory_analysis(
        batch_size=bs,
        seq_len=seq,
        num_layers=32,
        num_heads=32,
        head_dim=128
    )
    print(f"{model} seq_len={seq}: KV Cache = {mem['total_gb']:.2f} GB")

输出示例:

LLaMA-7B seq_len=2048: KV Cache = 0.50 GB
LLaMA-7B seq_len=4096: KV Cache = 1.00 GB
LLaMA-7B seq_len=8192: KV Cache = 2.00 GB
LLaMA-7B seq_len=32768: KV Cache = 8.00 GB

1. KV Cache量化压缩

1.1 量化基础

import torch
import torch.nn.functional as F
 
class KVCacheQuantizer:
    """
    KV Cache量化器
    
    支持多种量化策略:
    1. INT8动态量化
    2. FP8量化
    3. K-V量化(Key和Value使用不同精度)
    """
    
    def __init__(self, num_bits=8):
        self.num_bits = num_bits
    
    def quantize(self, tensor: torch.Tensor) -> tuple:
        """
        量化tensor
        
        Args:
            tensor: [batch, num_heads, seq_len, head_dim]
        
        Returns:
            quantized: 量化后的tensor
            scale: 量化 scale
            zero_point: 零点(用于非对称量化)
        """
        if self.num_bits == 8:
            return self.quantize_int8(tensor)
        elif self.num_bits == 4:
            return self.quantize_int4(tensor)
        elif self.num_bits == 16:
            return self.quantize_fp16(tensor)
        else:
            raise ValueError(f"Unsupported num_bits: {self.num_bits}")
    
    def quantize_int8(self, tensor):
        """INT8量化"""
        # 对称量化:scale = max(|x|) / 127
        scale = tensor.abs().max() / 127.0
        
        # 量化
        quantized = torch.round(tensor / scale).clamp(-128, 127).to(torch.int8)
        
        return quantized, scale, None
    
    def quantize_int4(self, tensor):
        """INT4量化"""
        scale = tensor.abs().max() / 7.0
        quantized = torch.round(tensor / scale).clamp(-8, 7).to(torch.int4)
        
        return quantized, scale, None
    
    def dequantize(self, quantized, scale, zero_point=None):
        """反量化"""
        if zero_point is None:
            # 对称反量化
            return quantized.float() * scale
        else:
            # 非对称反量化
            return (quantized.float() - zero_point) * scale

1.2 逐Token量化

class PerTokenKVQuantizer:
    """
    逐Token量化
    
    每个token位置使用独立的scale,提高精度
    """
    
    def __init__(self, num_bits=8):
        self.num_bits = num_bits
        self.max_val = 2 ** (num_bits - 1) - 1
    
    def quantize(self, kv_cache: Dict[str, torch.Tensor]) -> Dict:
        """
        量化KV Cache
        
        Args:
            kv_cache: {"k": [B, H, L, D], "v": [B, H, L, D]}
        
        Returns:
            量化后的cache和scales
        """
        quantized_cache = {}
        scales = {}
        
        for key in ["k", "v"]:
            tensor = kv_cache[key]
            
            # 沿head_dim计算scale
            # [B, H, L, 1]
            scale = tensor.abs().amax(dim=-1, keepdim=True) / self.max_val
            scale = scale + 1e-8  # 防止除零
            
            # 量化
            q_tensor = torch.round(tensor / scale).clamp(
                -self.max_val - 1, self.max_val
            ).to(torch.int8)
            
            quantized_cache[key] = q_tensor
            scales[key] = scale
        
        return {
            "quantized_cache": quantized_cache,
            "scales": scales
        }
    
    def dequantize(self, quantized_cache, scales):
        """反量化"""
        dequantized = {}
        for key in ["k", "v"]:
            dequantized[key] = (
                quantized_cache[key].float() * scales[key]
            )
        return dequantized

1.3 混合精度KV Cache

class HybridPrecisionKVCache:
    """
    混合精度KV Cache
    
    策略:
    - 近期token:使用较高精度(如FP16)
    - 远期token:使用较低精度(如INT4)
    """
    
    def __init__(self, recent_ratio=0.3, recent_bits=16, old_bits=4):
        self.recent_ratio = recent_ratio
        self.recent_bits = recent_bits
        self.old_bits = old_bits
    
    def quantize(self, kv_cache, seq_len):
        """分层量化"""
        recent_len = int(seq_len * self.recent_ratio)
        old_len = seq_len - recent_len
        
        recent_kv = {
            "k": kv_cache["k"][..., :recent_len, :],
            "v": kv_cache["v"][..., :recent_len, :]
        }
        old_kv = {
            "k": kv_cache["k"][..., recent_len:, :],
            "v": kv_cache["v"][..., recent_len:, :]
        }
        
        # 近期使用高精度
        recent_quantized = self._quantize_with_bits(recent_kv, self.recent_bits)
        
        # 远期使用低精度
        old_quantized = self._quantize_with_bits(old_kv, self.old_bits)
        
        return {
            "recent": recent_quantized,
            "old": old_quantized,
            "recent_len": recent_len
        }
    
    def _quantize_with_bits(self, kv_cache, bits):
        """使用指定精度量化"""
        if bits == 16:
            return {"k": kv_cache["k"].half(), "v": kv_cache["v"].half()}
        elif bits == 8:
            return PerTokenKVQuantizer(8).quantize(kv_cache)
        elif bits == 4:
            return PerTokenKVQuantizer(4).quantize(kv_cache)

2. KV Cache淘汰策略

2.1 问题定义

序列位置:    [0    1    2    3    4    5    6    7    8    ...   1000]
              ├──────────────────┤├──────────┤
                 重要token          中等重要    不重要

并非所有历史token对当前预测同等重要,需要智能淘汰。

2.2 H2O: Heavy-Hitter Oracle

class H2OKVCacheManager:
    """
    H2O (Heavy-Hitter Oracle) 淘汰策略
    
    核心思想:基于注意力分数识别重要token(Heavy Hitters)
    只保留最重要的token
    
    论文: H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models
    """
    
    def __init__(self, cache_size_ratio=0.3):
        """
        Args:
            cache_size_ratio: 保留的cache比例
        """
        self.cache_size_ratio = cache_size_ratio
        self.kv_cache = {"k": [], "v": []}
        self.attention_weights = []
    
    def update(self, new_k, new_v, attn_weights):
        """
        更新KV Cache
        
        Args:
            new_k: 新增的key [1, num_heads, 1, head_dim]
            new_v: 新增的value [1, num_heads, 1, head_dim]
            attn_weights: 相对于历史token的注意力权重
        """
        # 存储注意力权重用于后续淘汰决策
        self.attention_weights.append(attn_weights.squeeze())
        
        # 追加新的KV
        self.kv_cache["k"].append(new_k)
        self.kv_cache["v"].append(new_v)
        
        # 超过缓存限制时进行淘汰
        total_len = len(self.kv_cache["k"])
        max_len = self._calculate_max_cache_len()
        
        if total_len > max_len:
            self._evict()
    
    def _calculate_max_cache_len(self):
        """计算最大缓存长度"""
        # 估算初始最大长度
        return max(100, int(len(self.attention_weights) * self.cache_size_ratio))
    
    def _evict(self):
        """淘汰不重要的token"""
        # 1. 聚合累积注意力分数
        # attention_weights[i] 表示位置i对当前预测的累积重要性
        cumulative_scores = torch.stack([
            w.sum() for w in self.attention_weights
        ])
        
        # 2. 计算每个位置作为"桥梁"的重要性
        # Heavy Hitter分数 = 累积注意力 × 条件注意力
        bridge_scores = self._compute_bridge_scores(cumulative_scores)
        
        # 3. 淘汰分数最低的token
        num_to_evict = len(self.kv_cache["k"]) - self._calculate_max_cache_len()
        keep_indices = torch.topk(bridge_scores, len(self.kv_cache["k"]) - num_to_evict).indices
        keep_indices = sorted(keep_indices.tolist())
        
        # 4. 更新缓存
        self.kv_cache["k"] = [self.kv_cache["k"][i] for i in keep_indices]
        self.kv_cache["v"] = [self.kv_cache["v"][i] for i in keep_indices]
        self.attention_weights = [self.attention_weights[i] for i in keep_indices]
    
    def _compute_bridge_scores(self, cumulative_scores):
        """
        计算Bridge分数
        
        Heavy Hitter不仅自身被关注,还能帮助连接其他token
        """
        n = len(cumulative_scores)
        
        # 前向累积和后向累积
        prefix_max = cumulative_scores.cumsum(dim=0)
        suffix_max = cumulative_scores.flip(0).cumsum(dim=0).flip(0)
        
        # Bridge分数:当前位置作为连接桥梁的重要性
        bridge_scores = torch.zeros(n)
        for i in range(n):
            if i > 0:
                left_importance = prefix_max[i - 1]
            else:
                left_importance = 0
            
            if i < n - 1:
                right_importance = suffix_max[i + 1]
            else:
                right_importance = 0
            
            # 当前位置的重要性 × 它能连接的两侧重要性
            bridge_scores[i] = cumulative_scores[i] * (left_importance + right_importance)
        
        return bridge_scores

2.3 PyramidKV: 金字塔式缓存

class PyramidKVCacheManager:
    """
    PyramidKV 淘汰策略
    
    核心思想:不同层使用不同的压缩率,形成金字塔结构
    浅层保留更多token,深层使用更激进的压缩
    
    论文: PyramidKV: Dynamic Memory Allocation for Context Window Extension
    """
    
    def __init__(self, num_layers, base_cache_len=512, compression_schedule=None):
        self.num_layers = num_layers
        self.base_cache_len = base_cache_len
        
        # 压缩调度:各层的缓存长度比例
        if compression_schedule is None:
            # 默认:浅层保留多,深层保留少
            self.compression_schedule = self._default_schedule()
        
        self.layer_caches = [
            {"k": [], "v": []}
            for _ in range(num_layers)
        ]
    
    def _default_schedule(self):
        """默认压缩调度"""
        ratios = []
        for i in range(self.num_layers):
            # 线性衰减:浅层1.0,深层0.2
            ratio = 1.0 - 0.8 * (i / self.num_layers)
            ratios.append(ratio)
        return ratios
    
    def update(self, layer_idx, new_k, new_v, importance_scores=None):
        """更新特定层的KV Cache"""
        cache = self.layer_caches[layer_idx]
        max_len = int(self.base_cache_len * self.compression_schedule[layer_idx])
        
        # 追加新token
        cache["k"].append(new_k)
        cache["v"].append(new_v)
        
        # 超过限制时淘汰
        if len(cache["k"]) > max_len:
            self._evict_layer(layer_idx, importance_scores)
    
    def _evict_layer(self, layer_idx, importance_scores):
        """淘汰特定层的token"""
        cache = self.layer_caches[layer_idx]
        max_len = int(self.base_cache_len * self.compression_schedule[layer_idx])
        
        if importance_scores is None:
            # 均匀淘汰
            num_to_keep = max_len
        else:
            # 基于重要性淘汰
            scores = torch.tensor(importance_scores[-len(cache["k"]):])
            keep_indices = torch.topk(scores, max_len).indices.tolist()
            keep_indices = sorted(keep_indices)
            
            cache["k"] = [cache["k"][i] for i in keep_indices]
            cache["v"] = [cache["v"][i] for i in keep_indices]
            return
        
        # 均匀采样保留
        total_len = len(cache["k"])
        indices = torch.randperm(total_len)[:max_len].sort().values.tolist()
        cache["k"] = [cache["k"][i] for i in indices]
        cache["v"] = [cache["v"][i] for i in indices]

2.4 StreamingLLM式缓存

class StreamingLLMCache:
    """
    StreamingLLM 缓存策略
    
    核心思想:保留attention sink(初始几个token)和最近的token
    
    观察:
    1. 初始token(特别是[CLS]或第一个token)对所有位置有高注意力
    2. 最近token对当前预测最重要
    """
    
    def __init__(self, num_sink_tokens=4, num_recent_tokens=512):
        self.num_sink_tokens = num_sink_tokens
        self.num_recent_tokens = num_recent_tokens
        self.sink_k = []
        self.sink_v = []
        self.recent_k = []
        self.recent_v = []
    
    def update(self, new_k, new_v):
        """更新缓存"""
        # Sink token只在开始时添加
        if len(self.sink_k) < self.num_sink_tokens:
            self.sink_k.append(new_k)
            self.sink_v.append(new_v)
        
        # 最近token动态更新
        self.recent_k.append(new_k)
        self.recent_v.append(new_v)
        
        # 限制recent长度
        if len(self.recent_k) > self.num_recent_tokens:
            self.recent_k = self.recent_k[-self.num_recent_tokens:]
            self.recent_v = self.recent_v[-self.num_recent_tokens:]
    
    def get_cache(self):
        """获取完整缓存用于注意力计算"""
        all_k = self.sink_k + self.recent_k
        all_v = self.sink_v + self.recent_v
        
        # 拼接
        if all_k:
            k = torch.cat(all_k, dim=2)
            v = torch.cat(all_v, dim=2)
        else:
            k = torch.empty(0)
            v = torch.empty(0)
        
        return {"k": k, "v": v}

3. 前缀缓存

3.1 概念与动机

class PrefixCacheManager:
    """
    前缀缓存管理器
    
    场景:
    - 系统提示(system prompt)对于所有请求相同
    - few-shot examples 对于相似请求可能相同
    - 可以跨请求共享这些公共前缀的KV Cache
    
    优势:
    1. 减少重复计算
    2. 降低首token延迟
    3. 共享内存节省
    """
    
    def __init__(self, max_cache_size_gb=16):
        self.max_cache_size = max_cache_size_gb * 1024 ** 3  # bytes
        self.current_size = 0
        self.prefix_hashes = {}  # hash -> kv_cache
        self.prefix_metadata = {}  # hash -> metadata
    
    def hash_prefix(self, tokens: List[int]) -> str:
        """计算前缀的hash"""
        import hashlib
        token_bytes = bytes(tokens)
        return hashlib.sha256(token_bytes).hexdigest()[:16]
    
    def get_cached_prefix(self, tokens: List[int]) -> Optional[Dict]:
        """获取缓存的前缀"""
        prefix_hash = self.hash_prefix(tokens)
        
        if prefix_hash in self.prefix_hashes:
            # 更新访问时间
            self.prefix_metadata[prefix_hash]["last_access"] = time.time()
            self.prefix_metadata[prefix_hash]["access_count"] += 1
            return self.prefix_hashes[prefix_hash]
        
        return None
    
    def cache_prefix(self, tokens: List[int], kv_cache: Dict, metadata: Dict = None):
        """缓存前缀"""
        prefix_hash = self.hash_prefix(tokens)
        
        # 计算缓存大小
        cache_size = self._estimate_cache_size(kv_cache)
        
        # 检查空间
        if self.current_size + cache_size > self.max_cache_size:
            self._evict_lru()
        
        # 存储
        self.prefix_hashes[prefix_hash] = kv_cache
        self.prefix_metadata[prefix_hash] = {
            "size": cache_size,
            "token_len": len(tokens),
            "created": time.time(),
            "last_access": time.time(),
            "access_count": 0,
            "metadata": metadata or {}
        }
        self.current_size += cache_size
    
    def _evict_lru(self):
        """淘汰最近最少使用的缓存"""
        if not self.prefix_metadata:
            return
        
        # 按最后访问时间排序
        sorted_hashes = sorted(
            self.prefix_metadata.keys(),
            key=lambda h: self.prefix_metadata[h]["last_access"]
        )
        
        # 淘汰直到有足够空间
        while self.prefix_metadata and self.current_size > self.max_cache_size * 0.8:
            lru_hash = sorted_hashes[0]
            evicted_size = self.prefix_metadata[lru_hash]["size"]
            
            del self.prefix_hashes[lru_hash]
            del self.prefix_metadata[lru_hash]
            
            self.current_size -= evicted_size
            sorted_hashes = sorted_hashes[1:]
    
    def _estimate_cache_size(self, kv_cache: Dict) -> int:
        """估算缓存大小"""
        # 假设float16
        bytes_per_element = 2
        
        total_elements = 0
        for key in ["k", "v"]:
            if key in kv_cache and kv_cache[key] is not None:
                total_elements += kv_cache[key].numel()
        
        return total_elements * bytes_per_element

3.2 PagedAttention与vLLM缓存

class PagedKVCache:
    """
    PagedAttention KV Cache管理
    
    将KV Cache分页存储,类似操作系统的虚拟内存
    优势:
    1. 动态分配,无需预分配大内存
    2. 碎片少,利用率高
    3. 支持更长的上下文
    """
    
    def __init__(self, block_size=16, num_blocks=1024):
        self.block_size = block_size
        self.num_blocks = num_blocks
        
        # 物理块存储
        self.blocks = [
            {"k": torch.zeros(1, num_heads, block_size, head_dim),
             "v": torch.zeros(1, num_heads, block_size, head_dim)}
            for _ in range(num_blocks)
        ]
        
        # 逻辑块映射
        self.block_mapping = {}  # sequence_id -> [block_ids]
        self.block_refcount = {}  # block_id -> refcount
        self.available_blocks = set(range(num_blocks))
    
    def allocate_sequence(self, sequence_id: int, max_length: int) -> bool:
        """为新序列分配逻辑块"""
        num_blocks_needed = (max_length + self.block_size - 1) // self.block_size
        
        # 检查是否有足够的物理块
        if len(self.available_blocks) < num_blocks_needed:
            return False
        
        # 分配块
        allocated_blocks = []
        for _ in range(num_blocks_needed):
            block_id = self.available_blocks.pop()
            allocated_blocks.append(block_id)
            self.block_refcount[block_id] = 1
        
        self.block_mapping[sequence_id] = allocated_blocks
        return True
    
    def append_tokens(self, sequence_id: int, new_k: torch.Tensor, new_v: torch.Tensor):
        """追加token到序列"""
        if sequence_id not in self.block_mapping:
            raise ValueError(f"Sequence {sequence_id} not allocated")
        
        block_ids = self.block_mapping[sequence_id]
        current_len = len(block_ids) * self.block_size
        
        # 计算应该写入的块和位置
        for i, token_idx in enumerate(range(new_k.shape[2])):
            block_idx = (current_len + i) // self.block_size
            offset = (current_len + i) % self.block_size
            
            block_id = block_ids[block_idx]
            block = self.blocks[block_id]
            
            block["k"][..., offset:offset+1, :] = new_k[..., i:i+1, :]
            block["v"][..., offset:offset+1, :] = new_v[..., i:i+1, :]
    
    def get_sequence_cache(self, sequence_id: int) -> Dict:
        """获取序列的完整KV Cache"""
        if sequence_id not in self.block_mapping:
            return {"k": None, "v": None}
        
        block_ids = self.block_mapping[sequence_id]
        
        # 收集所有块
        all_k = [self.blocks[bid]["k"] for bid in block_ids]
        all_v = [self.blocks[bid]["v"] for bid in block_ids]
        
        return {
            "k": torch.cat(all_k, dim=2),
            "v": torch.cat(all_v, dim=2)
        }
    
    def free_sequence(self, sequence_id: int):
        """释放序列占用的块"""
        if sequence_id not in self.block_mapping:
            return
        
        for block_id in self.block_mapping[sequence_id]:
            self.block_refcount[block_id] -= 1
            if self.block_refcount[block_id] == 0:
                self.available_blocks.add(block_id)
        
        del self.block_mapping[sequence_id]

4. 内存优化技术

4.1 内存池管理

class KVCacheMemoryPool:
    """
    KV Cache内存池
    
    预分配固定大小的内存池,避免动态分配开销
    """
    
    def __init__(self, max_batch_size, max_seq_len, num_layers, num_heads, head_dim):
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len
        self.num_layers = num_layers
        
        # 预分配连续内存
        self.k_pool = torch.zeros(
            max_batch_size, num_layers, num_heads, max_seq_len, head_dim,
            dtype=torch.float16,
            pin_memory=True  # 用于CPU-GPU传输优化
        )
        self.v_pool = torch.zeros_like(self.k_pool)
        
        # 分配状态追踪
        self.allocations = {}  # sequence_id -> {"k_ptr": offset, "v_ptr": offset, "length": 0}
        self.available_slots = set(range(max_batch_size))
    
    def allocate(self, sequence_id: int) -> bool:
        """为新序列分配内存槽"""
        if not self.available_slots:
            return False
        
        slot = self.available_slots.pop()
        self.allocations[sequence_id] = {
            "slot": slot,
            "length": 0,
            "max_length": self.max_seq_len
        }
        return True
    
    def append(self, sequence_id: int, new_k, new_v):
        """追加新token"""
        if sequence_id not in self.allocations:
            raise ValueError(f"Sequence {sequence_id} not allocated")
        
        alloc = self.allocations[sequence_id]
        slot = alloc["slot"]
        offset = alloc["length"]
        
        # 写入预分配的内存
        self.k_pool[slot, :, :, offset:offset+new_k.shape[2], :] = new_k
        self.v_pool[slot, :, :, offset:offset+new_v.shape[2], :] = new_v
        
        alloc["length"] += new_k.shape[2]
    
    def get(self, sequence_id: int) -> Dict:
        """获取序列的KV Cache"""
        if sequence_id not in self.allocations:
            return None
        
        alloc = self.allocations[sequence_id]
        slot = alloc["slot"]
        length = alloc["length"]
        
        return {
            "k": self.k_pool[slot, :, :, :length, :],
            "v": self.v_pool[slot, :, :, :length, :]
        }
    
    def free(self, sequence_id: int):
        """释放序列的内存"""
        if sequence_id not in self.allocations:
            return
        
        slot = self.allocations[sequence_id]["slot"]
        self.available_slots.add(slot)
        del self.allocations[sequence_id]

4.2 卸载与重载策略

class KVColdStorageManager:
    """
    KV Cache冷热存储管理
    
    将不活跃的KV Cache卸载到CPU内存或NVMe存储
    """
    
    def __init__(self, 
                 gpu_cache_gb=16,
                 cpu_cache_gb=64,
                 nvme_cache_gb=512):
        self.gpu_cache_limit = gpu_cache_gb * 1024**3
        self.cpu_cache_limit = cpu_cache_gb * 1024**3
        self.nvme_cache_limit = nvme_cache_gb * 1024**3
        
        # 存储层级
        self.gpu_cache = {}  # sequence_id -> (k_tensor, v_tensor)
        self.cpu_cache = {}  # sequence_id -> (k_tensor, v_tensor)
        self.nvme_cache = {}  # sequence_id -> file_path
        
        # 访问追踪
        self.access_log = {}  # sequence_id -> last_access_time
        self.current_gpu_size = 0
    
    def store(self, sequence_id, kv_cache, storage="gpu"):
        """存储KV Cache"""
        cache_size = self._estimate_size(kv_cache)
        
        if storage == "gpu":
            self._store_to_gpu(sequence_id, kv_cache, cache_size)
        elif storage == "cpu":
            self._store_to_cpu(sequence_id, kv_cache, cache_size)
        elif storage == "nvme":
            self._store_to_nvme(sequence_id, kv_cache, cache_size)
        
        self.access_log[sequence_id] = time.time()
    
    def retrieve(self, sequence_id, target="gpu"):
        """获取KV Cache,必要时进行层级迁移"""
        # 检查是否在目标层级
        if target == "gpu" and sequence_id in self.gpu_cache:
            return self.gpu_cache[sequence_id]
        
        # 从其他层级获取并迁移
        kv_cache = None
        
        if sequence_id in self.cpu_cache:
            kv_cache = self.cpu_cache[sequence_id]
            source = "cpu"
        elif sequence_id in self.nvme_cache:
            kv_cache = self._load_from_nvme(sequence_id)
            source = "nvme"
        else:
            return None
        
        # 迁移到目标层级
        if target == "gpu":
            self._store_to_gpu(sequence_id, kv_cache, self._estimate_size(kv_cache))
        
        # 从源层级移除
        if source == "cpu":
            del self.cpu_cache[sequence_id]
        elif source == "nvme":
            self._remove_from_nvme(sequence_id)
        
        return kv_cache
    
    def _store_to_gpu(self, sequence_id, kv_cache, size):
        """存储到GPU"""
        # 检查GPU内存
        if self.current_gpu_size + size > self.gpu_cache_limit:
            self._evict_to_cpu()
        
        self.gpu_cache[sequence_id] = kv_cache
        self.current_gpu_size += size
    
    def _evict_to_cpu(self):
        """驱逐到CPU"""
        if not self.gpu_cache:
            return
        
        # LRU驱逐
        lru_seq = min(self.access_log.keys(), 
                     key=lambda sid: self.access_log.get(sid, 0))
        
        kv_cache = self.gpu_cache[lru_seq]
        self._store_to_cpu(lru_seq, kv_cache, self._estimate_size(kv_cache))
        
        del self.gpu_cache[lru_seq]
        self.current_gpu_size -= self._estimate_size(kv_cache)

5. 实践与评估

5.1 量化策略对比

class KVCacheComparison:
    """KV Cache优化策略对比"""
    
    strategies = {
        "fp16_full": {
            "memory_gb": 16.0,
            "latency_ms": 100,
            "accuracy": 1.0,
            "description": "完整FP16缓存"
        },
        "int8_dynamic": {
            "memory_gb": 8.0,
            "latency_ms": 95,
            "accuracy": 0.995,
            "description": "INT8动态量化"
        },
        "int4_hybrid": {
            "memory_gb": 4.0,
            "latency_ms": 90,
            "accuracy": 0.98,
            "description": "INT4混合精度(近期FP16,远期INT4)"
        },
        "h2o_30pct": {
            "memory_gb": 4.8,
            "latency_ms": 88,
            "accuracy": 0.975,
            "description": "H2O保留30%缓存"
        },
        "pyramid_kv": {
            "memory_gb": 5.2,
            "latency_ms": 92,
            "accuracy": 0.985,
            "description": "PyramidKV金字塔压缩"
        },
        "streamingllm": {
            "memory_gb": 2.0,
            "latency_ms": 85,
            "accuracy": 0.96,
            "description": "StreamingLLM(4 sink + 512 recent)"
        }
    }

5.2 最佳实践建议

def recommend_kv_cache_strategy(
    max_seq_len: int,
    available_memory_gb: float,
    accuracy_requirement: float = 0.98
) -> Dict:
    """
    推荐KV Cache优化策略
    """
    recommendations = []
    
    # 1. 首先考虑前缀缓存
    if max_seq_len > 1000:
        recommendations.append({
            "strategy": "prefix_caching",
            "priority": 1,
            "saving": "50-80%",
            "implementation": "相对简单,推荐优先实施"
        })
    
    # 2. 根据内存和精度要求选择
    if accuracy_requirement >= 0.995:
        if available_memory_gb >= 16:
            recommendations.append({
                "strategy": "fp16_full",
                "priority": 2,
                "saving": "baseline"
            })
        else:
            recommendations.append({
                "strategy": "int8_dynamic",
                "priority": 2,
                "saving": "50%",
                "implementation": "Per-Token INT8量化"
            })
    
    elif accuracy_requirement >= 0.98:
        if available_memory_gb < 8:
            recommendations.append({
                "strategy": "pyramid_kv",
                "priority": 2,
                "saving": "60-70%",
                "implementation": "分层压缩"
            })
        else:
            recommendations.append({
                "strategy": "int4_hybrid",
                "priority": 2,
                "saving": "75%",
                "implementation": "近期FP16 + 远期INT4"
            })
    
    else:
        recommendations.append({
            "strategy": "h2o + int4",
            "priority": 2,
            "saving": "80%+",
            "implementation": "H2O淘汰 + INT4量化"
        })
    
    # 3. PagedAttention
    recommendations.append({
        "strategy": "paged_attention",
        "priority": 3,
        "saving": "20-40%",
        "implementation": "vLLM等框架支持"
    })
    
    return recommendations

6. 总结

KV Cache优化是LLM推理优化的关键环节:

技术内存节省精度影响复杂度
Per-Token INT850%<1%
INT4量化75%2-5%
H2O淘汰50-70%1-5%
PyramidKV60-70%1-3%
前缀缓存50-80%0%
PagedAttention20-40%0%

实际应用中通常组合使用多种技术以达到最优效果。


参考资料