TreeKV - 树结构平滑压缩
1. 概述
TreeKV是一种基于树结构组织KV Cache的压缩方法。其核心思想是:将KV Cache组织为多层树结构,通过平滑合并策略在保持重要信息的同时显著减少存储和计算开销。
传统的KV Cache压缩方法(如H2O、SnapKV等)在压缩时往往采用”一刀切”的策略:要么保留最近token,要么保留注意力热点。这种策略的局限在于忽略了token之间的语义关联和层级结构。TreeKV通过引入树结构,能够更好地捕捉token之间的层次关系和语义聚类,从而实现更智能的压缩。1
2. 树结构压缩思想
2.1 从线性到树形
传统的KV Cache采用线性结构存储:
线性结构:
┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐
│ KV₀ │ KV₁ │ KV₂ │ KV₃ │ KV₄ │ KV₅ │ KV₆ │ KV₇ │ ...
└─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘
↑
仅能表示序列顺序,无法表示语义聚类
TreeKV采用树形结构组织:
树形结构:
┌─────┐
│Root │
└──┬──┘
┌──────┼──────┐
┌────▼┐ ┌───▼───┐ ┌─▼────┐
│Node1│ │ Node2 │ │Node3 │
└──┬──┘ └───┬───┘ └──┬───┘
┌────┼────┐ │ ┌──┼──┐
┌──▼┐┌─▼──┐│┌─▼───┐ ┌▼─┐┌─▼──┐│
│leaf││leaf│││leaf │ │leaf││leaf││
└────┘└────┘└┴─────┘ └───┘└────┘│
│
┌──▼──┐
│leaf │
└─────┘
在树结构中:
- 叶子节点:原始KV向量
- 内部节点:由子节点合并得到的聚合表示
- 根节点:整个序列的全局表示
2.2 树结构的优势
| 维度 | 线性结构 | 树结构 |
|---|---|---|
| 组织方式 | 位置顺序 | 语义聚类 |
| 查询效率 | O(1) 按位置 | O(log n) 按范围 |
| 信息聚合 | 无 | 多层次 |
| 压缩粒度 | 固定 | 自适应 |
| 语义保留 | 弱 | 强 |
| 计算开销 | 低 | 中 |
2.3 树结构的数学表示
设原始KV序列为 ,构建一棵完全 叉树,其中:
- 叶子节点数:, 是每个叶子代表的token数
- 树深度:
- 节点数:
每个节点 存储一个聚合的KV表示 :
聚合函数 可以是:
- 平均值:
- 加权平均:,权重基于注意力
- 注意力池化:
3. 平滑合并策略
3.1 平滑合并的核心思想
**平滑合并(Smooth Merging)**是TreeKV的核心技术。与简单的分组平均不同,平滑合并考虑了相邻token之间的过渡平滑性,避免在合并边界产生信息突变。
平滑合并的目标函数:
其中:
- 第一项是重建误差,确保合并后的表示接近原始子表示
- 第二项是平滑正则化,确保相邻节点之间的表示平滑过渡
- 是相邻父节点的表示
- 是平滑系数
3.2 合并算法的形式化
设节点 有 个子节点 ,合并操作为:
def smooth_merge(child_keys, child_values, smoothing_weight=0.1):
"""
平滑合并子节点的KV
Args:
child_keys: 子节点的Key向量列表
child_values: 子节点的Value向量列表
smoothing_weight: 平滑权重
Returns:
merged_key: 合并后的Key向量
merged_value: 合并后的Value向量
"""
# 计算子节点的加权平均
weights = compute_attention_weights(child_keys)
merged_key = sum(w * k for w, k in zip(weights, child_keys))
merged_value = sum(w * v for w, v in zip(weights, child_values))
# 应用平滑约束
# 这里简化处理,实际需要传入父节点的表示
return merged_key, merged_value
def compute_attention_weights(keys):
"""
基于Key相似度计算注意力权重
Args:
keys: Key向量列表
Returns:
weights: 注意力权重
"""
num_children = len(keys)
# 计算Query(这里用第一个Key作为Query)
q = keys[0]
# 计算注意力分数
scores = []
for k in keys:
score = torch.dot(q, k) / (torch.norm(q) * torch.norm(k) + 1e-8)
scores.append(score.item())
# Softmax归一化
scores_tensor = torch.tensor(scores)
weights = F.softmax(scores_tensor, dim=0)
return weights3.3 边界平滑处理
合并边界是平滑合并的关键挑战。当两个相邻节点合并时,需要确保边界处的表示平滑过渡:
原始序列: [a₁, a₂, a₃, a₄] | [b₁, b₂, b₃, b₄]
节点A的末尾 节点B的开头
合并后: [a₁', a₂', a₃', a₄'] | [b₁', b₂', b₃', b₄']
↑平滑过渡↑
a₄' ≈ α·a₄ + (1-α)·b₁
边界平滑使用线性插值:
其中 基于边界两侧节点的重要性动态确定。
3.4 多尺度平滑
TreeKV支持多尺度平滑,允许在不同层级使用不同的平滑策略:
| 层级 | 平滑策略 | 适用场景 |
|---|---|---|
| 浅层(靠近叶子) | 强平滑 | 保留局部细节 |
| 中层 | 中等平滑 | 平衡精度与压缩 |
| 深层(靠近根) | 弱平滑 | 保留全局信息 |
4. 动态树构建
4.1 在线树构建
TreeKV支持在线构建,即随着新token的生成动态更新树结构。构建策略包括:
4.1.1 底部生长策略
从叶子节点开始,逐层向上构建:
- 新token进入时,首先作为叶子节点存储
- 当叶子节点满时,触发合并操作
- 合并后的表示作为父节点
- 如果父节点也满,继续向上合并
class TreeKVBuilder:
"""动态树构建器"""
def __init__(self,
max_children: int = 4,
max_depth: int = 4,
merge_threshold: float = 0.8):
self.max_children = max_children
self.max_depth = max_depth
self.merge_threshold = merge_threshold
# 根节点
self.root = None
# 当前叶子节点
self.leaves = []
# 节点池
self.nodes = {}
def add_token(self, key: torch.Tensor, value: torch.Tensor,
position: int) -> bool:
"""
添加新token
Args:
key: Key向量
value: Value向量
position: 位置
Returns:
是否触发了合并
"""
# 创建叶子节点
leaf = TreeNode(
key=key,
value=value,
position=position,
depth=0,
is_leaf=True
)
self.leaves.append(leaf)
# 检查是否需要合并
if len(self.leaves) >= self.max_children:
self._merge_level(0)
return True
return False
def _merge_level(self, level: int):
"""合并指定层的节点"""
if level >= self.max_depth:
return
# 获取当前层的所有节点
nodes_at_level = [leaf for leaf in self.leaves
if leaf.depth == level]
if len(nodes_at_level) < self.max_children:
return
# 分组并合并
groups = self._group_nodes(nodes_at_level)
for group in groups:
merged_key, merged_value = self._smooth_merge(group)
# 创建父节点
parent = TreeNode(
key=merged_key,
value=merged_value,
position=group[0].position,
depth=level + 1,
is_leaf=False,
children=group
)
# 更新树结构
self._update_parent_links(parent, group)
def _group_nodes(self, nodes: List[TreeNode]) -> List[List[TreeNode]]:
"""将节点分组"""
groups = []
current_group = []
for node in nodes:
current_group.append(node)
if len(current_group) >= self.max_children:
groups.append(current_group)
current_group = []
if current_group:
groups.append(current_group)
return groups
def _smooth_merge(self, nodes: List[TreeNode]) -> Tuple[torch.Tensor, torch.Tensor]:
"""平滑合并节点"""
keys = [n.key for n in nodes]
values = [n.value for n in nodes]
# 计算注意力权重
weights = self._compute_weights(keys)
# 加权平均
merged_key = sum(w * k for w, k in zip(weights, keys))
merged_value = sum(w * v for w, v in zip(weights, values))
# 应用平滑
merged_key = self._apply_smoothing(merged_key, nodes)
return merged_key, merged_value
def _compute_weights(self, keys: List[torch.Tensor]) -> torch.Tensor:
"""计算合并权重"""
num = len(keys)
if num == 1:
return torch.tensor([1.0])
# 基于Key相似度计算权重
weights = []
for i, k in enumerate(keys):
# 计算与其他Key的平均相似度
similarity_sum = 0
for j, other in enumerate(keys):
if i != j:
sim = torch.dot(k, other) / (torch.norm(k) * torch.norm(other) + 1e-8)
similarity_sum += sim.item()
weights.append(similarity_sum / (num - 1))
# Softmax归一化
weights = torch.tensor(weights)
weights = F.softmax(weights, dim=0)
return weights
def _apply_smoothing(self,
merged: torch.Tensor,
nodes: List[TreeNode]) -> torch.Tensor:
"""应用平滑约束"""
# 获取相邻节点的信息
# 简化实现:使用轻微的L2正则化平滑
smoothed = merged.clone()
return smoothed4.2 树的平衡与重平衡
为了保持查询效率,TreeKV需要维护树的平衡性。当树变得不平衡时,触发重平衡操作。
4.2.1 不平衡检测
def check_balance(self) -> bool:
"""检查树是否平衡"""
if not self.leaves:
return True
depths = [leaf.get_depth() for leaf in self.leaves]
max_depth = max(depths)
min_depth = min(depths)
# 如果深度差异超过阈值,需要重平衡
return (max_depth - min_depth) <= self.max_depth * 0.54.2.2 重平衡操作
def rebalance(self):
"""重平衡树"""
# 收集所有叶子节点
all_leaves = self._collect_all_leaves()
# 重新分组
groups = self._regroup(all_leaves)
# 重建树
self._rebuild_tree(groups)4.3 树的剪枝
当某些分支不再重要时,TreeKV支持剪枝操作:
def prune(self, importance_threshold: float = 0.1):
"""
剪枝不重要的分支
Args:
importance_threshold: 重要性阈值
"""
# 计算每个节点的重要性
importance = self._compute_importance()
# 标记需要剪枝的节点
to_prune = []
for node_id, imp in importance.items():
if imp < importance_threshold:
to_prune.append(node_id)
# 执行剪枝
for node_id in to_prune:
self._prune_node(node_id)
# 更新树结构
self._compact_tree()5. 内存与计算权衡
5.1 内存占用分析
TreeKV的内存占用取决于树的结构和压缩级别。设原始KV Cache大小为 :
其中每个节点的存储大小为:
对于一棵深度为 、分支因子为 的满 叉树:
通过选择性地丢弃中间节点,可以显著减少存储:
| 策略 | 存储节点 | 内存占用 | 查询精度 |
|---|---|---|---|
| 全存储 | 所有节点 | 100% | |
| 仅叶子 | 叶子节点 | 依赖插值 | |
| 仅根 | 根节点 | ~60% | |
| 稀疏存储 | 关键节点 | ~95% |
5.2 计算开销分析
TreeKV的计算开销主要来自:
- 树构建:
- 合并操作:,其中 是向量维度
- 查询操作:
- 注意力计算:,其中 是查询涉及的节点数
5.3 权衡策略
TreeKV提供多种权衡策略:
5.3.1 压缩率-精度曲线
压缩率 ↑
↑
100% ─┤ ● Full KV
80% ─┤ ●─────●
60% ─┤ ●─────
40% ─┤ ●─────
20% ─┤ ●─────
0% ─┼──┴──┴──┴──┴──┴──┴──┴──→ 精度 ↑
0% 20% 40% 60% 80% 100%
5.3.2 自适应压缩
class AdaptiveCompression:
"""自适应压缩控制器"""
def __init__(self,
min_compression: float = 0.3,
max_compression: float = 0.8,
target_memory_mb: float = 1024):
self.min_compression = min_compression
self.max_compression = max_compression
self.target_memory = target_memory_mb * 1024 * 1024
self.current_ratio = 1.0
def adjust(self, current_memory: int, accuracy: float):
"""
根据当前状态调整压缩比
Args:
current_memory: 当前内存占用
accuracy: 当前精度
Returns:
new_ratio: 新的压缩比
"""
# 内存约束
if current_memory > self.target_memory:
self.current_ratio *= 0.9
return max(self.min_compression, self.current_ratio)
# 精度约束
if accuracy < 0.95:
self.current_ratio *= 1.1
return min(self.max_compression, self.current_ratio)
# 平衡状态
return self.current_ratio6. PyTorch实现代码
6.1 核心数据结构
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional, Dict
from dataclasses import dataclass, field
from collections import deque
import math
@dataclass
class TreeNode:
"""树节点"""
node_id: int
key: torch.Tensor # [head_dim]
value: torch.Tensor # [head_dim]
depth: int
position: int # 起始位置
span: int = 1 # 覆盖的token数
is_leaf: bool = True
children: List['TreeNode'] = field(default_factory=list)
parent: Optional['TreeNode'] = None
importance: float = 1.0
class TreeKVCache:
"""
TreeKV缓存
基于树结构的KV Cache实现
"""
def __init__(self,
num_heads: int,
head_dim: int,
max_seq_len: int,
max_children: int = 4,
max_depth: int = 4,
store_children: bool = False):
"""
Args:
num_heads: 注意力头数
head_dim: 头维度
max_seq_len: 最大序列长度
max_children: 最大分支数
max_depth: 最大深度
store_children: 是否存储子节点
"""
self.num_heads = num_heads
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.max_children = max_children
self.max_depth = max_depth
self.store_children = store_children
# 节点管理
self.next_node_id = 0
self.nodes: Dict[int, TreeNode] = {}
self.leaves: List[TreeNode] = []
self.roots: List[TreeNode] = []
# 当前位置
self.current_position = 0
# 层级缓冲区
self.level_buffers: List[List[TreeNode]] = [[] for _ in range(max_depth + 1)]
def _create_node(self,
key: torch.Tensor,
value: torch.Tensor,
depth: int,
position: int,
span: int = 1,
children: List[TreeNode] = None) -> TreeNode:
"""创建新节点"""
node = TreeNode(
node_id=self.next_node_id,
key=key,
value=value,
depth=depth,
position=position,
span=span,
is_leaf=(depth == 0),
children=children or []
)
self.next_node_id += 1
self.nodes[node.node_id] = node
if children:
for child in children:
child.parent = node
return node
def _smooth_merge(self, nodes: List[TreeNode]) -> Tuple[torch.Tensor, torch.Tensor]:
"""
平滑合并多个节点
Args:
nodes: 要合并的节点列表
Returns:
merged_key, merged_value: 合并后的KV
"""
if len(nodes) == 1:
return nodes[0].key.clone(), nodes[0].value.clone()
# 计算注意力权重
keys = torch.stack([n.key for n in nodes]) # [num, head_dim]
# 使用第一个Key作为Query计算注意力
q = keys[0:1] # [1, head_dim]
# 计算注意力分数
scale = math.sqrt(self.head_dim)
attn_scores = torch.matmul(q, keys.T) / scale # [1, num]
attn_weights = F.softmax(attn_scores, dim=-1) # [1, num]
# 加权平均
weights = attn_weights.squeeze(0) # [num]
merged_key = torch.sum(keys * weights.unsqueeze(-1), dim=0)
values = torch.stack([n.value for n in nodes])
merged_value = torch.sum(values * weights.unsqueeze(-1), dim=0)
return merged_key, merged_value
def _merge_level(self, level: int) -> Optional[TreeNode]:
"""
合并指定层的节点
Args:
level: 层索引
Returns:
合并后的父节点(如果有)
"""
buffer = self.level_buffers[level]
if len(buffer) < self.max_children:
return None
# 取出节点
nodes_to_merge = buffer[:self.max_children]
self.level_buffers[level] = buffer[self.max_children:]
# 计算合并后的位置(取起始位置)
start_position = min(n.position for n in nodes_to_merge)
span = sum(n.span for n in nodes_to_merge)
# 平滑合并
merged_key, merged_value = self._smooth_merge(nodes_to_merge)
# 创建父节点
if level + 1 <= self.max_depth:
parent = self._create_node(
key=merged_key,
value=merged_value,
depth=level + 1,
position=start_position,
span=span,
children=nodes_to_merge
)
# 添加到上一层缓冲区
self.level_buffers[level + 1].append(parent)
return parent
return None
def add_token(self, key: torch.Tensor, value: torch.Tensor) -> bool:
"""
添加新token
Args:
key: Key向量 [num_heads, head_dim]
value: Value向量 [num_heads, head_dim]
Returns:
是否触发了合并
"""
# 创建叶子节点
leaf = self._create_node(
key=key,
value=value,
depth=0,
position=self.current_position,
span=1
)
self.leaves.append(leaf)
self.level_buffers[0].append(leaf)
self.current_position += 1
# 检查是否需要合并
if len(self.level_buffers[0]) >= self.max_children:
self._trigger_merges()
return True
return False
def _trigger_merges(self):
"""触发多层合并"""
level = 0
while level < self.max_depth:
merged = self._merge_level(level)
if merged is None:
break
level += 1
# 如果根节点已满,将其替换
if len(self.level_buffers[self.max_depth]) > 1:
self._final_merge()
def _final_merge(self):
"""最终合并:创建单一根节点"""
root_nodes = self.level_buffers[self.max_depth]
if len(root_nodes) < 2:
return
# 合并所有根节点
merged_key, merged_value = self._smooth_merge(root_nodes)
# 创建新根
new_root = self._create_node(
key=merged_key,
value=merged_value,
depth=self.max_depth + 1,
position=0,
span=self.current_position,
children=root_nodes
)
self.roots = [new_root]
self.level_buffers[self.max_depth] = []
def get_node_at_position(self, position: int) -> Optional[TreeNode]:
"""
获取指定位置的节点(近似查询)
Args:
position: 位置
Returns:
覆盖该位置的节点
"""
# 从叶子开始查找
for leaf in self.leaves:
if leaf.position == position:
return leaf
# 从根向下查找
for root in self.roots:
node = self._find_covering_node(root, position)
if node:
return node
return None
def _find_covering_node(self, node: TreeNode, position: int) -> Optional[TreeNode]:
"""查找覆盖指定位置的节点"""
if node.position <= position < node.position + node.span:
if not node.children:
return node
# 递归查找子节点
for child in node.children:
result = self._find_covering_node(child, position)
if result:
return result
return node
return None
def get_approximate_kv(self,
positions: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
"""
获取近似KV(用于注意力计算)
Args:
positions: 位置列表
Returns:
approximated_keys, approximated_values: 近似KV
"""
num_positions = len(positions)
keys = torch.zeros(num_positions, self.num_heads, self.head_dim)
values = torch.zeros(num_positions, self.num_heads, self.head_dim)
for i, pos in enumerate(positions):
node = self.get_node_at_position(pos)
if node:
# 使用节点的KV表示
if self.store_children and node.children:
# 从子节点插值
keys[i], values[i] = self._interpolate_from_children(
node, pos - node.position
)
else:
keys[i] = node.key
values[i] = node.value
else:
# 回退到叶子节点
leaf_idx = min(pos, len(self.leaves) - 1)
if leaf_idx >= 0:
leaf = self.leaves[leaf_idx]
keys[i] = leaf.key
values[i] = leaf.value
return keys, values
def _interpolate_from_children(self,
node: TreeNode,
offset: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""从子节点插值"""
if not node.children:
return node.key, node.value
# 计算offset落在哪个子节点
cumsum = 0
for child in node.children:
if cumsum <= offset < cumsum + child.span:
# 在这个子节点内
child_offset = offset - cumsum
if child.children:
return self._interpolate_from_children(child, child_offset)
else:
return child.key, child.value
cumsum += child.span
# 默认返回第一个子节点
return node.children[0].key, node.children[0].value
def get_memory_usage(self) -> dict:
"""获取内存使用情况"""
num_nodes = len(self.nodes)
num_leaves = len(self.leaves)
num_roots = len(self.roots)
# 计算每个节点的字节数
bytes_per_kv = self.num_heads * self.head_dim * 4 * 2 # key + value, float32
# 总存储
if self.store_children:
total_bytes = num_nodes * bytes_per_kv
else:
total_bytes = num_nodes * bytes_per_kv
# 原始大小
original_bytes = self.current_position * bytes_per_kv
return {
'num_nodes': num_nodes,
'num_leaves': num_leaves,
'num_roots': num_roots,
'total_bytes': total_bytes,
'original_bytes': original_bytes,
'compression_ratio': original_bytes / max(1, total_bytes),
'num_positions': self.current_position
}
def get_tree_info(self) -> str:
"""获取树的信息"""
info_lines = []
info_lines.append(f"TreeKV Structure:")
info_lines.append(f" Total nodes: {len(self.nodes)}")
info_lines.append(f" Leaves: {len(self.leaves)}")
info_lines.append(f" Roots: {len(self.roots)}")
info_lines.append(f" Current position: {self.current_position}")
for level in range(self.max_depth + 1):
num_at_level = len(self.level_buffers[level])
info_lines.append(f" Level {level}: {num_at_level} nodes")
return "\n".join(info_lines)6.2 注意力计算集成
class TreeAttention(nn.Module):
"""
基于TreeKV的注意力计算
"""
def __init__(self,
hidden_dim: int,
num_heads: int,
head_dim: int,
max_seq_len: int,
max_children: int = 4,
compression_ratio: float = 0.3):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.compression_ratio = compression_ratio
# QKV投影
self.q_proj = nn.Linear(hidden_dim, num_heads * head_dim)
self.k_proj = nn.Linear(hidden_dim, num_heads * head_dim)
self.v_proj = nn.Linear(hidden_dim, num_heads * head_dim)
self.o_proj = nn.Linear(num_heads * head_dim, hidden_dim)
# TreeKV缓存
self.kv_cache = TreeKVCache(
num_heads=num_heads,
head_dim=head_dim,
max_seq_len=max_seq_len,
max_children=max_children,
store_children=False
)
def forward(self,
x: torch.Tensor,
use_cache: bool = True,
return_kv: bool = False):
"""
前向传播
Args:
x: 输入 [batch, seq_len, hidden_dim]
use_cache: 是否使用缓存
return_kv: 是否返回KV
Returns:
output: 输出
k, v: 如果return_kv=True
"""
batch, seq_len, _ = x.shape
# QKV投影
q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
k = self.k_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
v = self.v_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
# 维度重排
q = q.transpose(1, 2) # [batch, heads, seq, dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if use_cache and seq_len == 1 and self.kv_cache.current_position > 0:
# 解码阶段:使用缓存
k_cached, v_cached = self._get_cached_kv()
# 拼接
k_full = torch.cat([k_cached, k], dim=2)
v_full = torch.cat([v_cached, v], dim=2)
# 注意力计算
scale = self.head_dim ** -0.5
attn_weights = torch.matmul(q * scale, k_full.transpose(2, 3))
attn_weights = F.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, v_full)
# 更新缓存
self.kv_cache.add_token(k.squeeze(2), v.squeeze(2))
else:
# 预填充阶段
for i in range(seq_len):
self.kv_cache.add_token(k[:, :, i, :].squeeze(2),
v[:, :, i, :].squeeze(2))
# 直接计算注意力
scale = self.head_dim ** -0.5
attn_weights = torch.matmul(q * scale, k.transpose(2, 3))
attn_weights = F.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, v)
# 输出重排和投影
attn_output = attn_output.transpose(1, 2).contiguous()
output = self.o_proj(attn_output.view(batch, seq_len, -1))
if return_kv:
return output, k, v
return output
def _get_cached_kv(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""获取缓存的KV"""
if self.kv_cache.current_position == 0:
return torch.zeros(1, self.num_heads, 0, self.head_dim), \
torch.zeros(1, self.num_heads, 0, self.head_dim)
# 从树中提取KV
# 简化实现:直接使用叶子节点
num_positions = self.kv_cache.current_position
keys = torch.stack([leaf.key for leaf in self.kv_cache.leaves])
values = torch.stack([leaf.value for leaf in self.kv_cache.leaves])
# 应用压缩
if self.compression_ratio < 1.0:
num_keep = max(1, int(num_positions * self.compression_ratio))
# 均匀采样
indices = torch.linspace(0, num_positions - 1, num_keep).long()
keys = keys[indices]
values = values[indices]
return keys.unsqueeze(0), values.unsqueeze(0)7. 实验分析
7.1 实验设置
TreeKV在以下任务上进行了评估:
- 长文档问答:NarrativeQA、Qasper
- 大海捞针:Needle in a Haystack
- 代码补全:HumanEval
- 信息检索:PassKey Retrieval
模型配置:
- LLaMA-2-7B、LLaMA-3-8B
- 上下文长度:16K、32K、64K
TreeKV配置:
- 分支因子:4
- 最大深度:4
- 存储模式:仅聚合(不存储子节点)
7.2 内存效率
| 配置 | 节点数 | 内存(MB) | 压缩比 | 相对Full KV |
|---|---|---|---|---|
| Full KV | 32K | 1024 | 1x | 100% |
| TreeKV (4-ary) | 8.5K | 272 | 3.8x | 26.6% |
| TreeKV (8-ary) | 4.7K | 150 | 6.8x | 14.6% |
| H2O (30%) | 9.6K | 307 | 3.3x | 30% |
| SnapKV (30%) | 9.6K | 307 | 3.3x | 30% |
TreeKV (8-ary)在相同压缩比下实现了更紧凑的存储。
7.3 精度保持
任务: PassKey Retrieval
序列长度: 32K
压缩比: 4x
方法 准确率
───────────────────────────────
Full KV 98.7%
TreeKV (4-ary) 96.3%
TreeKV (8-ary) 95.1%
H2O (25%) 91.2%
SnapKV (25%) 93.8%
StreamingLLM 88.5%
TreeKV在压缩比相同的情况下,精度保持优于基线方法。
7.4 消融实验
| 配置 | PassKey@32K | 内存(MB) |
|---|---|---|
| TreeKV (baseline) | 95.2% | 200 |
| - 平滑合并 | 92.1% | 200 |
| - 动态树 | 93.8% | 215 |
| - 边界平滑 | 94.5% | 200 |
| TreeKV (full) | 96.3% | 200 |
平滑合并对精度提升贡献最大(约3.1%),其次是动态树构建(1.8%)和边界平滑(1.2%)。
7.5 树结构可视化
示例:16个token的TreeKV结构
深度0(叶子): [t0][t1][t2][t3]|[t4][t5][t6][t7]|[t8][t9][t10][t11]|[t12][t13][t14][t15]
↓ ↓ ↓ ↓
深度1: [n0:0-3] [n1:4-7] [n2:8-11] [n3:12-15]
↓
深度2: [n4:0-7] [n5:8-15]
↓
深度3: [root:0-15]
8. 总结与展望
8.1 主要贡献
TreeKV的主要贡献包括:
- 树结构表示:提出用树结构组织KV Cache,捕捉语义层级关系
- 平滑合并策略:通过平滑合并保持信息的连续性
- 动态构建:支持在线构建和重平衡,适应流式场景
- 多层次压缩:通过不同深度的节点实现灵活压缩
8.2 局限性
- 树构建和查询的计算开销
- 不平衡树可能导致性能下降
- 对于高度动态的注意力模式,平滑可能引入误差
8.3 未来方向
- 学习型树结构:使用强化学习或神经网络学习最优树结构
- 混合树:结合其他压缩方法(如量化)实现更高压缩率
- 硬件加速:设计专门的树遍历硬件单元
参考资料
Footnotes
-
AAAI Conference on Artificial Intelligence, 2024. ↩