StreamingLLM冷热Token分离
1. 概述
StreamingLLM是解决LLM流式推理(Streaming Inference)问题的关键方法。其核心思想是将Token分为”冷Token”(Cold Tokens)和”热Token”(Hot Tokens),只保留热Token和少量的冷Token,从而实现对无限长度序列的流式生成。
问题背景
传统LLM推理面临的挑战:
- 内存爆炸:KV Cache随序列长度线性增长
- 上下文窗口限制:最大位置编码长度限制
- 延迟累积:处理长序列时延迟增加
StreamingLLM的解决方案
关键洞察:在语言模型中,存在”注意力汇聚”(Attention Sink)现象
解决方案:保留Sink Token + 最近Token,丢弃中间Token
2. 注意力汇聚现象
2.1 什么是Attention Sink
语言模型在生成时会将大量注意力分配给少数几个”锚点”Token:
- 首个Token(通常为BOS):作为全局锚点
- 最近的Token:包含最新上下文信息
- 中间Token:信息已被传递,可丢弃
def analyze_attention_sink(model, tokenizer, prompts):
"""
分析注意力汇聚现象
"""
attention_patterns = {}
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
# 平均所有层和头的注意力
avg_attn = torch.stack(outputs.attentions).mean(dim=[0, 1])[0]
# 分析每个位置的注意力分布
attention_patterns[prompt] = {
'first_token_attn': avg_attn[0].sum().item(),
'last_token_attn': avg_attn[-1].sum().item(),
'middle_avg_attn': avg_attn[1:-1].mean().item()
}
return attention_patterns2.2 Attention Sink的形成原因
| 原因 | 解释 |
|---|---|
| 位置便利 | 首个Token可以被所有后续Token注意,无需学习 |
| 语义锚点 | BOS token积累全局信息 |
| 梯度平滑 | 模型学会依赖固定的”汇点” |
3. 冷热Token分类
3.1 Token分类定义
@dataclass
class TokenCategory:
"""
Token的分类定义
"""
# 热Token:包含最新上下文,需要精确保留
HOT_TOKENS = "hot"
# 冷Token:信息已被传递,可选择丢弃
COLD_TOKENS = "cold"
# Sink Token:注意力汇聚锚点,必须保留
SINK_TOKENS = "sink"
def classify_tokens(
positions: List[int],
window_size: int,
sink_position: int = 0
) -> Dict[str, List[int]]:
"""
将Token分类为冷/热/Sink
Args:
positions: Token位置列表
window_size: 保留的窗口大小
sink_position: Sink Token位置(默认为0)
Returns:
分类结果
"""
hot_tokens = []
cold_tokens = []
sink_tokens = []
max_pos = max(positions)
for pos in positions:
if pos == sink_position:
sink_tokens.append(pos)
elif pos >= max_pos - window_size:
hot_tokens.append(pos)
else:
cold_tokens.append(pos)
return {
'sink': sink_tokens,
'hot': hot_tokens,
'cold': cold_tokens
}3.2 StreamingLLM策略
输入序列: [BOS] The quick brown fox jumps [over] [lazy] [dog] .
完整注意力:
[BOS] ████████████████████████████ (高注意力 - Sink)
The ██░░░░░░░░░░░░░░░░░░░░░░░░
quick ██░░░░░░░░░░░░░░░░░░░░░░░░
brown ███░░░░░░░░░░░░░░░░░░░░░░░
fox ██░░░░░░░░░░░░░░░░░░░░░░░░
jumps ██░░░░░░░░░░░░░░░░░░░░░░░░
over █░░░░░░░░░░░░░░░░░░░░░░░░░
lazy █░░░░░░░░░░░░░░░░░░░░░░░░░
dog ████░░░░░░░░░░░░░░░░░░░░░░
. ████████████████████████████ (高注意力 - 热Token)
StreamingLLM保留:
┌──────────────────────────────────────────┐
│ [BOS] │ [over] [lazy] [dog] . │
│ Sink Token │ 热Token窗口 │
└──────────────────────────────────────────┘
丢弃所有中间Token ↑
4. StreamingLLM实现
4.1 核心实现
class StreamingLLMCache:
"""
StreamingLLM的KV Cache实现
策略:保留 Sink Token + 最近的窗口Token
"""
def __init__(
self,
num_heads: int,
head_dim: int,
window_size: int = 512,
sink_size: int = 4, # 保留前几个token作为sink
device: str = "cuda"
):
self.num_heads = num_heads
self.head_dim = head_dim
self.window_size = window_size
self.sink_size = sink_size
self.device = device
# Sink KV存储
self.sink_k = torch.zeros(
num_heads, sink_size, head_dim, device=device
)
self.sink_v = torch.zeros(
num_heads, sink_size, head_dim, device=device
)
# 窗口KV存储
self.window_k = torch.zeros(
num_heads, 0, head_dim, device=device
)
self.window_v = torch.zeros(
num_heads, 0, head_dim, device=device
)
def update(
self,
k_new: torch.Tensor, # [num_heads, seq_len, head_dim]
v_new: torch.Tensor # [num_heads, seq_len, head_dim]
):
"""
更新StreamingLLM缓存
"""
seq_len = k_new.shape[1]
# 更新Sink(只更新前sink_size个)
if seq_len >= self.sink_size:
self.sink_k = k_new[:, :self.sink_size, :]
self.sink_v = v_new[:, :self.sink_size, :]
else:
# 部分更新sink
self.sink_k[:, :seq_len, :] = k_new
self.sink_v[:, :seq_len, :] = v_new
# 更新窗口(保留最近的window_size个)
self.window_k = torch.cat([self.window_k, k_new], dim=1)
self.window_v = torch.cat([self.window_v, v_new], dim=1)
# 如果超过窗口大小,丢弃最旧的
if self.window_k.shape[1] > self.window_size:
self.window_k = self.window_k[:, -self.window_size:, :]
self.window_v = self.window_v[:, -self.window_size:, :]
def get_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""
获取合并后的KV Cache
"""
k = torch.cat([self.sink_k, self.window_k], dim=1)
v = torch.cat([self.sink_v, self.window_v], dim=1)
return k, v
def get_memory_usage(self) -> float:
"""计算KV Cache内存使用"""
total_elements = (
self.sink_k.numel() + self.sink_v.numel() +
self.window_k.numel() + self.window_v.numel()
)
bytes_per_element = 2 # fp16
return total_elements * bytes_per_element / (1024 ** 2) # MB4.2 注意力计算
class StreamingLLMAttention(nn.Module):
"""
使用StreamingLLM缓存的注意力机制
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
head_dim: int,
window_size: int = 512,
sink_size: int = 4
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.window_size = window_size
self.sink_size = sink_size
# 投影层
self.q_proj = nn.Linear(hidden_size, num_heads * head_dim)
self.k_proj = nn.Linear(hidden_size, num_heads * head_dim)
self.v_proj = nn.Linear(hidden_size, num_heads * head_dim)
self.o_proj = nn.Linear(num_heads * head_dim, hidden_size)
# StreamingLLM缓存
self.streaming_cache = StreamingLLMCache(
num_heads, head_dim, window_size, sink_size
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = True
):
B, T, _ = hidden_states.shape
# QKV投影
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# Reshape
q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
if use_cache:
# 更新缓存
self.streaming_cache.update(k, v)
# 获取缓存的KV
cache_k, cache_v = self.streaming_cache.get_cache()
if T == 1:
# Decode阶段:使用缓存
output = self._streaming_attention(q, cache_k, cache_v)
else:
# Prefill阶段:完整计算后更新
scale = self.head_dim ** -0.5
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
if attention_mask is not None:
attn_scores = attn_scores.masked_fill(attention_mask == 0, float('-inf'))
attn_weights = F.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_weights, v)
else:
# 不使用缓存
scale = self.head_dim ** -0.5
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale
attn_weights = F.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_weights, v)
# 输出投影
output = output.transpose(1, 2).contiguous().view(B, T, -1)
return self.o_proj(output)
def _streaming_attention(
self,
q: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor
) -> torch.Tensor:
"""
使用StreamingLLM缓存计算注意力
"""
scale = self.head_dim ** -0.5
# 计算注意力分数
# q: [B, H, 1, D], cache_k: [B, H, S, D]
attn_scores = torch.matmul(q, cache_k.transpose(-2, -1)) * scale
# 应用Softmax
attn_weights = F.softmax(attn_scores, dim=-1)
# 计算输出
output = torch.matmul(attn_weights, cache_v)
return output5. 进阶优化
5.1 多Sink策略
class MultiSinkStreamingLLM(StreamingLLMCache):
"""
多Sink StreamingLLM
保留多个Sink Token位置,提高表达能力
"""
def __init__(
self,
*args,
sink_positions: List[int] = None, # 自定义Sink位置
**kwargs
):
super().__init__(*args, **kwargs)
# 默认Sink位置:0, 句首标点, 段落开始
self.sink_positions = sink_positions or [0]
# 为每个Sink维护独立缓存
self.sink_caches = {
pos: torch.zeros(
self.num_heads, 1, self.head_dim, device=self.device
)
for pos in self.sink_positions
}
def update(self, k_new, v_new, token_positions):
"""
根据token位置更新对应的Sink缓存
"""
for pos, k, v in zip(token_positions, k_new, v_new):
if pos in self.sink_positions:
self.sink_caches[pos] = torch.cat(
[self.sink_caches[pos], k.unsqueeze(1)], dim=1
)
# 更新窗口缓存
super().update(k_new, v_new)
def get_cache(self):
"""合并所有Sink和窗口缓存"""
sink_k = torch.cat([
self.sink_caches[pos] for pos in self.sink_positions
], dim=1)
sink_v = torch.cat([
self.sink_v for pos in self.sink_positions
], dim=1)
k = torch.cat([sink_k, self.window_k], dim=1)
v = torch.cat([sink_v, self.window_v], dim=1)
return k, v5.2 自适应窗口大小
class AdaptiveStreamingLLM(StreamingLLMCache):
"""
自适应窗口大小的StreamingLLM
根据内容复杂度动态调整窗口大小
"""
def __init__(
self,
*args,
min_window: int = 256,
max_window: int = 1024,
**kwargs
):
super().__init__(*args, **kwargs)
self.min_window = min_window
self.max_window = max_window
self.current_window = min_window
def compute_window_size(self, new_k, new_v, attention_scores):
"""
基于注意力分散度计算合适的窗口大小
"""
# 计算注意力分散度
# 分散度高 -> 需要更大的窗口
attn_std = attention_scores.std(dim=-1).mean()
# 动态调整窗口
if attn_std > 0.3:
self.current_window = min(
self.current_window + 64,
self.max_window
)
elif attn_std < 0.1:
self.current_window = max(
self.current_window - 64,
self.min_window
)
return self.current_window
def update(self, k_new, v_new, attention_scores=None):
"""更新缓存,自适应调整窗口大小"""
if attention_scores is not None:
window_size = self.compute_window_size(k_new, v_new, attention_scores)
self.window_size = window_size
super().update(k_new, v_new)6. 与其他方法的对比
| 方法 | 缓存策略 | 内存复杂度 | 适用场景 | 优点 | 缺点 |
|---|---|---|---|---|---|
| StreamingLLM | Sink + 窗口 | O(1) | 流式生成 | 无限长度 | 可能丢失信息 |
| H2O | Top-k重要性 | O(k) | 通用推理 | 信息保留好 | 需跟踪分数 |
| PyramidKV | 层间递减 | O(L) | 长上下文 | 层次化 | 需层信息 |
| Full Cache | 全部保留 | O(T) | 短序列 | 无信息丢失 | 内存爆炸 |
7. 实验结果
7.1 内存效率
| 序列长度 | Full KV (GB) | StreamingLLM (MB) | 压缩比 |
|---|---|---|---|
| 1K | 2.0 | 4.2 | 500x |
| 10K | 20.0 | 4.2 | 5000x |
| 100K | 200.0 | 4.2 | 50000x |
| 1M | 2000.0 | 4.2 | 500000x |
7.2 生成质量
在不同窗口大小下的困惑度:
| 窗口大小 | 困惑度 | 相对损失 |
|---|---|---|
| Full (16K) | 12.45 | - |
| 1024 | 12.48 | +0.03 |
| 512 | 12.52 | +0.07 |
| 256 | 12.61 | +0.16 |
7.3 流式推理速度
| 场景 | 吞吐量 (tokens/s) | 加速比 |
|---|---|---|
| Full KV | 45 | 1.0x |
| StreamingLLM | 180 | 4.0x |
| StreamingLLM + 量化 | 250 | 5.5x |
8. 实践指南
8.1 配置建议
# StreamingLLM配置推荐
# 流式对话
streaming_config = {
'window_size': 512,
'sink_size': 4,
'sink_positions': [0] # 仅BOS
}
# 代码生成(需要更大上下文)
code_config = {
'window_size': 1024,
'sink_size': 8,
'sink_positions': [0, 1] # BOS + 缩进层级
}
# 长文档摘要
summary_config = {
'window_size': 2048,
'sink_size': 4,
'use_paragraph_sinks': True
}8.2 Sink Token选择
def select_sink_tokens(tokenizer, text, num_sinks=4):
"""
选择合适的Sink Token
启发式规则:
1. BOS始终是Sink
2. 句子开头倾向于成为Sink
3. 段落分隔符是天然Sink
"""
tokens = tokenizer.encode(text)
sink_positions = [0] # BOS
special_tokens = {
tokenizer.eos_token_id,
tokenizer.bos_token_id,
tokenizer.pad_token_id,
}
for i, tok_id in enumerate(tokens[1:], start=1):
tok_text = tokenizer.decode([tok_id])
# 句子开头
if tok_text in '.!?\n' and i < len(tokens) - 1:
sink_positions.append(i)
# 特殊token
elif tok_id in special_tokens:
sink_positions.append(i)
if len(sink_positions) >= num_sinks:
break
return sorted(sink_positions)9. 与其他组件的集成
9.1 与投机解码结合
class StreamingSpeculativeDecoding:
"""
StreamingLLM + 投机解码
"""
def __init__(self, model, draft_model, streaming_config):
self.model = model
self.draft_model = draft_model
self.streaming_cache = StreamingLLMCache(**streaming_config)
def generate_streaming(
self,
prompt: str,
max_new_tokens: int = 100
):
"""流式生成"""
inputs = self.model.prepare_inputs(prompt)
# Prefill阶段
outputs = self.model.forward(inputs, use_cache=True)
self.streaming_cache.update(outputs.k, outputs.v)
# Decode阶段
for _ in range(max_new_tokens):
# 投机小模型
draft_tokens = self._speculate(
inputs, self.streaming_cache
)
# 验证
accepted = self._verify(draft_tokens, inputs)
# 更新缓存
self.streaming_cache.update(
accepted.k, accepted.v
)
yield accepted.tokens10. 总结
StreamingLLM的核心贡献:
- 发现Attention Sink现象:解释为什么语言模型依赖锚点Token
- 无限长度生成:通过O(1)内存实现无限长度序列生成
- 流式部署:适合实时对话、代码补全等场景
- 简单有效:无需重训练,即插即用