StreamingLLM冷热Token分离

1. 概述

StreamingLLM是解决LLM流式推理(Streaming Inference)问题的关键方法。其核心思想是将Token分为”冷Token”(Cold Tokens)和”热Token”(Hot Tokens),只保留热Token和少量的冷Token,从而实现对无限长度序列的流式生成。

问题背景

传统LLM推理面临的挑战:

  1. 内存爆炸:KV Cache随序列长度线性增长
  2. 上下文窗口限制:最大位置编码长度限制
  3. 延迟累积:处理长序列时延迟增加

StreamingLLM的解决方案

关键洞察:在语言模型中,存在”注意力汇聚”(Attention Sink)现象

解决方案:保留Sink Token + 最近Token,丢弃中间Token

2. 注意力汇聚现象

2.1 什么是Attention Sink

语言模型在生成时会将大量注意力分配给少数几个”锚点”Token:

  • 首个Token(通常为BOS):作为全局锚点
  • 最近的Token:包含最新上下文信息
  • 中间Token:信息已被传递,可丢弃
def analyze_attention_sink(model, tokenizer, prompts):
    """
    分析注意力汇聚现象
    """
    attention_patterns = {}
    
    for prompt in prompts:
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        with torch.no_grad():
            outputs = model(**inputs, output_attentions=True)
        
        # 平均所有层和头的注意力
        avg_attn = torch.stack(outputs.attentions).mean(dim=[0, 1])[0]
        
        # 分析每个位置的注意力分布
        attention_patterns[prompt] = {
            'first_token_attn': avg_attn[0].sum().item(),
            'last_token_attn': avg_attn[-1].sum().item(),
            'middle_avg_attn': avg_attn[1:-1].mean().item()
        }
    
    return attention_patterns

2.2 Attention Sink的形成原因

原因解释
位置便利首个Token可以被所有后续Token注意,无需学习
语义锚点BOS token积累全局信息
梯度平滑模型学会依赖固定的”汇点”

3. 冷热Token分类

3.1 Token分类定义

@dataclass
class TokenCategory:
    """
    Token的分类定义
    """
    # 热Token:包含最新上下文,需要精确保留
    HOT_TOKENS = "hot"
    
    # 冷Token:信息已被传递,可选择丢弃
    COLD_TOKENS = "cold"
    
    # Sink Token:注意力汇聚锚点,必须保留
    SINK_TOKENS = "sink"
 
def classify_tokens(
    positions: List[int],
    window_size: int,
    sink_position: int = 0
) -> Dict[str, List[int]]:
    """
    将Token分类为冷/热/Sink
    
    Args:
        positions: Token位置列表
        window_size: 保留的窗口大小
        sink_position: Sink Token位置(默认为0)
    
    Returns:
        分类结果
    """
    hot_tokens = []
    cold_tokens = []
    sink_tokens = []
    
    max_pos = max(positions)
    
    for pos in positions:
        if pos == sink_position:
            sink_tokens.append(pos)
        elif pos >= max_pos - window_size:
            hot_tokens.append(pos)
        else:
            cold_tokens.append(pos)
    
    return {
        'sink': sink_tokens,
        'hot': hot_tokens,
        'cold': cold_tokens
    }

3.2 StreamingLLM策略

输入序列: [BOS] The quick brown fox jumps [over] [lazy] [dog] .

完整注意力:
[BOS] ████████████████████████████  (高注意力 - Sink)
The    ██░░░░░░░░░░░░░░░░░░░░░░░░
quick  ██░░░░░░░░░░░░░░░░░░░░░░░░
brown  ███░░░░░░░░░░░░░░░░░░░░░░░
fox    ██░░░░░░░░░░░░░░░░░░░░░░░░
jumps  ██░░░░░░░░░░░░░░░░░░░░░░░░
over   █░░░░░░░░░░░░░░░░░░░░░░░░░
lazy   █░░░░░░░░░░░░░░░░░░░░░░░░░
dog    ████░░░░░░░░░░░░░░░░░░░░░░
.      ████████████████████████████  (高注意力 - 热Token)

