H2O Heavy-Hitter注意力

1. 概述

H2O(Heavy-Hitter Oracle)是解决LLM推理中KV Cache内存瓶颈的重要方法之一。其核心思想是识别并保留对当前预测最重要的”重击手”(Heavy-Hitter)Token,同时淘汰对预测贡献较小的Token。

核心问题

在自回归生成中,每个新token的预测依赖于整个历史序列的注意力。然而:

  • 并不是所有历史token都同等重要
  • KV Cache随序列长度线性增长
  • 显存限制了可处理的最大序列长度

H2O的解决方案

关键洞察:注意力分数可以作为Token重要性的代理指标

解决方案:动态维护一个包含”重击手”的精简KV Cache

2. 理论基础

2.1 注意力作为重要性度量

在Transformer中,注意力权重 表示token 对token 的”贡献度”:

重击手定义:对于给定位置 ,累积注意力权重最高的token构成其重击手集合:

2.2 理论保证

H2O的理论分析基于以下假设:

假设1(稀疏注意力假设):对于每个位置 ,存在常数 使得:

即只需保留Top-个注意力源即可恢复大部分信息。

定理(H2O近似保证):设 为使用H2O缓存的模型输出, 为使用完整KV Cache的输出,则:

其中 是模型权重的谱范数。

3. H2O算法

3.1 缓存状态

@dataclass
class H2OCacheState:
    """H2O缓存状态"""
    # 当前保留的KV(最大budget个)
    cache_k: torch.Tensor  # [num_heads, budget, head_dim]
    cache_v: torch.Tensor  # [num_heads, budget, head_dim]
    
    # 累积注意力分数(用于决定淘汰)
    accum_scores: torch.Tensor  # [num_heads, budget]
    
    # 当前缓存的token位置
    positions: List[int]  # 保留token的原始位置
    
    # budget: 最大缓存容量
    budget: int

3.2 重击手识别

