EliteKV - RoPE频率选择与低秩投影

1. 概述

EliteKV是一种针对 Rotary Position Embedding(RoPE)优化的KV Cache压缩方法。其核心思想是:通过对RoPE编码的频率特性进行分析,识别出包含关键信息的频率分量,然后仅保留这些”精英”(Elite)频率,从而实现高效的KV Cache压缩。

在现代大语言模型中,RoPE已成为主流的位置编码方案。与绝对位置编码和相对位置编码相比,RoPE通过旋转操作将位置信息编码到Query和Key向量中,能够自然地处理任意长度的序列。然而,RoPE的独特性质也为KV Cache压缩带来了新的机遇和挑战。1

2. RoPE位置编码问题分析

2.1 RoPE的基本原理

RoPE的核心思想是将位置信息通过旋转矩阵编码到词嵌入中。对于位置 处的Query向量 和位置 处的Key向量 ,它们的旋转编码为:

其中旋转矩阵 定义为:

其中旋转频率 的选择遵循:

这个设计确保了不同维度使用不同的旋转频率,低频维度捕获长程依赖,高频维度捕获细节信息。

2.2 RoPE在注意力计算中的作用

在应用RoPE后,注意力分数的计算变为:

展开为:

这个公式揭示了一个关键洞察:注意力分数仅依赖于Query和Key之间的位置差 ,而非它们的绝对位置。这是RoPE能够处理任意长度序列的根本原因。

2.3 RoPE压缩的独特挑战

传统的KV Cache压缩方法(如直接量化或稀疏化)往往忽略了一个重要事实:RoPE编码使得不同频率的维度承载不同类型的信息。

频率类型维度范围携带信息压缩敏感度
高频分量局部细节、细粒度模式
中频分量中程依赖、语义结构
低频分量长程依赖、全局信息

如果对所有维度采用统一的压缩策略,可能导致关键信息的丢失。例如,对高频维度进行过度压缩会损害模型捕获局部模式的能力。

3. 频率选择机制

3.1 频率能量分析

EliteKV首先对KV向量进行频率能量分析。具体来说,对于已缓存的Key向量 ,计算各频率分量的能量:

其中 是第 个token在第 个频率对的分量。

通过分析 的分布,可以识别出:

  1. 高能量频率:携带较多信息的频率
  2. 低能量频率:可能是噪声或冗余信息

3.2 动态频率阈值

EliteKV采用动态频率阈值来确定保留哪些频率分量:

其中 是可调阈值参数。这种基于中值的自适应方法可以适应不同的输入分布。

更精细的阈值选择可以基于累积能量:

即选择能够覆盖95%总能量的最小频率集合。

3.3 频率分组策略

为了进一步提升效率,EliteKV将相邻频率对分组,并在组级别进行选择:

频率索引:    0  1  2  3  4  5  6  7  8  9  10 11 ...
频率对:     (0,1) (2,3) (4,5) (6,7) (8,9) (10,11) ...

组划分:     [  组0  ] [  组1  ] [  组2  ] [  组3  ] ...
组能量:      E_g0    E_g1    E_g2    E_g3    ...

选择结果:     ✓        ✗        ✓        ✓

组内所有频率要么全部保留,要么全部丢弃,这简化了压缩后的注意力计算。

3.4 基于注意力的频率重要性

除了能量分析,EliteKV还考虑各频率对注意力计算的贡献:

量化了第 个频率对最终注意力输出的影响程度。高 的频率对应该被优先保留。

4. 联合低秩投影设计

4.1 频率感知投影

基于频率选择的结果,EliteKV设计了频率感知的低秩投影。对于第 个频率对,定义投影权重:

其中 是衰减系数。低频分量使用较大的衰减系数,高频分量使用较小的衰减系数。

投影后的向量表示为:

4.2 结构化低秩约束

为了更好地利用低秩特性,EliteKV引入了结构化约束。假设原始KV矩阵 ,我们希望找到一个低秩近似

考虑频率分解:

其中 是第 个频率对对应的子矩阵, 是由旋转角度 定义的块对角矩阵。

EliteKV的优化目标是:

其中 是核范数(鼓励低秩), 是各频率对应的因子。

4.3 投影矩阵的端到端学习

