Trellis:学习压缩Key-Value记忆
1. 问题背景
1.1 KV Cache的困境
在自回归语言模型的推理过程中,KV Cache 是加速生成的关键技术。它缓存了已处理token的Key和Value向量,避免重复计算。然而,KV Cache面临严重的可扩展性问题:
1.1.1 内存随序列长度线性增长
对于序列长度为 的生成:
- 每个token的KV向量大小:(FP32)
- 总内存占用:
- 示例:LLaMA-7B (seq_len=32K) → ~16GB KV Cache
1.1.2 注意力计算随序列长度平方增长
当 很大时,即使有KV Cache,注意力计算仍是主要瓶颈。
1.2 现有方法的局限
| 方法 | 策略 | 问题 |
|---|---|---|
| StreamingLLM | 保留汇聚token | 丢失细粒度信息 |
| H2O | 动态驱逐 | 决策粗糙,可能丢失关键信息 |
| PyramidKV | 金字塔式缓存 | 固定压缩率,难以适应 |
| KV Quantization | 低精度存储 | 有损压缩,精度损失 |
1.3 Trellis的核心洞察
Trellis 提出了一个根本性的范式转变:
用固定大小的”记忆”替代随序列增长的KV Cache
核心思想:
- 维护一个固定大小的记忆缓冲区(与输入长度无关)
- 学习一个两遍递归压缩机制
- 模型自己决定何时压缩和如何压缩
2. 技术详解
2.1 形式化定义
2.1.1 标准KV Cache
标准Transformer中,KV Cache存储为:
2.1.2 Trellis记忆缓冲区
Trellis使用固定大小的记忆缓冲区:
其中 是固定的记忆容量,与输入序列长度无关。
2.1.3 两遍压缩机制
Trellis采用两遍递归压缩(Two-Pass Recursive Compression):
第一遍(收集阶段):
遍历所有KV对 → 评估压缩收益 → 识别待压缩的KV
第二遍(执行阶段):
对识别出的KV进行压缩 → 更新记忆缓冲区
2.2 压缩决策网络
Trellis的核心是一个压缩决策网络,决定哪些KV应该被压缩。
2.2.1 压缩收益评估
对于第 个token的KV对 ,定义压缩收益:
其中 是压缩后的KV表示。
2.2.2 决策网络架构
class CompressionDecisionNetwork(nn.Module):
"""
决定何时压缩的决策网络
"""
def __init__(self, d_model: int, d_k: int):
super().__init__()
# 评估网络:判断单个KV对的重要性
self.evaluator = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, 1),
nn.Sigmoid()
)
# 压缩网络:生成压缩后的表示
self.compressor = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, d_k)
)
def forward(self, k: torch.Tensor, v: torch.Tensor,
memory: tuple) -> tuple:
"""
Args:
k: Key向量 [d_k]
v: Value向量 [d_v]
memory: 当前记忆缓冲区 (M_k, M_v)
Returns:
compress: 是否压缩的决策
compressed_k: 压缩后的Key
compressed_v: 压缩后的Value
"""
# 评估压缩收益
gain = self.evaluator(k)
# 决策阈值(可学习)
threshold = 0.5
# 如果收益低于阈值,则压缩
compress = (gain < threshold).float()
# 生成压缩表示
kv_concat = torch.cat([k, v], dim=-1)
compressed = self.compressor(kv_concat)
compressed_k, compressed_v = compressed.chunk(2, dim=-1)
return compress, compressed_k, compressed_v2.3 两遍递归压缩算法
def trellis_compress(kv_pairs: list, memory: tuple,
memory_size: int, capacity: float) -> tuple:
"""
Trellis两遍递归压缩算法
Args:
kv_pairs: 所有KV对 [(k1,v1), (k2,v2), ...]
memory: 初始记忆缓冲区
memory_size: 缓冲区大小
capacity: 压缩触发容量(0到1之间)
Returns:
M_k, M_v: 最终记忆缓冲区
"""
M_k, M_v = memory
# ========== 第一遍:收集阶段 ==========
# 计算每个KV的重要性分数
importance_scores = []
for k, v in kv_pairs:
# 评估与当前记忆的关联度
score = compute_importance(k, v, M_k, M_v)
importance_scores.append(score)
# 识别需要压缩的KV
capacity_threshold = int(len(kv_pairs) * capacity)
_, compress_indices = torch.topk(
torch.tensor(importance_scores),
capacity_threshold,
largest=False # 重要性最低的将被压缩
)
# ========== 第二遍:执行阶段 ==========
new_kv_pairs = []
for i, (k, v) in enumerate(kv_pairs):
if i in compress_indices:
# 压缩这个KV
compressed_k, compressed_v = compress_kv(k, v)
new_kv_pairs.append((compressed_k, compressed_v))
else:
# 保留原KV
new_kv_pairs.append((k, v))
# 更新记忆缓冲区(可能需要进一步聚合)
M_k_new, M_v_new = update_memory(new_kv_pairs, M_k, M_v)
# 如果超过固定大小,进行递归压缩
if len(M_k_new) > memory_size:
return trellis_compress(
[(M_k_new[i], M_v_new[i]) for i in range(len(M_k_new))],
(M_k_new, M_v_new),
memory_size,
capacity * 0.8 # 降低容量触发阈值
)
return M_k_new, M_v_new
def update_memory(kv_pairs: list, M_k: torch.Tensor,
M_v: torch.Tensor) -> tuple:
"""
更新记忆缓冲区
使用门控机制决定是添加新KV还是聚合现有KV
"""
new_entries = []
for k, v in kv_pairs:
# 计算与现有记忆的相似度
similarities = torch.matmul(M_k, k.unsqueeze(-1)).squeeze(-1)
max_sim, best_match = similarities.max(dim=0)
if max_sim > 0.9: # 高度相似,聚合
# 门控更新
gate = torch.sigmoid(nn.Linear(2 * d_model, 1)(torch.cat([k, M_k[best_match]], dim=-1)))
updated_v = gate * v + (1 - gate) * M_v[best_match]
new_entries.append((k, updated_v))
else: # 不相似,添加新条目
new_entries.append((k, v))
# 合并现有记忆和新条目
all_k = torch.cat([M_k] + [k.unsqueeze(0) for k, _ in new_entries], dim=0)
all_v = torch.cat([M_v] + [v.unsqueeze(0) for _, v in new_entries], dim=0)
return all_k, all_v2.4 与注意力的集成
Trellis的压缩记忆直接用于注意力计算:
注意:这与标准注意力的形式完全一致,只是将 替换为压缩后的 。
3. 理论分析
3.1 表达能力保证
定理(表达能力保持):在适当条件下,Trellis的记忆表示保持了原始KV序列的表达能力。
证明概要:
- 假设压缩决策网络能够正确识别冗余KV
- 使用信息瓶颈理论分析压缩损失
- 门控机制确保关键信息不被丢弃
3.2 计算复杂度
| 阶段 | 标准KV Cache | Trellis |
|---|---|---|
| 存储复杂度 | (固定) | |
| 注意力计算 | ||
| 压缩开销 | 无 |
其中 是决策网络的复杂度,通常 。
3.3 内存节省
设记忆容量 ,不同序列长度的内存节省:
| 序列长度 | 标准KV Cache | Trellis | 节省比例 |
|---|---|---|---|
| 1K | 1× | 0.26× | 74% |
| 4K | 4× | 0.07× | 93% |
| 16K | 16× | 0.02× | 98% |
| 32K | 32× | 0.01× | 99% |
4. PyTorch实现
4.1 核心Trellis模块
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class TrellisKVCache(nn.Module):
"""
Trellis: 学习压缩的KV Cache
使用固定大小记忆替代标准KV Cache
"""
def __init__(
self,
d_model: int,
num_heads: int,
memory_size: int = 256,
compression_threshold: float = 0.5,
):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.memory_size = memory_size
self.compression_threshold = compression_threshold
# 投影层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
# 记忆缓冲区(可学习参数)
self.register_buffer(
'memory_k',
torch.zeros(memory_size, num_heads, self.d_k)
)
self.register_buffer(
'memory_v',
torch.zeros(memory_size, num_heads, self.d_k)
)
self.register_buffer('memory_ptr', torch.tensor(0))
# 压缩决策网络
self.decision_net = CompressionDecisionNetwork(
d_model, self.d_k, memory_size
)
def reset_memory(self):
"""重置记忆缓冲区"""
self.memory_k.zero_()
self.memory_v.zero_()
self.memory_ptr.fill_(0)
def decide_compress(self, k: torch.Tensor, v: torch.Tensor) -> tuple:
"""
决定是否压缩单个KV对
Args:
k: Key向量 [batch, num_heads, d_k]
v: Value向量 [batch, num_heads, d_v]
Returns:
compress: 是否压缩的决策
compressed_k, compressed_v: 压缩后的向量
"""
# 计算与当前记忆的关联度
current_memory = self.memory_k[:self.memory_ptr]
if self.memory_ptr > 0:
# 相似度计算
similarities = torch.einsum('bhd,mhd->bhm', k, current_memory)
max_sim = similarities.max(dim=-1)[0].mean()
else:
max_sim = torch.tensor(0.0, device=k.device)
# 决策阈值
should_compress = (max_sim < self.compression_threshold).float()
# 生成压缩表示(简化的线性压缩)
# 实际实现中应使用更复杂的压缩网络
compressed_k = k.mean(dim=1, keepdim=True).expand_as(k)
compressed_v = v.mean(dim=1, keepdim=True).expand_as(v)
return should_compress, compressed_k, compressed_v
def add_to_memory(self, k: torch.Tensor, v: torch.Tensor):
"""
将KV对添加到记忆缓冲区
Args:
k: Key向量
v: Value向量
"""
batch_size = k.shape[0]
# 如果缓冲区已满,先压缩
if self.memory_ptr + batch_size > self.memory_size:
self._compress_memory()
# 添加到记忆
end_ptr = min(
self.memory_ptr + batch_size,
self.memory_size
)
self.memory_k[self.memory_ptr:end_ptr] = k[:end_ptr-self.memory_ptr]
self.memory_v[self.memory_ptr:end_ptr] = v[:end_ptr-self.memory_ptr]
self.memory_ptr = end_ptr
def _compress_memory(self):
"""
压缩记忆缓冲区
使用聚类合并相似条目
"""
current_size = self.memory_ptr.item()
if current_size <= self.memory_size // 2:
return
# 简单的两两合并
new_size = current_size // 2
new_k = torch.zeros(new_size, self.num_heads, self.d_k, device=self.memory_k.device)
new_v = torch.zeros(new_size, self.num_heads, self.d_v, device=self.memory_v.device)
for i in range(new_size):
# 合并相邻的两个条目
new_k[i] = (self.memory_k[2*i] + self.memory_k[2*i+1]) / 2
new_v[i] = (self.memory_v[2*i] + self.memory_v[2*i+1]) / 2
self.memory_k[:new_size] = new_k
self.memory_v[:new_size] = new_v
self.memory_ptr.fill_(new_size)
def forward(
self,
q: torch.Tensor,
use_memory: bool = True,
attention_mask: torch.Tensor = None
) -> torch.Tensor:
"""
使用记忆缓冲区计算注意力
Args:
q: Query向量 [batch, seq_len, d_model]
use_memory: 是否使用压缩记忆
attention_mask: 额外的注意力掩码
"""
batch_size, seq_len, _ = q.shape
# QKV投影
Q = self.W_q(q)
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
if use_memory:
# 使用记忆缓冲区
K = self.memory_k[:self.memory_ptr].unsqueeze(0).expand(batch_size, -1, -1, -1)
V = self.memory_v[:self.memory_ptr].unsqueeze(0).expand(batch_size, -1, -1, -1)
else:
# 标准注意力(用于对比)
K = self.W_k(q).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(q).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# 注意力计算
scale = math.sqrt(self.d_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
if attention_mask is not None:
scores = scores.masked_fill(attention_mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
context = torch.matmul(attn_weights, V)
# 重组输出
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.W_o(context)
class CompressionDecisionNetwork(nn.Module):
"""
压缩决策网络
决定哪些KV对应该被压缩
"""
def __init__(self, d_model: int, d_k: int, memory_size: int):
super().__init__()
self.d_model = d_model
self.d_k = d_k
# 评估网络:判断重要性
self.evaluator = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, 1),
)
# 压缩网络:生成压缩表示
self.compressor = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, d_k * 2), # 输出压缩后的k和v
)
# 门控网络
self.gate_net = nn.Sequential(
nn.Linear(d_model * 2, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, 1),
nn.Sigmoid()
)
def forward(
self,
k: torch.Tensor,
v: torch.Tensor,
memory_k: torch.Tensor,
memory_v: torch.Tensor
) -> tuple:
"""
Args:
k: Key向量 [batch, num_heads, d_k]
v: Value向量 [batch, num_heads, d_v]
memory_k: 当前记忆中的Key [memory_size, num_heads, d_k]
memory_v: 当前记忆中的Value [memory_size, num_heads, d_v]
"""
# 评估重要性
kv_concat = torch.cat([k, v], dim=-1) # [B, H, d_k + d_v]
importance = self.evaluator(kv_concat.mean(dim=1)) # [B, 1]
# 决定是否压缩
should_compress = (torch.sigmoid(importance) < 0.5).float()
# 生成压缩表示
compressed = self.compressor(kv_concat.mean(dim=1)) # [B, d_k * 2]
compressed_k, compressed_v = compressed.chunk(2, dim=-1)
return should_compress, compressed_k, compressed_v4.2 端到端使用示例
class TrellisTransformerLayer(nn.Module):
"""
使用Trellis KV Cache的Transformer层
"""
def __init__(self, d_model: int, num_heads: int, d_ff: int, memory_size: int = 256):
super().__init__()
self.attention = TrellisKVCache(d_model, num_heads, memory_size)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model)
)
def forward(self, x: torch.Tensor, use_memory: bool = True) -> torch.Tensor:
# 自注意力 + Trellis
attn_out = self.attention(x, use_memory=use_memory)
x = self.norm1(x + attn_out)
# FFN
ffn_out = self.ffn(x)
x = self.norm2(x + ffn_out)
return x
# 使用示例
model = TrellisTransformerLayer(d_model=512, num_heads=8, memory_size=256)
# 第一个forward:添加KV到记忆
x1 = torch.randn(1, 10, 512) # prefix
output1 = model(x1, use_memory=True)
# 后续forward:使用压缩记忆
x2 = torch.randn(1, 5, 512) # decoding
output2 = model(x2, use_memory=True) # 使用记忆注意力
# 完全重置
model.attention.reset_memory()5. 实验结果
5.1 基准测试
| 任务 | 标准KV Cache | StreamingLLM | H2O | Trellis |
|---|---|---|---|---|
| LAMBADA | 100% | 94.2% | 97.8% | 99.1% |
| PG-19 | 100% | 91.5% | 96.3% | 98.2% |
| ArXiv | 100% | 88.7% | 95.1% | 97.6% |
| Story | 100% | 93.1% | 97.2% | 98.9% |
5.2 内存-性能权衡
记忆容量 m | 困惑度 (WikiText-103) | 内存使用
-------------|------------------------|----------
32 | 24.7 | 0.12× (稀疏)
64 | 21.3 | 0.25×
128 | 19.8 | 0.50×
256 | 18.9 | 1.00×
512 | 18.5 | 2.00×
1024 | 18.3 | 4.00×
∞ (标准) | 18.2 | n × (线性增长)
5.3 与序列长度的关系
序列长度 | 标准内存 | Trellis内存 | 加速比
----------|-----------|--------------|--------
1K | 1.0 GB | 0.5 GB | 0.5× (额外开销)
4K | 4.0 GB | 0.5 GB | 1.3×
16K | 16.0 GB | 0.5 GB | 2.8×
32K | 32.0 GB | 0.5 GB | 4.1×
64K | 64.0 GB | 0.5 GB | 6.7×
128K | 128.0 GB | 0.5 GB | 9.2×
6. 应用场景
6.1 长文档摘要
# 长文档摘要场景
document = load_long_document() # 100K tokens
trellis = TrellisKVCache(d_model=512, memory_size=256)
# 分段处理文档
for chunk in split_into_chunks(document, chunk_size=512):
# 处理每个chunk
hidden = encoder(chunk)
# 添加到记忆
trellis.add_to_memory(hidden, hidden)
# 使用压缩记忆生成摘要
summary_hidden = decoder(start_token)
summary = decoder(trellis.forward(summary_hidden, use_memory=True))6.2 多轮对话
# 多轮对话场景
conversation_history = []
trellis = TrellisKVCache(d_model=512, memory_size=256)
for turn in range(100): # 100轮对话
user_input = get_user_input()
# 将对话历史添加到记忆
context = tokenizer.encode(conversation_history)
context_hidden = encoder(context)
trellis.add_to_memory(context_hidden, context_hidden)
# 生成回复
reply_hidden = decoder(start_token)
reply = decoder(trellis.forward(reply_hidden, use_memory=True))
conversation_history.append(reply)6.3 代码补全
# 大型代码库的补全场景
repo = load_large_repository() # 可能包含数百万行代码
trellis = TrellisKVCache(d_model=768, memory_size=512)
# 处理整个代码库,构建压缩记忆
for file in repo.files():
ast = parser.parse(file)
hidden = encoder(ast)
trellis.add_to_memory(hidden, hidden)
# 基于压缩记忆进行补全
cursor = get_cursor_position()
context = get_context(cursor)
completion = trellis.generate(context)7. 与相关工作的对比
7.1 vs StreamingLLM
| 方面 | StreamingLLM | Trellis |
|---|---|---|
| 记忆形式 | 汇聚token + 局部窗口 | 可学习的压缩记忆 |
| 压缩策略 | 固定 | 自适应学习 |
| 信息保留 | 粗粒度 | 细粒度 |
| 实现复杂度 | 低 | 中等 |
7.2 vs H2O
| 方面 | H2O | Trellis |
|---|---|---|
| 决策方式 | 最近最少用 | 学习决策网络 |
| 压缩粒度 | 全局驱逐 | 按token压缩 |
| 表达能力 | 可能丢失关键信息 | 理论保证保持 |
| 适应性 | 静态 | 动态适应 |