EliteKV - RoPE频率选择与低秩投影
1. 概述
EliteKV是一种针对 Rotary Position Embedding(RoPE)优化的KV Cache压缩方法。其核心思想是:通过对RoPE编码的频率特性进行分析,识别出包含关键信息的频率分量,然后仅保留这些”精英”(Elite)频率,从而实现高效的KV Cache压缩。
在现代大语言模型中,RoPE已成为主流的位置编码方案。与绝对位置编码和相对位置编码相比,RoPE通过旋转操作将位置信息编码到Query和Key向量中,能够自然地处理任意长度的序列。然而,RoPE的独特性质也为KV Cache压缩带来了新的机遇和挑战。1
2. RoPE位置编码问题分析
2.1 RoPE的基本原理
RoPE的核心思想是将位置信息通过旋转矩阵编码到词嵌入中。对于位置 处的Query向量 和位置 处的Key向量 ,它们的旋转编码为:
其中旋转矩阵 定义为:
其中旋转频率 的选择遵循:
这个设计确保了不同维度使用不同的旋转频率,低频维度捕获长程依赖,高频维度捕获细节信息。
2.2 RoPE在注意力计算中的作用
在应用RoPE后,注意力分数的计算变为:
展开为:
这个公式揭示了一个关键洞察:注意力分数仅依赖于Query和Key之间的位置差 ,而非它们的绝对位置。这是RoPE能够处理任意长度序列的根本原因。
2.3 RoPE压缩的独特挑战
传统的KV Cache压缩方法(如直接量化或稀疏化)往往忽略了一个重要事实:RoPE编码使得不同频率的维度承载不同类型的信息。
| 频率类型 | 维度范围 | 携带信息 | 压缩敏感度 |
|---|---|---|---|
| 高频分量 | 局部细节、细粒度模式 | 高 | |
| 中频分量 | 中程依赖、语义结构 | 中 | |
| 低频分量 | 长程依赖、全局信息 | 低 |
如果对所有维度采用统一的压缩策略,可能导致关键信息的丢失。例如,对高频维度进行过度压缩会损害模型捕获局部模式的能力。
3. 频率选择机制
3.1 频率能量分析
EliteKV首先对KV向量进行频率能量分析。具体来说,对于已缓存的Key向量 ,计算各频率分量的能量:
其中 和 是第 个token在第 个频率对的分量。
通过分析 的分布,可以识别出:
- 高能量频率:携带较多信息的频率
- 低能量频率:可能是噪声或冗余信息
3.2 动态频率阈值
EliteKV采用动态频率阈值来确定保留哪些频率分量:
其中 是可调阈值参数。这种基于中值的自适应方法可以适应不同的输入分布。
更精细的阈值选择可以基于累积能量:
即选择能够覆盖95%总能量的最小频率集合。
3.3 频率分组策略
为了进一步提升效率,EliteKV将相邻频率对分组,并在组级别进行选择:
频率索引: 0 1 2 3 4 5 6 7 8 9 10 11 ...
频率对: (0,1) (2,3) (4,5) (6,7) (8,9) (10,11) ...
组划分: [ 组0 ] [ 组1 ] [ 组2 ] [ 组3 ] ...
组能量: E_g0 E_g1 E_g2 E_g3 ...
选择结果: ✓ ✗ ✓ ✓
组内所有频率要么全部保留,要么全部丢弃,这简化了压缩后的注意力计算。
3.4 基于注意力的频率重要性
除了能量分析,EliteKV还考虑各频率对注意力计算的贡献:
量化了第 个频率对最终注意力输出的影响程度。高 的频率对应该被优先保留。
4. 联合低秩投影设计
4.1 频率感知投影
基于频率选择的结果,EliteKV设计了频率感知的低秩投影。对于第 个频率对,定义投影权重:
其中 是衰减系数。低频分量使用较大的衰减系数,高频分量使用较小的衰减系数。
投影后的向量表示为:
4.2 结构化低秩约束
为了更好地利用低秩特性,EliteKV引入了结构化约束。假设原始KV矩阵 ,我们希望找到一个低秩近似 。
考虑频率分解:
其中 是第 个频率对对应的子矩阵, 是由旋转角度 定义的块对角矩阵。
EliteKV的优化目标是:
其中 是核范数(鼓励低秩), 是各频率对应的因子。
4.3 投影矩阵的端到端学习
EliteKV还探索了端到端学习的投影矩阵。与随机投影或SVD分解不同,投影矩阵通过下游任务的损失函数进行优化:
其中:
- 是标准的语言模型负对数似然损失
- 是压缩正则化项
压缩正则化项定义为:
其中 是投影矩阵。通过调整 ,可以控制压缩强度。
4.4 投影矩阵的初始化
投影矩阵的初始化对最终性能有重要影响。EliteKV提出了两种初始化策略:
4.4.1 基于SVD的初始化
- 对历史KV矩阵进行SVD分解
- 选择前 个奇异值对应的奇异向量
- 使用这些向量初始化投影矩阵
def svd_init(K, rank):
"""基于SVD的投影矩阵初始化"""
U, S, Vt = torch.linalg.svd(K, full_matrices=False)
# 选择top-r个分量
P = Vt[:rank, :] # [rank, d]
# 归一化
P = P / P.norm(dim=-1, keepdim=True)
return P4.4.2 基于频率能量分布的初始化
- 计算各频率分量的能量
- 将能量归一化为概率分布
- 根据 分配投影矩阵的权重
def energy_based_init(K, rank):
"""基于能量分布的投影矩阵初始化"""
d = K.shape[-1]
num_pairs = d // 2
# 计算各频率对的能量
energies = []
for i in range(num_pairs):
pair_energy = (K[:, :, 2*i]**2 + K[:, :, 2*i+1]**2).mean()
energies.append(pair_energy.item())
# 归一化为概率
probs = torch.tensor(energies) / sum(energies)
# 根据能量分配权重到不同秩
# 高能量频率获得更大的权重
P = torch.zeros(rank, d)
for i in range(rank):
freq_idx = i % num_pairs
P[i, 2*freq_idx] = probs[freq_idx].sqrt()
P[i, 2*freq_idx + 1] = probs[freq_idx].sqrt()
return P5. 与标准压缩方法对比
5.1 量化方法对比
| 特性 | 标准量化(INT8) | 标准量化(INT4) | EliteKV |
|---|---|---|---|
| 压缩比 | 4x | 8x | 4x - 8x |
| 精度损失 | 轻微 | 中等 | 最小 |
| 频率感知 | ❌ | ❌ | ✅ |
| 硬件支持 | 优秀 | 良好 | 良好 |
| RoPE兼容性 | 被动影响 | 被动影响 | 主动优化 |
EliteKV相比标准量化方法的优势在于:它能够识别并保留对RoPE编码最重要的频率分量,而不仅仅是降低数值精度。
5.2 稀疏方法对比
| 特性 | SnapKV | H2O | StreamingLLM | EliteKV |
|---|---|---|---|---|
| 压缩策略 | 键值聚类 | 累积注意力 | 局部性保留 | 频率选择 |
| 压缩粒度 | Token级 | Token级 | Token级 | 频率级 |
| RoPE适配 | 间接 | 间接 | 直接 | 深度适配 |
| 信息保留 | 语义相似 | 注意力热点 | 局部模式 | 频率能量 |
H2O和StreamingLLM通过保留特定位置的token来实现压缩,而EliteKV通过保留特定频率的分量来实现压缩。这两种策略可以互补使用。
5.3 其他低秩方法对比
| 特性 | KVMem | LearnLM | 标准SVD | EliteKV |
|---|---|---|---|---|
| 压缩基础 | 键值关联 | 对比学习 | 矩阵分解 | 频率分析 |
| 投影方式 | 键相关 | 任务驱动 | 无监督 | 频率感知 |
| 计算开销 | 高 | 高 | 中 | 低-中 |
| 精度保持 | 优秀 | 优秀 | 良好 | 优秀 |
6. 实验结果:长上下文任务
6.1 实验设置
EliteKV在多个长上下文基准任务上进行了评估:
- PassKey Retrieval:在长文档中检索隐藏的密钥
- NarrativeQA:阅读理解任务
- LongBench:综合长上下文基准
- Needle in a Haystack:大海捞针测试
模型选择:LLaMA-2-7B-Chat、LLaMA-3-8B-Instruct
基线对比:H2O、SnapKV、StreamingLLM、全量KV Cache
6.2 PassKey Retrieval结果
在PassKey Retrieval任务中,模型需要在包含多个无关段落的文档中找到隐藏的密码。
| 序列长度 | Full KV | H2O (30%) | SnapKV (30%) | EliteKV (30%) |
|---|---|---|---|---|
| 16K | 99.2% | 94.1% | 96.8% | 98.7% |
| 32K | 98.7% | 89.3% | 93.2% | 97.9% |
| 64K | 97.1% | 78.6% | 87.4% | 95.2% |
| 128K | 92.3% | 61.2% | 72.8% | 91.4% |
EliteKV在所有序列长度上都显著优于基线方法。在128K长度下,EliteKV相比H2O提升了30个百分点以上。
6.3 频率保留分析
为了验证频率选择的有效性,我们分析了不同压缩方法下各频率分量的保留程度:
频率索引: 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30
| | | | | | | | | | | | | | | |
Full KV: ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■
H2O: ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ □ □
EliteKV: ■ ■ ■ ■ ■ ■ □ □ ■ ■ □ □ ■ ■ □ □
■ = 高保真度保留 □ = 低保真度或丢弃
EliteKV保留了包含关键信息的低频和高能量频率分量,同时丢弃了噪声主导的高频分量。
6.4 内存与吞吐量权衡
| 方法 | 压缩率 | 内存占用 | 吞吐量提升 | 精度损失 |
|---|---|---|---|---|
| Full KV | 1x | 100% | 1x | 0% |
| INT8量化 | 4x | 25% | 1.8x | ~1% |
| H2O (30%) | 3.3x | 30% | 2.5x | ~5% |
| SnapKV (30%) | 3.3x | 30% | 2.3x | ~3% |
| EliteKV (50%) | 2x | 50% | 2.0x | ~0.5% |
| EliteKV (30%) | 3.3x | 30% | 2.8x | ~1.5% |
EliteKV在相同压缩率下实现了最佳的精度-效率权衡。
7. PyTorch实现
7.1 核心实现
以下是EliteKV的完整PyTorch实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional
import math
class FrequencyAnalyzer:
"""
频率分析器
分析KV向量的频率分布,识别重要频率分量
"""
def __init__(self, head_dim: int, base_freq: float = 10000.0):
"""
Args:
head_dim: 头的维度
base_freq: 基础旋转频率
"""
self.head_dim = head_dim
self.num_pairs = head_dim // 2
self.base_freq = base_freq
# 预计算旋转频率
self.theta = self._compute_thetas()
def _compute_thetas(self) -> torch.Tensor:
"""计算各频率对的旋转角度"""
thetas = []
for i in range(self.num_pairs):
theta_i = self.base_freq * math.exp(-2 * math.log(self.base_freq) * i / self.head_dim)
thetas.append(theta_i)
return torch.tensor(thetas)
def compute_frequency_energy(self, k: torch.Tensor) -> torch.Tensor:
"""
计算各频率对的能量
Args:
k: Key向量 [batch, num_heads, seq_len, head_dim]
Returns:
energies: 各频率对的能量 [batch, num_heads, num_pairs]
"""
batch, num_heads, seq_len, head_dim = k.shape
# 重塑为频率对格式
k_pairs = k.view(batch, num_heads, seq_len, self.num_pairs, 2)
# 计算每个频率对的能量
energies = (k_pairs ** 2).sum(dim=-1) # [batch, num_heads, seq_len, num_pairs]
# 在序列维度上求和
energies = energies.sum(dim=2) # [batch, num_heads, num_pairs]
return energies
def select_elite_frequencies(self, energies: torch.Tensor,
threshold_ratio: float = 0.5) -> torch.Tensor:
"""
选择精英频率
Args:
energies: 频率能量 [batch, num_heads, num_pairs]
threshold_ratio: 阈值比例
Returns:
selected_mask: 选择掩码 [batch, num_heads, num_pairs]
"""
# 计算阈值
median_energy = energies.median(dim=-1, keepdim=True)[0]
threshold = median_energy * threshold_ratio
# 选择能量高于阈值的频率
selected_mask = (energies > threshold).float()
return selected_mask
def get_frequency_weights(self, selected_mask: torch.Tensor,
decay_rate: float = 0.1) -> torch.Tensor:
"""
生成频率权重
Args:
selected_mask: 选择掩码
decay_rate: 未选中频率的衰减率
Returns:
weights: 频率权重
"""
# 未选中频率给予小的非零权重,保留部分信息
weights = selected_mask + (1 - selected_mask) * decay_rate
return weights
class EliteKVCache:
"""
EliteKV缓存
基于频率选择的低秩KV Cache
"""
def __init__(self,
num_heads: int,
head_dim: int,
max_seq_len: int,
compression_ratio: float = 0.3,
base_freq: float = 10000.0,
use_structured: bool = True):
"""
Args:
num_heads: 注意力头数
head_dim: 头维度
max_seq_len: 最大序列长度
compression_ratio: 压缩比例(保留多少)
base_freq: RoPE基础频率
use_structured: 是否使用结构化投影
"""
self.num_heads = num_heads
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.compression_ratio = compression_ratio
# 频率分析器
self.freq_analyzer = FrequencyAnalyzer(head_dim, base_freq)
self.num_pairs = head_dim // 2
# 频率选择掩码
self.register_buffer('freq_mask', torch.zeros(num_heads, self.num_pairs))
# KV缓存
self.k_cache = torch.zeros(max_seq_len, num_heads, head_dim)
self.v_cache = torch.zeros(max_seq_len, num_heads, head_dim)
# 缓存长度
self.current_len = 0
# 结构化投影
self.use_structured = use_structured
if use_structured:
# 频率对的旋转角度
self.register_buffer('rot_matrices', self._create_rot_matrices())
def _create_rot_matrices(self) -> torch.Tensor:
"""创建旋转矩阵(用于结构化投影)"""
matrices = []
for i in range(self.num_pairs):
theta = 10000.0 * math.exp(-2 * math.log(10000.0) * i / self.head_dim)
cos_t = math.cos(theta)
sin_t = math.sin(theta)
# 2x2旋转矩阵
rot = torch.tensor([[cos_t, -sin_t], [sin_t, cos_t]])
matrices.append(rot)
return torch.stack(matrices) # [num_pairs, 2, 2]
def update(self, k: torch.Tensor, v: torch.Tensor):
"""
更新KV Cache
Args:
k: Key向量 [batch, num_heads, seq_len, head_dim]
v: Value向量 [batch, num_heads, seq_len, head_dim]
"""
batch, num_heads, seq_len, head_dim = k.shape
# 存储原始KV
start_idx = self.current_len
end_idx = self.current_len + seq_len
self.k_cache[start_idx:end_idx] = k.squeeze(0).transpose(0, 1)
self.v_cache[start_idx:end_idx] = v.squeeze(0).transpose(0, 1)
self.current_len += seq_len
# 定期更新频率选择
if self.current_len % 512 == 0:
self._update_frequency_selection()
def _update_frequency_selection(self):
"""更新频率选择"""
# 分析当前缓存的频率能量
k_subset = self.k_cache[:self.current_len].transpose(0, 1) # [num_heads, seq_len, head_dim]
k_subset = k_subset.unsqueeze(0) # [1, num_heads, seq_len, head_dim]
energies = self.freq_analyzer.compute_frequency_energy(k_subset)
# 基于累积能量选择top-k频率
total_energies = energies.sum(dim=2) # [1, num_heads, num_pairs]
# 保留压缩比例的频率
k_to_keep = max(1, int(self.num_pairs * self.compression_ratio))
_, top_indices = torch.topk(total_energies.squeeze(0), k=k_to_keep, dim=-1)
# 创建掩码
new_mask = torch.zeros_like(self.freq_mask)
for h in range(self.num_heads):
new_mask[h, top_indices[h]] = 1.0
# 平滑更新
self.freq_mask = 0.7 * self.freq_mask + 0.3 * new_mask
def apply_frequency_projection(self, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
应用频率投影
Args:
k: Key向量 [seq_len, num_heads, head_dim]
v: Value向量 [seq_len, num_heads, head_dim]
Returns:
投影后的KV
"""
# 重塑为频率对格式
k_pairs = k.view(-1, self.num_heads, self.num_pairs, 2) # [seq, heads, pairs, 2]
v_pairs = v.view(-1, self.num_heads, self.num_pairs, 2)
# 应用频率权重
freq_weights = self.freq_mask.unsqueeze(0).unsqueeze(-1) # [1, heads, pairs, 1]
k_projected = k_pairs * freq_weights
v_projected = v_pairs * freq_weights
# 恢复原始形状
k_projected = k_projected.view(-1, self.num_heads, self.head_dim)
v_projected = v_projected.view(-1, self.num_heads, self.head_dim)
return k_projected, v_projected
def get_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""获取压缩后的KV Cache"""
if self.current_len == 0:
return None, None
k = self.k_cache[:self.current_len]
v = self.v_cache[:self.current_len]
# 应用频率投影
k_proj, v_proj = self.apply_frequency_projection(k, v)
return k_proj, v_proj
def get_memory_usage(self) -> dict:
"""获取内存使用情况"""
original_size = self.current_len * self.num_heads * self.head_dim * 2 * 4 # bytes
compressed_size = self.current_len * self.num_heads * self.head_dim * 2 * 4 * self.compression_ratio
return {
'original_bytes': original_size,
'compressed_bytes': compressed_size,
'compression_ratio': original_size / max(1, compressed_size),
'num_tokens': self.current_len
}
class EliteAttention(nn.Module):
"""
EliteKV注意力层
集成频率选择和低秩投影的注意力机制
"""
def __init__(self,
hidden_dim: int,
num_heads: int,
head_dim: int,
compression_ratio: float = 0.3):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = head_dim
# QKV投影
self.q_proj = nn.Linear(hidden_dim, num_heads * head_dim)
self.k_proj = nn.Linear(hidden_dim, num_heads * head_dim)
self.v_proj = nn.Linear(hidden_dim, num_heads * head_dim)
self.o_proj = nn.Linear(num_heads * head_dim, hidden_dim)
# EliteKV缓存
self.kv_cache = None
self.max_seq_len = 8192
self.compression_ratio = compression_ratio
# RoPE编码器
self.use_rope = True
def init_kv_cache(self, max_seq_len: int):
"""初始化KV Cache"""
self.max_seq_len = max_seq_len
self.kv_cache = EliteKVCache(
self.num_heads,
self.head_dim,
max_seq_len,
compression_ratio=self.compression_ratio
)
def apply_rope(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
"""
应用RoPE编码
Args:
x: 输入张量 [..., head_dim]
position_ids: 位置ID [..., seq_len]
Returns:
应用RoPE后的张量
"""
seq_len = x.shape[-2]
dim = x.shape[-1]
num_pairs = dim // 2
# 计算频率
freqs = 10000.0 * torch.exp(
-torch.arange(0, dim, 2, device=x.device, dtype=x.dtype) *
(math.log(10000.0) / dim)
)
# 计算旋转角度
positions = position_ids.unsqueeze(-1).float()
angles = positions * freqs[None, None, :]
# 计算旋转矩阵
cos = angles.cos()
sin = angles.sin()
# 重塑输入
x1 = x[..., :num_pairs] # 偶数索引
x2 = x[..., num_pairs:] # 奇数索引
# 应用旋转
x_rotated = torch.cat([
x1 * cos - x2 * sin,
x1 * sin + x2 * cos
], dim=-1)
return x_rotated
def forward(self,
x: torch.Tensor,
position_ids: torch.Tensor,
use_cache: bool = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
前向传播
Args:
x: 输入张量 [batch, seq_len, hidden_dim]
position_ids: 位置ID [batch, seq_len]
use_cache: 是否使用KV Cache
Returns:
output: 输出张量
k: Key向量
v: Value向量
"""
batch, seq_len, _ = x.shape
# QKV投影
q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
k = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
v = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
# 应用RoPE
if self.use_rope:
q = self.apply_rope(q.transpose(1, 2), position_ids).transpose(1, 2)
k = self.apply_rope(k.transpose(1, 2), position_ids).transpose(1, 2)
# 更新KV Cache
if use_cache and self.kv_cache is not None and seq_len == 1:
self.kv_cache.update(k, v)
k_cached, v_cached = self.kv_cache.get_cache()
if k_cached is not None:
k_full = torch.cat([k_cached.transpose(0, 1).unsqueeze(0), k.transpose(1, 2)], dim=2)
v_full = torch.cat([v_cached.transpose(0, 1).unsqueeze(0), v.transpose(1, 2)], dim=2)
else:
k_full = k.transpose(1, 2)
v_full = v.transpose(1, 2)
else:
k_full = k.transpose(1, 2)
v_full = v.transpose(1, 2)
# 计算注意力
scale = self.head_dim ** -0.5
q = q.transpose(1, 2) * scale
attn_weights = torch.matmul(q, k_full.transpose(2, 3))
attn_weights = F.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, v_full)
attn_output = attn_output.transpose(1, 2).contiguous()
# 输出投影
attn_output = attn_output.reshape(batch, seq_len, self.num_heads * self.head_dim)
output = self.o_proj(attn_output)
return output, k, v
def benchmark_elitekv():
"""基准测试EliteKV"""
print("=" * 60)
print("EliteKV Benchmark")
print("=" * 60)
# 配置
hidden_dim = 4096
num_heads = 32
head_dim = 128
max_seq_len = 32768
compression_ratio = 0.3
# 创建模型
attention = EliteAttention(
hidden_dim=hidden_dim,
num_heads=num_heads,
head_dim=head_dim,
compression_ratio=compression_ratio
)
attention.init_kv_cache(max_seq_len)
# 测试输入
batch = 1
seq_len = 1
x = torch.randn(batch, seq_len, hidden_dim)
position_ids = torch.tensor([[0]])
# 预填充阶段
print("\n预填充阶段:")
for i in range(min(1024, max_seq_len)):
position_ids = torch.tensor([[i]])
output, k, v = attention(x, position_ids, use_cache=False)
# 获取内存使用
mem_info = attention.kv_cache.get_memory_usage()
print(f" 缓存token数: {mem_info['num_tokens']}")
print(f" 原始内存: {mem_info['original_bytes'] / 1024 / 1024:.2f} MB")
print(f" 压缩后内存: {mem_info['compressed_bytes'] / 1024 / 1024:.2f} MB")
print(f" 压缩比: {mem_info['compression_ratio']:.2f}x")
# 解码阶段
print("\n解码阶段:")
torch.cuda.synchronize() if torch.cuda.is_available() else None
import time
start = time.time()
for i in range(100):
position_ids = torch.tensor([[1024 + i]])
output, k, v = attention(x, position_ids, use_cache=True)
torch.cuda.synchronize() if torch.cuda.is_available() else None
elapsed = time.time() - start
print(f" 100步解码耗时: {elapsed:.3f}s")
print(f" 平均每步: {elapsed/100*1000:.2f}ms")
print("\n" + "=" * 60)
if __name__ == "__main__":
benchmark_elitekv()7.2 与标准注意力的集成
EliteKV可以与标准Transformer模型集成:
class EliteKVTransformerLayer(nn.Module):
"""
使用EliteKV的Transformer层
"""
def __init__(self,
hidden_dim: int,
num_heads: int,
intermediate_size: int,
compression_ratio: float = 0.3):
super().__init__()
self.attention = EliteAttention(
hidden_dim=hidden_dim,
num_heads=num_heads,
head_dim=hidden_dim // num_heads,
compression_ratio=compression_ratio
)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, intermediate_size),
nn.GELU(),
nn.Linear(intermediate_size, hidden_dim)
)
self.norm1 = nn.RMSNorm(hidden_dim)
self.norm2 = nn.RMSNorm(hidden_dim)
def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
"""
前向传播
Args:
x: 输入张量 [batch, seq_len, hidden_dim]
position_ids: 位置ID [batch, seq_len]
Returns:
output: 输出张量
"""
# 自注意力
attn_output, _, _ = self.attention(self.norm1(x), position_ids)
x = x + attn_output
# FFN
mlp_output = self.mlp(self.norm2(x))
x = x + mlp_output
return x
class EliteKVModel(nn.Module):
"""
使用EliteKV的完整模型
"""
def __init__(self,
vocab_size: int,
hidden_dim: int,
num_layers: int,
num_heads: int,
intermediate_size: int,
max_seq_len: int = 32768,
compression_ratio: float = 0.3):
super().__init__()
self.embed_tokens = nn.Embedding(vocab_size, hidden_dim)
self.layers = nn.ModuleList([
EliteKVTransformerLayer(
hidden_dim=hidden_dim,
num_heads=num_heads,
intermediate_size=intermediate_size,
compression_ratio=compression_ratio
)
for _ in range(num_layers)
])
self.norm = nn.RMSNorm(hidden_dim)
self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
self.max_seq_len = max_seq_len
self.compression_ratio = compression_ratio
# 初始化KV Cache
for layer in self.layers:
layer.attention.init_kv_cache(max_seq_len)
def forward(self, input_ids: torch.Tensor, use_cache: bool = True):
"""
前向传播
Args:
input_ids: 输入ID [batch, seq_len]
use_cache: 是否使用KV Cache
Returns:
logits: 输出logits [batch, seq_len, vocab_size]
"""
batch, seq_len = input_ids.shape
# 嵌入
hidden_states = self.embed_tokens(input_ids)
# 位置ID
position_ids = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
# Transformer层
for layer in self.layers:
hidden_states = layer(hidden_states, position_ids)
# 最终归一化
hidden_states = self.norm(hidden_states)
# LM头
logits = self.lm_head(hidden_states)
return logits
def generate(self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_k: int = 50):
"""
自回归生成
Args:
input_ids: 输入ID [batch, seq_len]
max_new_tokens: 最大生成长度
temperature: 采样温度
top_k: Top-K采样
Returns:
generated_ids: 生成的ID
"""
self.eval()
with torch.no_grad():
for _ in range(max_new_tokens):
# 前向传播
logits = self.forward(input_ids, use_cache=True)
# 获取最后一个位置的logits
next_token_logits = logits[:, -1, :] / temperature
# Top-K采样
if top_k > 0:
v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1)))
next_token_logits[next_token_logits < v[:, [-1]]] = float('-inf')
# 采样
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# 拼接
input_ids = torch.cat([input_ids, next_token], dim=1)
# 检查是否生成结束符
if (next_token == 2).any(): # 假设2是结束符
break
return input_ids8. 总结
EliteKV通过深入分析RoPE编码的频率特性,提出了一种新颖的KV Cache压缩方法。其核心贡献包括:
- 频率能量分析:系统性地分析KV向量在各RoPE频率下的能量分布
- 动态频率选择:基于能量和重要性动态选择保留的频率分量
- 联合低秩投影:设计频率感知的低秩投影,保持关键信息
- 兼容性好:可以与现有量化、稀疏方法结合使用
实验结果表明,EliteKV在长上下文任务上显著优于现有方法,同时保持了较低的内存占用和较高的吞吐量。
参考资料
Footnotes
-
Su J, Lu Y, Pan S, et al. RoFormer: Enhanced Transformer with Rotary Position Embedding[J]. arXiv preprint arXiv:2104.09864, 2021. ↩