EliteKV还探索了端到端学习的投影矩阵。与随机投影或SVD分解不同,投影矩阵通过下游任务的损失函数进行优化:

其中:

  • 是标准的语言模型负对数似然损失
  • 是压缩正则化项

压缩正则化项定义为:

其中 是投影矩阵。通过调整 ,可以控制压缩强度。

4.4 投影矩阵的初始化

投影矩阵的初始化对最终性能有重要影响。EliteKV提出了两种初始化策略:

4.4.1 基于SVD的初始化

  1. 对历史KV矩阵进行SVD分解
  2. 选择前 个奇异值对应的奇异向量
  3. 使用这些向量初始化投影矩阵
def svd_init(K, rank):
    """基于SVD的投影矩阵初始化"""
    U, S, Vt = torch.linalg.svd(K, full_matrices=False)
    
    # 选择top-r个分量
    P = Vt[:rank, :]  # [rank, d]
    
    # 归一化
    P = P / P.norm(dim=-1, keepdim=True)
    
    return P

4.4.2 基于频率能量分布的初始化

  1. 计算各频率分量的能量
  2. 将能量归一化为概率分布
  3. 根据 分配投影矩阵的权重
def energy_based_init(K, rank):
    """基于能量分布的投影矩阵初始化"""
    d = K.shape[-1]
    num_pairs = d // 2
    
    # 计算各频率对的能量
    energies = []
    for i in range(num_pairs):
        pair_energy = (K[:, :, 2*i]**2 + K[:, :, 2*i+1]**2).mean()
        energies.append(pair_energy.item())
    
    # 归一化为概率
    probs = torch.tensor(energies) / sum(energies)
    
    # 根据能量分配权重到不同秩
    # 高能量频率获得更大的权重
    P = torch.zeros(rank, d)
    for i in range(rank):
        freq_idx = i % num_pairs
        P[i, 2*freq_idx] = probs[freq_idx].sqrt()
        P[i, 2*freq_idx + 1] = probs[freq_idx].sqrt()
    
    return P

5. 与标准压缩方法对比

5.1 量化方法对比

特性标准量化(INT8)标准量化(INT4)EliteKV
压缩比4x8x4x - 8x
精度损失轻微中等最小
频率感知
硬件支持优秀良好良好
RoPE兼容性被动影响被动影响主动优化

EliteKV相比标准量化方法的优势在于:它能够识别并保留对RoPE编码最重要的频率分量,而不仅仅是降低数值精度。

5.2 稀疏方法对比

特性SnapKVH2OStreamingLLMEliteKV
压缩策略键值聚类累积注意力局部性保留频率选择
压缩粒度Token级Token级Token级频率级
RoPE适配间接间接直接深度适配
信息保留语义相似注意力热点局部模式频率能量

H2O和StreamingLLM通过保留特定位置的token来实现压缩,而EliteKV通过保留特定频率的分量来实现压缩。这两种策略可以互补使用。

5.3 其他低秩方法对比

特性KVMemLearnLM标准SVDEliteKV
压缩基础键值关联对比学习矩阵分解频率分析
投影方式键相关任务驱动无监督频率感知
计算开销低-中
精度保持优秀优秀良好优秀

6. 实验结果:长上下文任务

6.1 实验设置

EliteKV在多个长上下文基准任务上进行了评估:

  1. PassKey Retrieval:在长文档中检索隐藏的密钥
  2. NarrativeQA:阅读理解任务
  3. LongBench:综合长上下文基准
  4. Needle in a Haystack:大海捞针测试

模型选择:LLaMA-2-7B-Chat、LLaMA-3-8B-Instruct

基线对比:H2O、SnapKV、StreamingLLM、全量KV Cache

6.2 PassKey Retrieval结果

在PassKey Retrieval任务中,模型需要在包含多个无关段落的文档中找到隐藏的密码。

序列长度Full KVH2O (30%)SnapKV (30%)EliteKV (30%)
16K99.2%94.1%96.8%98.7%
32K98.7%89.3%93.2%97.9%
64K97.1%78.6%87.4%95.2%
128K92.3%61.2%72.8%91.4%

EliteKV在所有序列长度上都显著优于基线方法。在128K长度下,EliteKV相比H2O提升了30个百分点以上。