StreamingLLM保留:
┌──────────────────────────────────────────┐
│ [BOS]                    │  [over] [lazy] [dog] . │
│  Sink Token               │   热Token窗口          │
└──────────────────────────────────────────┘
      丢弃所有中间Token ↑

4. StreamingLLM实现

4.1 核心实现

class StreamingLLMCache:
    """
    StreamingLLM的KV Cache实现
    
    策略:保留 Sink Token + 最近的窗口Token
    """
    
    def __init__(
        self,
        num_heads: int,
        head_dim: int,
        window_size: int = 512,
        sink_size: int = 4,  # 保留前几个token作为sink
        device: str = "cuda"
    ):
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.window_size = window_size
        self.sink_size = sink_size
        self.device = device
        
        # Sink KV存储
        self.sink_k = torch.zeros(
            num_heads, sink_size, head_dim, device=device
        )
        self.sink_v = torch.zeros(
            num_heads, sink_size, head_dim, device=device
        )
        
        # 窗口KV存储
        self.window_k = torch.zeros(
            num_heads, 0, head_dim, device=device
        )
        self.window_v = torch.zeros(
            num_heads, 0, head_dim, device=device
        )
        
    def update(
        self,
        k_new: torch.Tensor,  # [num_heads, seq_len, head_dim]
        v_new: torch.Tensor   # [num_heads, seq_len, head_dim]
    ):
        """
        更新StreamingLLM缓存
        """
        seq_len = k_new.shape[1]
        
        # 更新Sink(只更新前sink_size个)
        if seq_len >= self.sink_size:
            self.sink_k = k_new[:, :self.sink_size, :]
            self.sink_v = v_new[:, :self.sink_size, :]
        else:
            # 部分更新sink
            self.sink_k[:, :seq_len, :] = k_new
            self.sink_v[:, :seq_len, :] = v_new
        
        # 更新窗口(保留最近的window_size个)
        self.window_k = torch.cat([self.window_k, k_new], dim=1)
        self.window_v = torch.cat([self.window_v, v_new], dim=1)
        
        # 如果超过窗口大小,丢弃最旧的
        if self.window_k.shape[1] > self.window_size:
            self.window_k = self.window_k[:, -self.window_size:, :]
            self.window_v = self.window_v[:, -self.window_size:, :]
    
    def get_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        获取合并后的KV Cache
        """
        k = torch.cat([self.sink_k, self.window_k], dim=1)
        v = torch.cat([self.sink_v, self.window_v], dim=1)
        return k, v
    
    def get_memory_usage(self) -> float:
        """计算KV Cache内存使用"""
        total_elements = (
            self.sink_k.numel() + self.sink_v.numel() +
            self.window_k.numel() + self.window_v.numel()
        )
        bytes_per_element = 2  # fp16
        return total_elements * bytes_per_element / (1024 ** 2)  # MB

4.2 注意力计算

class StreamingLLMAttention(nn.Module):
    """
    使用StreamingLLM缓存的注意力机制
    """
    
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        head_dim: int,
        window_size: int = 512,
        sink_size: int = 4
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.window_size = window_size
        self.sink_size = sink_size
        
        # 投影层
        self.q_proj = nn.Linear(hidden_size, num_heads * head_dim)
        self.k_proj = nn.Linear(hidden_size, num_heads * head_dim)
        self.v_proj = nn.Linear(hidden_size, num_heads * head_dim)
        self.o_proj = nn.Linear(num_heads * head_dim, hidden_size)
        
        # StreamingLLM缓存
        self.streaming_cache = StreamingLLMCache(
            num_heads, head_dim, window_size, sink_size
        )
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        use_cache: bool = True
    ):
        B, T, _ = hidden_states.shape
        
        # QKV投影
        q = self.q_proj(hidden_states)
        k = self.k_proj(hidden_states)
        v = self.v_proj(hidden_states)
        
        # Reshape
        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        
        if use_cache:
            # 更新缓存
            self.streaming_cache.update(k, v)
            
            # 获取缓存的KV
            cache_k, cache_v = self.streaming_cache.get_cache()
            
            if T == 1:
                # Decode阶段:使用缓存
                output = self._streaming_attention(q, cache_k, cache_v)
            else:
                # Prefill阶段:完整计算后更新
                scale = self.head_dim ** -0.5
                attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
                
                if attention_mask is not None:
                    attn_scores = attn_scores.masked_fill(attention_mask == 0, float('-inf'))
                
                attn_weights = F.softmax(attn_scores, dim=-1)
                output = torch.matmul(attn_weights, v)
        else:
            # 不使用缓存
            scale = self.head_dim ** -0.5
            attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
            attn_weights = F.softmax(attn_scores, dim=-1)
            output = torch.matmul(attn_weights, v)
        
        # 输出投影
        output = output.transpose(1, 2).contiguous().view(B, T, -1)
        return self.o_proj(output)
    
    def _streaming_attention(
        self,
        q: torch.Tensor,
        cache_k: torch.Tensor,
        cache_v: torch.Tensor
    ) -> torch.Tensor:
        """
        使用StreamingLLM缓存计算注意力
        """
        scale = self.head_dim ** -0.5
        
        # 计算注意力分数
        # q: [B, H, 1, D], cache_k: [B, H, S, D]
        attn_scores = torch.matmul(q, cache_k.transpose(-2, -1)) * scale
        
        # 应用Softmax
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # 计算输出
        output = torch.matmul(attn_weights, cache_v)
        
        return output

5. 进阶优化

5.1 多Sink策略

class MultiSinkStreamingLLM(StreamingLLMCache):
    """
    多Sink StreamingLLM
    
    保留多个Sink Token位置,提高表达能力
    """
    
    def __init__(
        self,
        *args,
        sink_positions: List[int] = None,  # 自定义Sink位置
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        
        # 默认Sink位置:0, 句首标点, 段落开始
        self.sink_positions = sink_positions or [0]
        
        # 为每个Sink维护独立缓存
        self.sink_caches = {
            pos: torch.zeros(
                self.num_heads, 1, self.head_dim, device=self.device
            )
            for pos in self.sink_positions
        }
    
    def update(self, k_new, v_new, token_positions):
        """
        根据token位置更新对应的Sink缓存
        """
        for pos, k, v in zip(token_positions, k_new, v_new):
            if pos in self.sink_positions:
                self.sink_caches[pos] = torch.cat(
                    [self.sink_caches[pos], k.unsqueeze(1)], dim=1
                )
        
        # 更新窗口缓存
        super().update(k_new, v_new)
    
    def get_cache(self):
        """合并所有Sink和窗口缓存"""
        sink_k = torch.cat([
            self.sink_caches[pos] for pos in self.sink_positions
        ], dim=1)
        
        sink_v = torch.cat([
            self.sink_v for pos in self.sink_positions
        ], dim=1)
        
        k = torch.cat([sink_k, self.window_k], dim=1)
        v = torch.cat([sink_v, self.window_v], dim=1)
        
        return k, v

5.2 自适应窗口大小

class AdaptiveStreamingLLM(StreamingLLMCache):
    """
    自适应窗口大小的StreamingLLM
    
    根据内容复杂度动态调整窗口大小
    """
    
    def __init__(
        self,
        *args,
        min_window: int = 256,
        max_window: int = 1024,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.min_window = min_window
        self.max_window = max_window
        self.current_window = min_window
        
    def compute_window_size(self, new_k, new_v, attention_scores):
        """
        基于注意力分散度计算合适的窗口大小
        """
        # 计算注意力分散度
        # 分散度高 -> 需要更大的窗口
        attn_std = attention_scores.std(dim=-1).mean()
        
        # 动态调整窗口
        if attn_std > 0.3:
            self.current_window = min(
                self.current_window + 64,
                self.max_window
            )
        elif attn_std < 0.1:
            self.current_window = max(
                self.current_window - 64,
                self.min_window
            )
        
        return self.current_window
    
    def update(self, k_new, v_new, attention_scores=None):
        """更新缓存,自适应调整窗口大小"""
        if attention_scores is not None:
            window_size = self.compute_window_size(k_new, v_new, attention_scores)
            self.window_size = window_size
        
        super().update(k_new, v_new)

6. 与其他方法的对比

方法缓存策略内存复杂度适用场景优点缺点
StreamingLLMSink + 窗口O(1)流式生成无限长度可能丢失信息
H2OTop-k重要性O(k)通用推理信息保留好需跟踪分数
PyramidKV层间递减O(L)长上下文层次化需层信息
Full Cache全部保留O(T)短序列无信息丢失内存爆炸

7. 实验结果

7.1 内存效率

序列长度Full KV (GB)StreamingLLM (MB)压缩比
1K2.04.2500x
10K20.04.25000x
100K200.04.250000x
1M2000.04.2500000x

7.2 生成质量

在不同窗口大小下的困惑度:

窗口大小困惑度相对损失
Full (16K)12.45-
102412.48+0.03
51212.52+0.07
25612.61+0.16

7.3 流式推理速度

场景吞吐量 (tokens/s)加速比
Full KV451.0x
StreamingLLM1804.0x
StreamingLLM + 量化2505.5x

8. 实践指南

8.1 配置建议

# StreamingLLM配置推荐
 
# 流式对话
streaming_config = {
    'window_size': 512,
    'sink_size': 4,
    'sink_positions': [0]  # 仅BOS
}
 
# 代码生成(需要更大上下文)
code_config = {
    'window_size': 1024,
    'sink_size': 8,
    'sink_positions': [0, 1]  # BOS + 缩进层级
}
 
# 长文档摘要
summary_config = {
    'window_size': 2048,
    'sink_size': 4,
    'use_paragraph_sinks': True
}

8.2 Sink Token选择

def select_sink_tokens(tokenizer, text, num_sinks=4):
    """
    选择合适的Sink Token
    
    启发式规则:
    1. BOS始终是Sink
    2. 句子开头倾向于成为Sink
    3. 段落分隔符是天然Sink
    """
    tokens = tokenizer.encode(text)
    sink_positions = [0]  # BOS
    
    special_tokens = {
        tokenizer.eos_token_id,
        tokenizer.bos_token_id,
        tokenizer.pad_token_id,
    }
    
    for i, tok_id in enumerate(tokens[1:], start=1):
        tok_text = tokenizer.decode([tok_id])
        
        # 句子开头
        if tok_text in '.!?\n' and i < len(tokens) - 1:
            sink_positions.append(i)
        # 特殊token
        elif tok_id in special_tokens:
            sink_positions.append(i)
        
        if len(sink_positions) >= num_sinks:
            break
    
    return sorted(sink_positions)

9. 与其他组件的集成

9.1 与投机解码结合

class StreamingSpeculativeDecoding:
    """
    StreamingLLM + 投机解码
    """
    
    def __init__(self, model, draft_model, streaming_config):
        self.model = model
        self.draft_model = draft_model
        self.streaming_cache = StreamingLLMCache(**streaming_config)
    
    def generate_streaming(
        self,
        prompt: str,
        max_new_tokens: int = 100
    ):
        """流式生成"""
        inputs = self.model.prepare_inputs(prompt)
        
        # Prefill阶段
        outputs = self.model.forward(inputs, use_cache=True)
        self.streaming_cache.update(outputs.k, outputs.v)
        
        # Decode阶段
        for _ in range(max_new_tokens):
            # 投机小模型
            draft_tokens = self._speculate(
                inputs, self.streaming_cache
            )
            
            # 验证
            accepted = self._verify(draft_tokens, inputs)
            
            # 更新缓存
            self.streaming_cache.update(
                accepted.k, accepted.v
            )
            
            yield accepted.tokens

10. 总结

StreamingLLM的核心贡献:

  1. 发现Attention Sink现象:解释为什么语言模型依赖锚点Token
  2. 无限长度生成:通过O(1)内存实现无限长度序列生成
  3. 流式部署:适合实时对话、代码补全等场景
  4. 简单有效:无需重训练,即插即用

参考文献