def identify_heavy_hitters(
    attention_scores: torch.Tensor,
    cache_state: H2OCacheState,
    budget: int
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    识别并保留重击手Token
    
    Args:
        attention_scores: 当前step的注意力分数 [num_heads, seq_len]
        cache_state: 当前H2O缓存状态
        
    Returns:
        保留的KV和对应的累积分数
    """
    num_heads, seq_len = attention_scores.shape
    
    # 1. 更新累积分数
    # 对于缓存中的token,累加新的注意力分数
    new_accum = cache_state.accum_scores.clone()
    
    # 获取当前step中缓存token对应的注意力分数
    cached_attn = attention_scores[:, cache_state.positions]
    new_accum += cached_attn
    
    # 2. 计算不在缓存中的token的分数
    # 这些需要与缓存中的某个位置竞争
    
    # 3. 找到Top-budget个重击手
    # 合并缓存token和当前token的分数
    current_scores = attention_scores  # 当前step的注意力
    
    # 使用堆排序高效找到Top-k
    combined_scores = torch.cat([
        new_accum.unsqueeze(1),  # 缓存token的累积分数
        current_scores.unsqueeze(1)  # 当前token的分数
    ], dim=1)
    
    # 简化的Top-k选择
    flat_scores = combined_scores.flatten()
    values, indices = torch.topk(flat_scores, k=min(budget, len(flat_scores)))
    
    return values, indices

3.3 完整H2O更新算法

class H2OKVCache:
    """
    H2O: Heavy-Hitter Oracle KV Cache
    
    核心思想:维护累积注意力分数最高的token
    """
    
    def __init__(
        self,
        num_heads: int,
        head_dim: int,
        budget: int = 64,  # 每层保留的token数
        device: str = "cuda"
    ):
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.budget = budget
        self.device = device
        
        # 缓存状态
        self.cache_k = torch.zeros(
            num_heads, 0, head_dim, device=device
        )
        self.cache_v = torch.zeros(
            num_heads, 0, head_dim, device=device
        )
        self.accum_scores = torch.zeros(
            num_heads, 0, device=device
        )
        self.positions = []
        
        # 历史序列的完整表示(用于检索)
        self.full_k = []
        self.full_v = []
        
    def update(
        self,
        k_new: torch.Tensor,   # [num_heads, 1, head_dim]
        v_new: torch.Tensor,   # [num_heads, 1, head_dim]
        attention_scores: torch.Tensor,  # [num_heads, seq_len]
        current_pos: int
    ):
        """
        更新H2O缓存
        
        Args:
            k_new: 新的key向量
            v_new: 新的value向量
            attention_scores: 当前step的注意力分数
            current_pos: 当前token在完整序列中的位置
        """
        num_heads, _, head_dim = k_new.shape
        
        # 如果缓存为空,直接添加
        if self.cache_k.shape[1] == 0:
            self.cache_k = k_new
            self.cache_v = v_new
            self.accum_scores = attention_scores.mean(dim=1, keepdim=True)
            self.positions = [current_pos]
            self.full_k = [k_new]
            self.full_v = [v_new]
            return
        
        # 1. 累加缓存token的注意力分数
        new_accum = self.accum_scores + attention_scores.mean(
            dim=1, keepdim=True
        )
        
        # 2. 获取不在缓存中的token的注意力分数
        # 这些是当前step新增的token
        if len(self.full_k) < attention_scores.shape[1]:
            # 新增了token
            new_token_scores = attention_scores[:, len(self.full_k):]
        else:
            new_token_scores = attention_scores.new_zeros(
                num_heads, 1
            )
        
        # 3. 合并所有候选者
        all_k = torch.cat([self.cache_k, k_new], dim=1)
        all_v = torch.cat([self.cache_v, v_new], dim=1)
        all_scores = torch.cat([
            new_accum, 
            attention_scores.mean(dim=1, keepdim=True)
        ], dim=1)
        all_positions = self.positions + [current_pos]
        
        # 4. 选择Top-budget个重击手
        if all_k.shape[1] > self.budget:
            scores_flat = all_scores.flatten()
            _, top_indices = torch.topk(scores_flat, k=self.budget)
            
            # 重新组织缓存
            self.cache_k = all_k[:, top_indices, :]
            self.cache_v = all_v[:, top_indices, :]
            self.accum_scores = all_scores[:, top_indices]
            self.positions = [all_positions[i] for i in top_indices]
        else:
            self.cache_k = all_k
            self.cache_v = all_v
            self.accum_scores = all_scores
            self.positions = all_positions
        
        # 5. 保存完整历史(用于后续检索)
        self.full_k.append(k_new)
        self.full_v.append(v_new)
        
    def get_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """获取当前缓存的KV"""
        return self.cache_k, self.cache_v

4. 累积分数管理策略

4.1 指数衰减策略

避免旧token的累积分数过高:

class DecayH2OKVCache(H2OKVCache):
    """
    带衰减的H2O缓存
    """
    
    def __init__(self, *args, decay_factor: float = 0.9, **kwargs):
        super().__init__(*args, **kwargs)
        self.decay_factor = decay_factor
    
    def decay_scores(self):
        """定期衰减累积分数"""
        self.accum_scores = self.accum_scores * self.decay_factor

4.2 层级缓存策略

不同层使用不同的budget:

class LayerwiseH2OKVCache:
    """
    层级H2O:不同层使用不同缓存容量
    """
    
    def __init__(
        self,
        num_layers: int,
        num_heads: int,
        head_dim: int,
        base_budget: int = 64,
        pyramid_ratio: float = 0.5,
        device: str = "cuda"
    ):
        # 计算每层的budget
        self.layer_caches = []
        for layer_idx in range(num_layers):
            depth_ratio = layer_idx / max(num_layers - 1, 1)
            budget = int(base_budget * (1 - pyramid_ratio * depth_ratio))
            budget = max(budget, 16)
            
            self.layer_caches.append(
                H2OKVCache(num_heads, head_dim, budget, device)
            )
    
    def update_layer(self, layer_idx, k, v, attn_scores, pos):
        self.layer_caches[layer_idx].update(k, v, attn_scores, pos)
    
    def get_layer_cache(self, layer_idx) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.layer_caches[layer_idx].get_cache()

5. 与Transformer集成

5.1 H2O注意力层

class H2OAttention(nn.Module):
    """
    使用H2O缓存的注意力层
    """
    
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        head_dim: int,
        budget: int = 64,
        dropout: float = 0.0
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.budget = budget
        
        # 投影层
        self.q_proj = nn.Linear(hidden_size, num_heads * head_dim)
        self.k_proj = nn.Linear(hidden_size, num_heads * head_dim)
        self.v_proj = nn.Linear(hidden_size, num_heads * head_dim)
        self.o_proj = nn.Linear(num_heads * head_dim, hidden_size)
        
        # H2O缓存
        self.h2o_cache = None
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        use_cache: bool = True
    ):
        B, T, _ = hidden_states.shape
        
        # 投影得到QKV
        q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim)
        k = self.k_proj(hidden_states).view(B, T, self.num_heads, self.head_dim)
        v = self.v_proj(hidden_states).view(B, T, self.num_heads, self.head_dim)
        
        # 调整维度顺序
        q = q.transpose(1, 2)  # [B, H, T, D]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # 计算注意力分数
        scale = self.head_dim ** -0.5
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
        
        # 应用mask
        if attention_mask is not None:
            attn_scores = attn_scores.masked_fill(attention_mask == 0, float('-inf'))
        
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # 获取H2O缓存的KV
        if use_cache and self.h2o_cache is not None:
            cache_k, cache_v = self.h2o_cache.get_cache()
            
            # 使用缓存的KV进行注意力计算
            if T == 1:
                # Decode阶段
                attn_output = self._h2o_attention(
                    q, cache_k, cache_v, attn_weights[:, :, -1:, :]
                )
            else:
                # Prefill阶段:正常计算后更新缓存
                attn_output = torch.matmul(attn_weights, v)
                
                # 更新H2O缓存
                self._update_h2o_cache(k, v, attn_weights)
        else:
            attn_output = torch.matmul(attn_weights, v)
        
        # 输出投影
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(B, T, -1)
        
        return self.o_proj(attn_output)
    
    def _h2o_attention(
        self,
        q: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor,
        attn_weights: torch.Tensor
    ) -> torch.Tensor:
        """使用H2O缓存的注意力计算"""
        scale = self.head_dim ** -0.5
        
        # 缓存的注意力分数
        cached_scores = torch.matmul(q, cache_k.transpose(-2, -1)) * scale
        
        # 合并新token和缓存
        all_k = torch.cat([cache_k, q.transpose(1, 2)], dim=2)
        all_v = torch.cat([cache_v, q.transpose(1, 2)], dim=2)
        all_scores = torch.cat([attn_weights, cached_scores], dim=-1)
        
        # Softmax归一化
        all_weights = F.softmax(all_scores, dim=-1)
        
        # 计算输出
        return torch.matmul(all_weights, all_v)
    
    def _update_h2o_cache(self, k, v, attn_weights):
        """更新H2O缓存"""
        # 初始化缓存(如果需要)
        if self.h2o_cache is None:
            self.h2o_cache = H2OKVCache(
                self.num_heads,
                self.head_dim,
                self.budget,
                device=k.device
            )
        
        # 更新每个位置的缓存
        B, H, T, D = k.shape
        for t in range(T):
            self.h2o_cache.update(
                k[:, :, t:t+1, :],
                v[:, :, t:t+1, :],
                attn_weights[:, :, t, :],
                t
            )

6. 实验结果

6.1 内存效率

模型Budget缓存Token内存节省困惑度变化
LLaMA-7B6464/layer75%+0.02
LLaMA-7B3232/layer87%+0.05
LLaMA-7B1616/layer93%+0.12

6.2 任务性能

在各种任务上的性能对比:

任务完整KVH2O-64H2O-32H2O-16
WikiText12.4512.4712.5012.57
PIQA79.279.078.878.3
BoolQ76.476.275.975.1

6.3 长上下文性能

在长序列任务上的表现:

任务序列长度完整KVH2O-64
PassKey32K98.2%97.1%
Needle128K95.1%93.8%
Summarization64K42.341.8

7. H2O vs 其他方法

方法核心思想选择策略适用场景
H2O累积注意力分数Top-k淘汰通用推理
PyramidKV层间差异金字塔递减长上下文
StreamingLLM局部性固定窗口流式生成
SnapKV相似性聚类模式匹配特定任务

8. 实践建议

8.1 Budget选择

# Budget选择指南
 
# 通用推理
budget_config = {
    'base_budget': 64,       # 适合大多数场景
    'pyramid_ratio': 0.0     # 不使用层级
}
 
# 长上下文
budget_config = {
    'base_budget': 128,      # 更多缓存
    'pyramid_ratio': 0.3    # 浅层更多
}
 
# 内存受限
budget_config = {
    'base_budget': 32,       # 极小缓存
    'use_decay': True,       # 使用衰减
    'decay_factor': 0.95
}

8.2 性能优化

# 异步缓存更新
class AsyncH2OCache(H2OKVCache):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.update_queue = []
    
    def async_update(self, k, v, scores, pos):
        """异步更新,不阻塞主计算"""
        self.update_queue.append((k, v, scores, pos))
    
    def process_updates(self):
        """批量处理更新"""
        for k, v, scores, pos in self.update_queue:
            self.update(k, v, scores, pos)
        self.update_queue.clear()

9. 总结

H2O的核心贡献:

  1. 理论支撑:基于注意力稀疏性的理论保证
  2. 简单有效:无需重训练,即插即用
  3. 灵活配置:可调整budget平衡内存与性能
  4. 可扩展性:可与其他方法(如PyramidKV)结合

参考文献