6.3 频率保留分析

为了验证频率选择的有效性,我们分析了不同压缩方法下各频率分量的保留程度:

频率索引:  0  2  4  6  8  10 12 14 16 18 20 22 24 26 28 30
           |  |  |  |  |  |  |  |  |  |  |  |  |  |  |  |
Full KV:   ■  ■  ■  ■  ■  ■  ■  ■  ■  ■  ■  ■  ■  ■  ■  ■
H2O:       ■  ■  ■  ■  □  □  □  □  □  □  □  □  □  □  □  □
EliteKV:   ■  ■  ■  ■  ■  ■  □  □  ■  ■  □  □  ■  ■  □  □

■ = 高保真度保留   □ = 低保真度或丢弃

EliteKV保留了包含关键信息的低频和高能量频率分量,同时丢弃了噪声主导的高频分量。

6.4 内存与吞吐量权衡

方法压缩率内存占用吞吐量提升精度损失
Full KV1x100%1x0%
INT8量化4x25%1.8x~1%
H2O (30%)3.3x30%2.5x~5%
SnapKV (30%)3.3x30%2.3x~3%
EliteKV (50%)2x50%2.0x~0.5%
EliteKV (30%)3.3x30%2.8x~1.5%

EliteKV在相同压缩率下实现了最佳的精度-效率权衡。

7. PyTorch实现

7.1 核心实现

以下是EliteKV的完整PyTorch实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional
import math
 
 
class FrequencyAnalyzer:
    """
    频率分析器
    
    分析KV向量的频率分布,识别重要频率分量
    """
    
    def __init__(self, head_dim: int, base_freq: float = 10000.0):
        """
        Args:
            head_dim: 头的维度
            base_freq: 基础旋转频率
        """
        self.head_dim = head_dim
        self.num_pairs = head_dim // 2
        self.base_freq = base_freq
        
        # 预计算旋转频率
        self.theta = self._compute_thetas()
    
    def _compute_thetas(self) -> torch.Tensor:
        """计算各频率对的旋转角度"""
        thetas = []
        for i in range(self.num_pairs):
            theta_i = self.base_freq * math.exp(-2 * math.log(self.base_freq) * i / self.head_dim)
            thetas.append(theta_i)
        return torch.tensor(thetas)
    
    def compute_frequency_energy(self, k: torch.Tensor) -> torch.Tensor:
        """
        计算各频率对的能量
        
        Args:
            k: Key向量 [batch, num_heads, seq_len, head_dim]
        
        Returns:
            energies: 各频率对的能量 [batch, num_heads, num_pairs]
        """
        batch, num_heads, seq_len, head_dim = k.shape
        
        # 重塑为频率对格式
        k_pairs = k.view(batch, num_heads, seq_len, self.num_pairs, 2)
        
        # 计算每个频率对的能量
        energies = (k_pairs ** 2).sum(dim=-1)  # [batch, num_heads, seq_len, num_pairs]
        
        # 在序列维度上求和
        energies = energies.sum(dim=2)  # [batch, num_heads, num_pairs]
        
        return energies
    
    def select_elite_frequencies(self, energies: torch.Tensor, 
                                   threshold_ratio: float = 0.5) -> torch.Tensor:
        """
        选择精英频率
        
        Args:
            energies: 频率能量 [batch, num_heads, num_pairs]
            threshold_ratio: 阈值比例
        
        Returns:
            selected_mask: 选择掩码 [batch, num_heads, num_pairs]
        """
        # 计算阈值
        median_energy = energies.median(dim=-1, keepdim=True)[0]
        threshold = median_energy * threshold_ratio
        
        # 选择能量高于阈值的频率
        selected_mask = (energies > threshold).float()
        
        return selected_mask
    
    def get_frequency_weights(self, selected_mask: torch.Tensor, 
                               decay_rate: float = 0.1) -> torch.Tensor:
        """
        生成频率权重
        
        Args:
            selected_mask: 选择掩码
            decay_rate: 未选中频率的衰减率
        
        Returns:
            weights: 频率权重
        """
        # 未选中频率给予小的非零权重,保留部分信息
        weights = selected_mask + (1 - selected_mask) * decay_rate
        return weights
 
 
