H2O Heavy-Hitter注意力
1. 概述
H2O(Heavy-Hitter Oracle)是解决LLM推理中KV Cache内存瓶颈的重要方法之一。其核心思想是识别并保留对当前预测最重要的”重击手”(Heavy-Hitter)Token,同时淘汰对预测贡献较小的Token。
核心问题
在自回归生成中,每个新token的预测依赖于整个历史序列的注意力。然而:
- 并不是所有历史token都同等重要
- KV Cache随序列长度线性增长
- 显存限制了可处理的最大序列长度
H2O的解决方案
关键洞察:注意力分数可以作为Token重要性的代理指标
解决方案:动态维护一个包含”重击手”的精简KV Cache
2. 理论基础
2.1 注意力作为重要性度量
在Transformer中,注意力权重 表示token 对token 的”贡献度”:
重击手定义:对于给定位置 ,累积注意力权重最高的token构成其重击手集合:
2.2 理论保证
H2O的理论分析基于以下假设:
假设1(稀疏注意力假设):对于每个位置 ,存在常数 使得:
即只需保留Top-个注意力源即可恢复大部分信息。
定理(H2O近似保证):设 为使用H2O缓存的模型输出, 为使用完整KV Cache的输出,则:
其中 是模型权重的谱范数。
3. H2O算法
3.1 缓存状态
@dataclass
class H2OCacheState:
"""H2O缓存状态"""
# 当前保留的KV(最大budget个)
cache_k: torch.Tensor # [num_heads, budget, head_dim]
cache_v: torch.Tensor # [num_heads, budget, head_dim]
# 累积注意力分数(用于决定淘汰)
accum_scores: torch.Tensor # [num_heads, budget]
# 当前缓存的token位置
positions: List[int] # 保留token的原始位置
# budget: 最大缓存容量
budget: int3.2 重击手识别
def identify_heavy_hitters(
attention_scores: torch.Tensor,
cache_state: H2OCacheState,
budget: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
识别并保留重击手Token
Args:
attention_scores: 当前step的注意力分数 [num_heads, seq_len]
cache_state: 当前H2O缓存状态
Returns:
保留的KV和对应的累积分数
"""
num_heads, seq_len = attention_scores.shape
# 1. 更新累积分数
# 对于缓存中的token,累加新的注意力分数
new_accum = cache_state.accum_scores.clone()
# 获取当前step中缓存token对应的注意力分数
cached_attn = attention_scores[:, cache_state.positions]
new_accum += cached_attn
# 2. 计算不在缓存中的token的分数
# 这些需要与缓存中的某个位置竞争
# 3. 找到Top-budget个重击手
# 合并缓存token和当前token的分数
current_scores = attention_scores # 当前step的注意力
# 使用堆排序高效找到Top-k
combined_scores = torch.cat([
new_accum.unsqueeze(1), # 缓存token的累积分数
current_scores.unsqueeze(1) # 当前token的分数
], dim=1)
# 简化的Top-k选择
flat_scores = combined_scores.flatten()
values, indices = torch.topk(flat_scores, k=min(budget, len(flat_scores)))
return values, indices3.3 完整H2O更新算法
class H2OKVCache:
"""
H2O: Heavy-Hitter Oracle KV Cache
核心思想:维护累积注意力分数最高的token
"""
def __init__(
self,
num_heads: int,
head_dim: int,
budget: int = 64, # 每层保留的token数
device: str = "cuda"
):
self.num_heads = num_heads
self.head_dim = head_dim
self.budget = budget
self.device = device
# 缓存状态
self.cache_k = torch.zeros(
num_heads, 0, head_dim, device=device
)
self.cache_v = torch.zeros(
num_heads, 0, head_dim, device=device
)
self.accum_scores = torch.zeros(
num_heads, 0, device=device
)
self.positions = []
# 历史序列的完整表示(用于检索)
self.full_k = []
self.full_v = []
def update(
self,
k_new: torch.Tensor, # [num_heads, 1, head_dim]
v_new: torch.Tensor, # [num_heads, 1, head_dim]
attention_scores: torch.Tensor, # [num_heads, seq_len]
current_pos: int
):
"""
更新H2O缓存
Args:
k_new: 新的key向量
v_new: 新的value向量
attention_scores: 当前step的注意力分数
current_pos: 当前token在完整序列中的位置
"""
num_heads, _, head_dim = k_new.shape
# 如果缓存为空,直接添加
if self.cache_k.shape[1] == 0:
self.cache_k = k_new
self.cache_v = v_new
self.accum_scores = attention_scores.mean(dim=1, keepdim=True)
self.positions = [current_pos]
self.full_k = [k_new]
self.full_v = [v_new]
return
# 1. 累加缓存token的注意力分数
new_accum = self.accum_scores + attention_scores.mean(
dim=1, keepdim=True
)
# 2. 获取不在缓存中的token的注意力分数
# 这些是当前step新增的token
if len(self.full_k) < attention_scores.shape[1]:
# 新增了token
new_token_scores = attention_scores[:, len(self.full_k):]
else:
new_token_scores = attention_scores.new_zeros(
num_heads, 1
)
# 3. 合并所有候选者
all_k = torch.cat([self.cache_k, k_new], dim=1)
all_v = torch.cat([self.cache_v, v_new], dim=1)
all_scores = torch.cat([
new_accum,
attention_scores.mean(dim=1, keepdim=True)
], dim=1)
all_positions = self.positions + [current_pos]
# 4. 选择Top-budget个重击手
if all_k.shape[1] > self.budget:
scores_flat = all_scores.flatten()
_, top_indices = torch.topk(scores_flat, k=self.budget)
# 重新组织缓存
self.cache_k = all_k[:, top_indices, :]
self.cache_v = all_v[:, top_indices, :]
self.accum_scores = all_scores[:, top_indices]
self.positions = [all_positions[i] for i in top_indices]
else:
self.cache_k = all_k
self.cache_v = all_v
self.accum_scores = all_scores
self.positions = all_positions
# 5. 保存完整历史(用于后续检索)
self.full_k.append(k_new)
self.full_v.append(v_new)
def get_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""获取当前缓存的KV"""
return self.cache_k, self.cache_v4. 累积分数管理策略
4.1 指数衰减策略
避免旧token的累积分数过高:
class DecayH2OKVCache(H2OKVCache):
"""
带衰减的H2O缓存
"""
def __init__(self, *args, decay_factor: float = 0.9, **kwargs):
super().__init__(*args, **kwargs)
self.decay_factor = decay_factor
def decay_scores(self):
"""定期衰减累积分数"""
self.accum_scores = self.accum_scores * self.decay_factor4.2 层级缓存策略
不同层使用不同的budget:
class LayerwiseH2OKVCache:
"""
层级H2O:不同层使用不同缓存容量
"""
def __init__(
self,
num_layers: int,
num_heads: int,
head_dim: int,
base_budget: int = 64,
pyramid_ratio: float = 0.5,
device: str = "cuda"
):
# 计算每层的budget
self.layer_caches = []
for layer_idx in range(num_layers):
depth_ratio = layer_idx / max(num_layers - 1, 1)
budget = int(base_budget * (1 - pyramid_ratio * depth_ratio))
budget = max(budget, 16)
self.layer_caches.append(
H2OKVCache(num_heads, head_dim, budget, device)
)
def update_layer(self, layer_idx, k, v, attn_scores, pos):
self.layer_caches[layer_idx].update(k, v, attn_scores, pos)
def get_layer_cache(self, layer_idx) -> Tuple[torch.Tensor, torch.Tensor]:
return self.layer_caches[layer_idx].get_cache()5. 与Transformer集成
5.1 H2O注意力层
class H2OAttention(nn.Module):
"""
使用H2O缓存的注意力层
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
head_dim: int,
budget: int = 64,
dropout: float = 0.0
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
self.budget = budget
# 投影层
self.q_proj = nn.Linear(hidden_size, num_heads * head_dim)
self.k_proj = nn.Linear(hidden_size, num_heads * head_dim)
self.v_proj = nn.Linear(hidden_size, num_heads * head_dim)
self.o_proj = nn.Linear(num_heads * head_dim, hidden_size)
# H2O缓存
self.h2o_cache = None
self.dropout = nn.Dropout(dropout)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = True
):
B, T, _ = hidden_states.shape
# 投影得到QKV
q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim)
k = self.k_proj(hidden_states).view(B, T, self.num_heads, self.head_dim)
v = self.v_proj(hidden_states).view(B, T, self.num_heads, self.head_dim)
# 调整维度顺序
q = q.transpose(1, 2) # [B, H, T, D]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# 计算注意力分数
scale = self.head_dim ** -0.5
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
# 应用mask
if attention_mask is not None:
attn_scores = attn_scores.masked_fill(attention_mask == 0, float('-inf'))
attn_weights = F.softmax(attn_scores, dim=-1)
# 获取H2O缓存的KV
if use_cache and self.h2o_cache is not None:
cache_k, cache_v = self.h2o_cache.get_cache()
# 使用缓存的KV进行注意力计算
if T == 1:
# Decode阶段
attn_output = self._h2o_attention(
q, cache_k, cache_v, attn_weights[:, :, -1:, :]
)
else:
# Prefill阶段:正常计算后更新缓存
attn_output = torch.matmul(attn_weights, v)
# 更新H2O缓存
self._update_h2o_cache(k, v, attn_weights)
else:
attn_output = torch.matmul(attn_weights, v)
# 输出投影
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(B, T, -1)
return self.o_proj(attn_output)
def _h2o_attention(
self,
q: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
attn_weights: torch.Tensor
) -> torch.Tensor:
"""使用H2O缓存的注意力计算"""
scale = self.head_dim ** -0.5
# 缓存的注意力分数
cached_scores = torch.matmul(q, cache_k.transpose(-2, -1)) * scale
# 合并新token和缓存
all_k = torch.cat([cache_k, q.transpose(1, 2)], dim=2)
all_v = torch.cat([cache_v, q.transpose(1, 2)], dim=2)
all_scores = torch.cat([attn_weights, cached_scores], dim=-1)
# Softmax归一化
all_weights = F.softmax(all_scores, dim=-1)
# 计算输出
return torch.matmul(all_weights, all_v)
def _update_h2o_cache(self, k, v, attn_weights):
"""更新H2O缓存"""
# 初始化缓存(如果需要)
if self.h2o_cache is None:
self.h2o_cache = H2OKVCache(
self.num_heads,
self.head_dim,
self.budget,
device=k.device
)
# 更新每个位置的缓存
B, H, T, D = k.shape
for t in range(T):
self.h2o_cache.update(
k[:, :, t:t+1, :],
v[:, :, t:t+1, :],
attn_weights[:, :, t, :],
t
)6. 实验结果
6.1 内存效率
| 模型 | Budget | 缓存Token | 内存节省 | 困惑度变化 |
|---|---|---|---|---|
| LLaMA-7B | 64 | 64/layer | 75% | +0.02 |
| LLaMA-7B | 32 | 32/layer | 87% | +0.05 |
| LLaMA-7B | 16 | 16/layer | 93% | +0.12 |
6.2 任务性能
在各种任务上的性能对比:
| 任务 | 完整KV | H2O-64 | H2O-32 | H2O-16 |
|---|---|---|---|---|
| WikiText | 12.45 | 12.47 | 12.50 | 12.57 |
| PIQA | 79.2 | 79.0 | 78.8 | 78.3 |
| BoolQ | 76.4 | 76.2 | 75.9 | 75.1 |
6.3 长上下文性能
在长序列任务上的表现:
| 任务 | 序列长度 | 完整KV | H2O-64 |
|---|---|---|---|
| PassKey | 32K | 98.2% | 97.1% |
| Needle | 128K | 95.1% | 93.8% |
| Summarization | 64K | 42.3 | 41.8 |
7. H2O vs 其他方法
| 方法 | 核心思想 | 选择策略 | 适用场景 |
|---|---|---|---|
| H2O | 累积注意力分数 | Top-k淘汰 | 通用推理 |
| PyramidKV | 层间差异 | 金字塔递减 | 长上下文 |
| StreamingLLM | 局部性 | 固定窗口 | 流式生成 |
| SnapKV | 相似性聚类 | 模式匹配 | 特定任务 |
8. 实践建议
8.1 Budget选择
# Budget选择指南
# 通用推理
budget_config = {
'base_budget': 64, # 适合大多数场景
'pyramid_ratio': 0.0 # 不使用层级
}
# 长上下文
budget_config = {
'base_budget': 128, # 更多缓存
'pyramid_ratio': 0.3 # 浅层更多
}
# 内存受限
budget_config = {
'base_budget': 32, # 极小缓存
'use_decay': True, # 使用衰减
'decay_factor': 0.95
}8.2 性能优化
# 异步缓存更新
class AsyncH2OCache(H2OKVCache):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.update_queue = []
def async_update(self, k, v, scores, pos):
"""异步更新,不阻塞主计算"""
self.update_queue.append((k, v, scores, pos))
def process_updates(self):
"""批量处理更新"""
for k, v, scores, pos in self.update_queue:
self.update(k, v, scores, pos)
self.update_queue.clear()9. 总结
H2O的核心贡献:
- 理论支撑:基于注意力稀疏性的理论保证
- 简单有效:无需重训练,即插即用
- 灵活配置:可调整budget平衡内存与性能
- 可扩展性:可与其他方法(如PyramidKV)结合