Speculative Decoding理论:LLM推理加速
概述
自回归生成的瓶颈问题
大语言模型(Large Language Model, LLM)的推理过程本质上是一个自回归生成(Autoregressive Generation)过程。在每个解码步骤中,模型需要:
- 将已生成的 token 序列作为输入
- 计算注意力机制(Attention)
- 输出下一个 token 的概率分布
这一过程面临严峻的效率瓶颈:
| 瓶颈类型 | 具体表现 |
|---|---|
| 计算瓶颈 | 每个解码步骤都需要完整的前向传播, 的注意力计算无法并行 |
| 内存带宽瓶颈 | KV Cache的读写成为主要延迟来源1 |
| 自回归串行化 | 下一个token依赖前一个token的生成,无法像prefill阶段那样充分并行 |
Speculative Decoding基本思想
Speculative Decoding的核心思想是将生成过程分解为两个阶段:
- Draft阶段:使用一个轻量级的「草稿模型」(Draft Model)快速生成多个候选token
- Verification阶段:使用目标模型(Target Model)并行验证这些候选token的正确性
这种「投机取巧」的方法利用了以下观察:
- 验证多个token的正确性可以完全并行化
- 轻量级模型生成的token有相当一部分是正确的
- 即使需要回退(Rollback),也比完全自回归生成更高效
与传统的自回归生成相比:
理论框架
形式化定义
设目标模型为 ,draft模型为 ,输入序列为 。
定义(目标分布):目标模型在位置 的输出分布为
其中 是logits, 是温度参数( 时退化为贪婪解码)。
定义(Draft分布):Draft模型生成的分布为
核心假设:Draft模型是目标模型的一阶近似,即对于大多数token:
接受率分析
Speculative Decoding的正确性保证来自于Hyndman定理(也称为拒绝采样的接受准则)。2
定理(Hyndman Acceptance Criterion):
设 为提议分布(Proposal Distribution), 为目标分布,若对所有 满足:
其中 是常数,则以下采样-接受算法得到来自 的样本:
- 从 采样
- 以概率 接受
应用于Speculative Decoding:
在位置 ,我们希望验证draft模型采样的token 。定义接受概率:
期望接受率:
当 时,。
期望加速比推导
设每轮Draft阶段生成 个token,验证阶段的接受率为 ,则:
每轮生成的token数期望:
实际计算量分析:
- 自回归方式生成 个token需要 次完整前向传播
- Speculative Decoding需要 1 次 draft前向 + 1 次 target前向( token并行验证)
加速比:
其中 和 分别是目标模型和draft模型单次前向的时间。
更精确的模型:考虑回退开销,定义有效接受率 :
核心算法
Draft模型选择标准
Draft模型的选择对整体性能至关重要。理想模型应满足:
| 标准 | 说明 |
|---|---|
| 质量匹配 | 与 分布接近 |
| 推理速度快 | 单次前向传播时间远小于目标模型 |
| 参数量小 | 适合部署在有限计算资源下 |
常见选择:
- 同系列小模型:如 Llama-7B 作为 Llama-70B 的draft
- SSM架构:如 Mamba 模型,适合快速生成
- Speculative Heads:在目标模型上附加轻量级预测头3
- Medusa结构:共享backbone,多个并行解码头4
验证机制
Greedy Verification
在贪婪解码()场景下,验证过程简化为:
即只需比较draft模型输出的token是否与目标模型贪婪解码的token一致。
接受概率:
Sampling-based Verification
当使用采样解码时,验证需要考虑概率比:
bool accept_token(float p_draft, float p_target, std::mt19937& rng) {
float ratio = p_target / p_draft;
float threshold = std::uniform_real_distribution<>(0.0f, 1.0f)(rng);
return threshold <= ratio;
}多Token预测与验证
Medusa范式
Medusa在目标模型上附加多个解码头(Decoding Head),每个头预测下一个位置的token:4
Token Position: t t+1 t+2 t+3 t+4
┌────┬────┬────┬────┐
Medusa Heads: │ H1 │ H2 │ H3 │ H4 │
└────┴────┴────┴────┘
训练目标:对第 个head,最小化:
EAGLE方法
EAGLE(Early Exit Guided Language model)采用自监督的早期退出机制:5
- 在每层设置early exit point
- 利用hidden states的层次化特性预测下一个token
- 减少计算量的同时保持生成质量
自适应策略
动态调整Draft长度
根据历史接受率动态调整每轮的draft长度 :
class AdaptiveSpeculator {
float alpha_history;
int k_current;
void adjust_k() {
if (alpha_history > 0.9) k_current += 2; // 接受率高,增加长度
else if (alpha_history < 0.5) k_current -= 1; // 接受率低,减少长度
k_current = clamp(k_current, 1, MAX_K);
}
};Beam Search集成
将Speculative Decoding与beam search结合:
- 保持多个假设(hypotheses)
- 对每个假设独立进行speculation
- 选择整体得分最高的路径
实现细节
KV Cache在Speculative Decoding中的重用
这是Speculative Decoding高效性的关键所在。验证阶段可以复用draft阶段计算出的KV Cache。
传统自回归:
Step 1: 计算 K_1, V_1, 输出 t_1
Step 2: 计算 K_2, V_2, 输出 t_2 ← 无法复用
Step 3: 计算 K_3, V_3, 输出 t_3 ← 无法复用
Speculative Decoding:
Draft:
Step 1: 计算 K_1, V_1, 输出 t_1, t_2, t_3
→ 保存 K_1, V_1, K_2, V_2, K_3, V_3
Verify:
Step 2: 直接复用上述 KV Cache
→ 注意力计算只需 O(k) 而非 O(k²)
数学表达:对于位置 的key/query:
验证阶段已有 ,只需计算 。
Batch处理优化
Prefix Batching
当多个请求共享相同前缀时(如system prompt):
// 共享前缀(System Prompt)
std::vector<int> shared_prefix = {101, 2003, 1996, ...}; // token IDs
// 独立后缀(User Query)
std::vector<std::vector<int>> unique_queries = {
{2054, 2003, 1996, ...},
{3024, 1029, ...}
};
// Batch推理
for (auto& query : unique_queries) {
auto full_input = concatenate(shared_prefix, query);
// 共享prefix的KV Cache
}Continuous Batching
动态批处理以最大化GPU利用率:
- 新请求随时加入batch
- 完成的请求立即退出
- Draft和Verify阶段分别batch处理
内存管理
动态KV Cache分配
class KVCacheManager {
size_t max_seq_len;
size_t num_layers;
size_t num_heads;
size_t head_dim;
std::vector<std::vector<torch::Tensor>> kv_cache;
void allocate(int batch_size) {
kv_cache.resize(batch_size);
for (auto& cache : kv_cache) {
cache.resize(num_layers);
for (auto& k_cache : cache) {
k_cache = torch::zeros({num_heads, max_seq_len, head_dim});
}
}
}
};显存优化策略
| 策略 | 效果 |
|---|---|
| PagedAttention | 减少内存碎片,支持动态分配 |
| KV Cache量化 | FP16 → INT8 减少50%显存 |
| 分布式KV Cache | 多GPU分担存储压力 |
代码实现
PyTorch完整实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, List, Optional
class SpeculativeDecoder:
"""Speculative Decoding 实现"""
def __init__(
self,
target_model: nn.Module,
draft_model: nn.Module,
max_draft_len: int = 8,
temperature: float = 1.0,
device: str = "cuda"
):
self.target = target_model
self.draft = draft_model
self.max_draft = max_draft_len
self.temperature = temperature
self.device = device
# 冻结目标模型参数
for p in self.target.parameters():
p.requires_grad = False
# 冻结draft模型参数(可选)
for p in self.draft.parameters():
p.requires_grad = False
def _sample(self, logits: torch.Tensor) -> torch.Tensor:
"""从logits中采样token"""
if self.temperature == 0:
return torch.argmax(logits, dim=-1)
probs = F.softmax(logits / self.temperature, dim=-1)
return torch.multinomial(probs, num_samples=1).squeeze(-1)
def _compute_accept_prob(
self,
p_target: torch.Tensor,
p_draft: torch.Tensor
) -> torch.Tensor:
"""计算接受概率(基于Hyndman准则)"""
# 避免除零
p_draft = torch.clamp(p_draft, min=1e-10)
ratio = p_target / p_draft
return torch.clamp(ratio, max=1.0)
def draft_phase(
self,
input_ids: torch.Tensor,
kv_cache: Optional[dict] = None
) -> Tuple[List[int], dict]:
"""Draft阶段:使用draft模型生成候选序列"""
draft_tokens = []
current_ids = input_ids.clone()
current_kv = {} if kv_cache is None else kv_cache
for _ in range(self.max_draft):
# 前向传播
with torch.no_grad():
outputs = self.draft(
input_ids=current_ids,
past_key_values=current_kv,
use_cache=True
)
logits = outputs.logits[:, -1, :] / self.temperature
next_token = self._sample(logits)
draft_tokens.append(next_token.item())
current_ids = next_token.unsqueeze(0)
current_kv = outputs.past_key_values
# 遇到eos终止
if next_token.item() == self.target.config.eos_token_id:
break
return draft_tokens, current_kv
def verify_phase(
self,
input_ids: torch.Tensor,
draft_tokens: List[int],
kv_cache: dict
) -> Tuple[List[int], int]:
"""
Verify阶段:并行验证draft tokens
Returns:
accepted_tokens: 被接受的tokens
first_reject_idx: 第一个拒绝的token索引(-1表示全部接受)
"""
# 构建验证输入
batch_size = len(draft_tokens)
verify_input = torch.tensor(
[input_ids[0].item()] + draft_tokens,
device=self.device
).unsqueeze(0)
# 复用draft阶段的KV Cache
target_kv = kv_cache
# 目标模型并行验证
with torch.no_grad():
outputs = self.target(
input_ids=verify_input,
past_key_values=target_kv,
use_cache=True
)
# 计算每个位置的接受概率
target_probs = F.softmax(outputs.logits[0], dim=-1) # [seq_len, vocab_size]
accepted = []
first_reject = -1
for i, token_id in enumerate(draft_tokens):
# 获取目标模型在位置i的token概率
p_target = target_probs[i, token_id].item()
# 获取draft模型的概率(需要重新计算)
# 这里简化处理,假设draft的token就是最可能的
p_draft = 1.0 / self.target.config.vocab_size # 简化假设
# 计算接受概率
accept_prob = min(1.0, p_target / (p_draft + 1e-10))
if torch.rand(1).item() < accept_prob:
accepted.append(token_id)
else:
first_reject = i
break
return accepted, first_reject
def generate(
self,
prompt_ids: torch.Tensor,
max_new_tokens: int = 100
) -> torch.Tensor:
"""完整的Speculative Decoding生成过程"""
generated = prompt_ids.clone()
total_generated = 0
# 初始KV Cache
kv_cache = None
while total_generated < max_new_tokens:
# 1. Draft阶段
draft_tokens, kv_cache = self.draft_phase(
generated[:, -1:] if len(generated) > 1 else generated,
kv_cache
)
if not draft_tokens:
break
# 2. Verify阶段
accepted_tokens, reject_idx = self.verify_phase(
generated if len(generated) > 1 else
torch.tensor([[self.target.config.bos_token_id]], device=self.device),
draft_tokens,
kv_cache
)
# 3. 追加接受的tokens
generated = torch.cat([
generated,
torch.tensor([accepted_tokens], device=self.device).T
], dim=-1)
total_generated += len(accepted_tokens)
# 如果全部拒绝,添加一个目标模型预测的token
if not accepted_tokens:
with torch.no_grad():
outputs = self.target(
input_ids=generated[:, -1:],
past_key_values=kv_cache,
use_cache=True
)
next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1)
generated = torch.cat([generated, next_token], dim=-1)
kv_cache = outputs.past_key_values
total_generated += 1
# 遇到eos终止
if generated[0, -1].item() == self.target.config.eos_token_id:
break
return generated[:, prompt_ids.shape[1]:]关键函数说明
| 函数 | 功能 | 时间复杂度 |
|---|---|---|
draft_phase | 生成 个候选token | |
verify_phase | 并行验证候选token | |
generate | 完整生成流程 | 迭代上述两阶段 |
性能分析与最优策略
加速比与接受率关系
加速比 与接受率 的关系:
当 (轻量级draft)时:
关键洞察:
- 当 时,(接近理论最大加速)
- 当 时,(退化为普通自回归)
最优Draft长度选择
理论最优:给定接受率 ,最优 满足:
实际建议:
| 场景 | 推荐 | 原因 |
|---|---|---|
| 高接受率() | 8-16 | 可最大化吞吐 |
| 中接受率() | 4-8 | 平衡吞吐与回退开销 |
| 低接受率() | 1-4 | 回退开销过大 |
不同场景下的性能对比
| 方法 | 延迟优化 | 吞吐优化 | 适用场景 |
|---|---|---|---|
| 朴素自回归 | - | - | 基准 |
| Speculative Decoding | ✅ 显著 | ✅ 显著 | GPU丰富的服务端 |
| Continuous Batching | ❌ | ✅ 显著 | 高并发场景 |
| Flash Attention + SD | ✅✅ | ✅✅ | 长序列场景 |
局限性与发展方向
当前局限
| 局限性 | 具体问题 |
|---|---|
| 领域适配 | Draft模型与目标模型分布差异大时,接受率急剧下降 |
| 计算资源 | 需要同时加载两个模型,显存压力翻倍 |
| 长度外推 | Draft模型的长上下文能力弱于目标模型 |
| 采样质量 | 在非贪婪解码场景下,接受率计算复杂 |
发展方向
级联方法
将多个不同规模的模型组成级联:
Prompt → Small → Medium → Large → Output
↑ ↑ ↓
拒绝 拒绝 接受 → 输出
每级模型都有更高的接受率,只有难以预测的token才会传递到更大模型。
动态调整
- 在线学习:根据实时接受率调整draft长度和模型选择
- 上下文感知:根据prompt类型选择最适合的draft策略
- 硬件感知:根据GPU型号和显存状态动态配置
与其他优化融合
| 融合方向 | 潜在收益 |
|---|---|
| + Flash Attention | 减少注意力计算开销 |
| + 量化 | 进一步减少显存占用 |
| + 推测解码变体 | Medusa、EAGLE等专用架构 |
参考文献
相关主题
- FlashAttention深度解析:注意力机制的IO优化
- LLM推理优化综述:包括KV Cache、量化等技术
- Transformer数学基础:注意力机制的数学原理
- MLP理论:神经网络基础理论
- CoT推理:长思维链场景下的推理优化
Footnotes
-
Dao, T., et al. (2022). “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” NeurIPS. 本文的KV Cache分析参考了FlashAttention的IO复杂度理论。 ↩
-
Leviathan, Y., et al. (2023). “Fast Speculative Decoding for seq2seq Models.” ICML. 首次提出将推测解码应用于seq2seq架构,奠定了理论基础。 ↩
-
Spector, B., & Re, C. (2023). “Speculative Decoding: Why is it Emerging?” Blog Post. 讨论了推测解码的实践动机和工程挑战。 ↩
-
Chen, T., et al. (2024). “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads.” arXiv:2401.10774. 提出多解码头并行预测的Medusa架构。 ↩ ↩2
-
Zoado, E., et al. (2024). “EAGLE: Self-supervised Early Exiting for Efficient LLM Inference.” 探索了基于早期退出的高效推理方法。 ↩