class EliteKVCache:
    """
    EliteKV缓存
    
    基于频率选择的低秩KV Cache
    """
    
    def __init__(self, 
                 num_heads: int, 
                 head_dim: int,
                 max_seq_len: int,
                 compression_ratio: float = 0.3,
                 base_freq: float = 10000.0,
                 use_structured: bool = True):
        """
        Args:
            num_heads: 注意力头数
            head_dim: 头维度
            max_seq_len: 最大序列长度
            compression_ratio: 压缩比例(保留多少)
            base_freq: RoPE基础频率
            use_structured: 是否使用结构化投影
        """
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        self.compression_ratio = compression_ratio
        
        # 频率分析器
        self.freq_analyzer = FrequencyAnalyzer(head_dim, base_freq)
        self.num_pairs = head_dim // 2
        
        # 频率选择掩码
        self.register_buffer('freq_mask', torch.zeros(num_heads, self.num_pairs))
        
        # KV缓存
        self.k_cache = torch.zeros(max_seq_len, num_heads, head_dim)
        self.v_cache = torch.zeros(max_seq_len, num_heads, head_dim)
        
        # 缓存长度
        self.current_len = 0
        
        # 结构化投影
        self.use_structured = use_structured
        if use_structured:
            # 频率对的旋转角度
            self.register_buffer('rot_matrices', self._create_rot_matrices())
    
    def _create_rot_matrices(self) -> torch.Tensor:
        """创建旋转矩阵(用于结构化投影)"""
        matrices = []
        for i in range(self.num_pairs):
            theta = 10000.0 * math.exp(-2 * math.log(10000.0) * i / self.head_dim)
            cos_t = math.cos(theta)
            sin_t = math.sin(theta)
            # 2x2旋转矩阵
            rot = torch.tensor([[cos_t, -sin_t], [sin_t, cos_t]])
            matrices.append(rot)
        return torch.stack(matrices)  # [num_pairs, 2, 2]
    
    def update(self, k: torch.Tensor, v: torch.Tensor):
        """
        更新KV Cache
        
        Args:
            k: Key向量 [batch, num_heads, seq_len, head_dim]
            v: Value向量 [batch, num_heads, seq_len, head_dim]
        """
        batch, num_heads, seq_len, head_dim = k.shape
        
        # 存储原始KV
        start_idx = self.current_len
        end_idx = self.current_len + seq_len
        
        self.k_cache[start_idx:end_idx] = k.squeeze(0).transpose(0, 1)
        self.v_cache[start_idx:end_idx] = v.squeeze(0).transpose(0, 1)
        
        self.current_len += seq_len
        
        # 定期更新频率选择
        if self.current_len % 512 == 0:
            self._update_frequency_selection()
    
    def _update_frequency_selection(self):
        """更新频率选择"""
        # 分析当前缓存的频率能量
        k_subset = self.k_cache[:self.current_len].transpose(0, 1)  # [num_heads, seq_len, head_dim]
        k_subset = k_subset.unsqueeze(0)  # [1, num_heads, seq_len, head_dim]
        
        energies = self.freq_analyzer.compute_frequency_energy(k_subset)
        
        # 基于累积能量选择top-k频率
        total_energies = energies.sum(dim=2)  # [1, num_heads, num_pairs]
        
        # 保留压缩比例的频率
        k_to_keep = max(1, int(self.num_pairs * self.compression_ratio))
        
        _, top_indices = torch.topk(total_energies.squeeze(0), k=k_to_keep, dim=-1)
        
        # 创建掩码
        new_mask = torch.zeros_like(self.freq_mask)
        for h in range(self.num_heads):
            new_mask[h, top_indices[h]] = 1.0
        
        # 平滑更新
        self.freq_mask = 0.7 * self.freq_mask + 0.3 * new_mask
    
    def apply_frequency_projection(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        应用频率投影
        
        Args:
            k: Key向量 [seq_len, num_heads, head_dim]
            v: Value向量 [seq_len, num_heads, head_dim]
        
        Returns:
            投影后的KV
        """
        # 重塑为频率对格式
        k_pairs = k.view(-1, self.num_heads, self.num_pairs, 2)  # [seq, heads, pairs, 2]
        v_pairs = v.view(-1, self.num_heads, self.num_pairs, 2)
        
        # 应用频率权重
        freq_weights = self.freq_mask.unsqueeze(0).unsqueeze(-1)  # [1, heads, pairs, 1]
        
        k_projected = k_pairs * freq_weights
        v_projected = v_pairs * freq_weights
        
        # 恢复原始形状
        k_projected = k_projected.view(-1, self.num_heads, self.head_dim)
        v_projected = v_projected.view(-1, self.num_heads, self.head_dim)
        
        return k_projected, v_projected
    
    def get_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """获取压缩后的KV Cache"""
        if self.current_len == 0:
            return None, None
        
        k = self.k_cache[:self.current_len]
        v = self.v_cache[:self.current_len]
        
        # 应用频率投影
        k_proj, v_proj = self.apply_frequency_projection(k, v)
        
        return k_proj, v_proj
    
    def get_memory_usage(self) -> dict:
        """获取内存使用情况"""
        original_size = self.current_len * self.num_heads * self.head_dim * 2 * 4  # bytes
        compressed_size = self.current_len * self.num_heads * self.head_dim * 2 * 4 * self.compression_ratio
        
        return {
            'original_bytes': original_size,
            'compressed_bytes': compressed_size,
            'compression_ratio': original_size / max(1, compressed_size),
            'num_tokens': self.current_len
        }
 
 
class EliteAttention(nn.Module):
    """
    EliteKV注意力层
    
    集成频率选择和低秩投影的注意力机制
    """
    
    def __init__(self, 
                 hidden_dim: int, 
                 num_heads: int, 
                 head_dim: int,
                 compression_ratio: float = 0.3):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        
        # QKV投影
        self.q_proj = nn.Linear(hidden_dim, num_heads * head_dim)
        self.k_proj = nn.Linear(hidden_dim, num_heads * head_dim)
        self.v_proj = nn.Linear(hidden_dim, num_heads * head_dim)
        self.o_proj = nn.Linear(num_heads * head_dim, hidden_dim)
        
        # EliteKV缓存
        self.kv_cache = None
        self.max_seq_len = 8192
        self.compression_ratio = compression_ratio
        
        # RoPE编码器
        self.use_rope = True
    
    def init_kv_cache(self, max_seq_len: int):
        """初始化KV Cache"""
        self.max_seq_len = max_seq_len
        self.kv_cache = EliteKVCache(
            self.num_heads,
            self.head_dim,
            max_seq_len,
            compression_ratio=self.compression_ratio
        )
    
    def apply_rope(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
        """
        应用RoPE编码
        
        Args:
            x: 输入张量 [..., head_dim]
            position_ids: 位置ID [..., seq_len]
        
        Returns:
            应用RoPE后的张量
        """
        seq_len = x.shape[-2]
        dim = x.shape[-1]
        num_pairs = dim // 2
        
        # 计算频率
        freqs = 10000.0 * torch.exp(
            -torch.arange(0, dim, 2, device=x.device, dtype=x.dtype) * 
            (math.log(10000.0) / dim)
        )
        
        # 计算旋转角度
        positions = position_ids.unsqueeze(-1).float()
        angles = positions * freqs[None, None, :]
        
        # 计算旋转矩阵
        cos = angles.cos()
        sin = angles.sin()
        
        # 重塑输入
        x1 = x[..., :num_pairs]  # 偶数索引
        x2 = x[..., num_pairs:]  # 奇数索引
        
        # 应用旋转
        x_rotated = torch.cat([
            x1 * cos - x2 * sin,
            x1 * sin + x2 * cos
        ], dim=-1)
        
        return x_rotated
    
    def forward(self, 
                x: torch.Tensor, 
                position_ids: torch.Tensor,
                use_cache: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        前向传播
        
        Args:
            x: 输入张量 [batch, seq_len, hidden_dim]
            position_ids: 位置ID [batch, seq_len]
            use_cache: 是否使用KV Cache
        
        Returns:
            output: 输出张量
            k: Key向量
            v: Value向量
        """
        batch, seq_len, _ = x.shape
        
        # QKV投影
        q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
        
        # 应用RoPE
        if self.use_rope:
            q = self.apply_rope(q.transpose(1, 2), position_ids).transpose(1, 2)
            k = self.apply_rope(k.transpose(1, 2), position_ids).transpose(1, 2)
        
        # 更新KV Cache
        if use_cache and self.kv_cache is not None and seq_len == 1:
            self.kv_cache.update(k, v)
            k_cached, v_cached = self.kv_cache.get_cache()
            
            if k_cached is not None:
                k_full = torch.cat([k_cached.transpose(0, 1).unsqueeze(0), k.transpose(1, 2)], dim=2)
                v_full = torch.cat([v_cached.transpose(0, 1).unsqueeze(0), v.transpose(1, 2)], dim=2)
            else:
                k_full = k.transpose(1, 2)
                v_full = v.transpose(1, 2)
        else:
            k_full = k.transpose(1, 2)
            v_full = v.transpose(1, 2)
        
        # 计算注意力
        scale = self.head_dim ** -0.5
        q = q.transpose(1, 2) * scale
        
        attn_weights = torch.matmul(q, k_full.transpose(2, 3))
        attn_weights = F.softmax(attn_weights, dim=-1)
        
        attn_output = torch.matmul(attn_weights, v_full)
        attn_output = attn_output.transpose(1, 2).contiguous()
        
        # 输出投影
        attn_output = attn_output.reshape(batch, seq_len, self.num_heads * self.head_dim)
        output = self.o_proj(attn_output)
        
        return output, k, v
 
 
def benchmark_elitekv():
    """基准测试EliteKV"""
    print("=" * 60)
    print("EliteKV Benchmark")
    print("=" * 60)
    
    # 配置
    hidden_dim = 4096
    num_heads = 32
    head_dim = 128
    max_seq_len = 32768
    compression_ratio = 0.3
    
    # 创建模型
    attention = EliteAttention(
        hidden_dim=hidden_dim,
        num_heads=num_heads,
        head_dim=head_dim,
        compression_ratio=compression_ratio
    )
    attention.init_kv_cache(max_seq_len)
    
    # 测试输入
    batch = 1
    seq_len = 1
    x = torch.randn(batch, seq_len, hidden_dim)
    position_ids = torch.tensor([[0]])
    
    # 预填充阶段
    print("\n预填充阶段:")
    for i in range(min(1024, max_seq_len)):
        position_ids = torch.tensor([[i]])
        output, k, v = attention(x, position_ids, use_cache=False)
    
    # 获取内存使用
    mem_info = attention.kv_cache.get_memory_usage()
    print(f"  缓存token数: {mem_info['num_tokens']}")
    print(f"  原始内存: {mem_info['original_bytes'] / 1024 / 1024:.2f} MB")
    print(f"  压缩后内存: {mem_info['compressed_bytes'] / 1024 / 1024:.2f} MB")
    print(f"  压缩比: {mem_info['compression_ratio']:.2f}x")
    
    # 解码阶段
    print("\n解码阶段:")
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    import time
    start = time.time()
    
    for i in range(100):
        position_ids = torch.tensor([[1024 + i]])
        output, k, v = attention(x, position_ids, use_cache=True)
    
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    elapsed = time.time() - start
    
    print(f"  100步解码耗时: {elapsed:.3f}s")
    print(f"  平均每步: {elapsed/100*1000:.2f}ms")
    
    print("\n" + "=" * 60)
 
 
if __name__ == "__main__":
    benchmark_elitekv()

7.2 与标准注意力的集成

EliteKV可以与标准Transformer模型集成:

class EliteKVTransformerLayer(nn.Module):
    """
    使用EliteKV的Transformer层
    """
    
    def __init__(self, 
                 hidden_dim: int, 
                 num_heads: int, 
                 intermediate_size: int,
                 compression_ratio: float = 0.3):
        super().__init__()
        
        self.attention = EliteAttention(
            hidden_dim=hidden_dim,
            num_heads=num_heads,
            head_dim=hidden_dim // num_heads,
            compression_ratio=compression_ratio
        )
        
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, intermediate_size),
            nn.GELU(),
            nn.Linear(intermediate_size, hidden_dim)
        )
        
        self.norm1 = nn.RMSNorm(hidden_dim)
        self.norm2 = nn.RMSNorm(hidden_dim)
    
    def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
        """
        前向传播
        
        Args:
            x: 输入张量 [batch, seq_len, hidden_dim]
            position_ids: 位置ID [batch, seq_len]
        
        Returns:
            output: 输出张量
        """
        # 自注意力
        attn_output, _, _ = self.attention(self.norm1(x), position_ids)
        x = x + attn_output
        
        # FFN
        mlp_output = self.mlp(self.norm2(x))
        x = x + mlp_output
        
        return x
 
 
class EliteKVModel(nn.Module):
    """
    使用EliteKV的完整模型
    """
    
    def __init__(self,
                 vocab_size: int,
                 hidden_dim: int,
                 num_layers: int,
                 num_heads: int,
                 intermediate_size: int,
                 max_seq_len: int = 32768,
                 compression_ratio: float = 0.3):
        super().__init__()
        
        self.embed_tokens = nn.Embedding(vocab_size, hidden_dim)
        
        self.layers = nn.ModuleList([
            EliteKVTransformerLayer(
                hidden_dim=hidden_dim,
                num_heads=num_heads,
                intermediate_size=intermediate_size,
                compression_ratio=compression_ratio
            )
            for _ in range(num_layers)
        ])
        
        self.norm = nn.RMSNorm(hidden_dim)
        self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
        
        self.max_seq_len = max_seq_len
        self.compression_ratio = compression_ratio
        
        # 初始化KV Cache
        for layer in self.layers:
            layer.attention.init_kv_cache(max_seq_len)
    
    def forward(self, input_ids: torch.Tensor, use_cache: bool = True):
        """
        前向传播
        
        Args:
            input_ids: 输入ID [batch, seq_len]
            use_cache: 是否使用KV Cache
        
        Returns:
            logits: 输出logits [batch, seq_len, vocab_size]
        """
        batch, seq_len = input_ids.shape
        
        # 嵌入
        hidden_states = self.embed_tokens(input_ids)
        
        # 位置ID
        position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
        
        # Transformer层
        for layer in self.layers:
            hidden_states = layer(hidden_states, position_ids)
        
        # 最终归一化
        hidden_states = self.norm(hidden_states)
        
        # LM头
        logits = self.lm_head(hidden_states)
        
        return logits
    
    def generate(self, 
                  input_ids: torch.Tensor, 
                  max_new_tokens: int = 100,
                  temperature: float = 1.0,
                  top_k: int = 50):
        """
        自回归生成
        
        Args:
            input_ids: 输入ID [batch, seq_len]
            max_new_tokens: 最大生成长度
            temperature: 采样温度
            top_k: Top-K采样
        
        Returns:
            generated_ids: 生成的ID
        """
        self.eval()
        
        with torch.no_grad():
            for _ in range(max_new_tokens):
                # 前向传播
                logits = self.forward(input_ids, use_cache=True)
                
                # 获取最后一个位置的logits
                next_token_logits = logits[:, -1, :] / temperature
                
                # Top-K采样
                if top_k > 0:
                    v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
                    next_token_logits[next_token_logits < v[:, [-1]]] = float('-inf')
                
                # 采样
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                
                # 拼接
                input_ids = torch.cat([input_ids, next_token], dim=1)
                
                # 检查是否生成结束符
                if (next_token == 2).any():  # 假设2是结束符
                    break
        
        return input_ids

8. 总结

EliteKV通过深入分析RoPE编码的频率特性,提出了一种新颖的KV Cache压缩方法。其核心贡献包括:

  1. 频率能量分析:系统性地分析KV向量在各RoPE频率下的能量分布
  2. 动态频率选择:基于能量和重要性动态选择保留的频率分量
  3. 联合低秩投影:设计频率感知的低秩投影,保持关键信息
  4. 兼容性好:可以与现有量化、稀疏方法结合使用

实验结果表明,EliteKV在长上下文任务上显著优于现有方法,同时保持了较低的内存占用和较高的吞吐量。

参考资料

Footnotes

  1. Su J, Lu Y, Pan S, et al. RoFormer: Enhanced Transformer with Rotary Position Embedding[J]. arXiv preprint arXiv:2104.09864, 2021.