Speculative Decoding理论:LLM推理加速

概述

自回归生成的瓶颈问题

大语言模型(Large Language Model, LLM)的推理过程本质上是一个自回归生成(Autoregressive Generation)过程。在每个解码步骤中,模型需要:

  1. 将已生成的 token 序列作为输入
  2. 计算注意力机制(Attention)
  3. 输出下一个 token 的概率分布

这一过程面临严峻的效率瓶颈:

瓶颈类型具体表现
计算瓶颈每个解码步骤都需要完整的前向传播, 的注意力计算无法并行
内存带宽瓶颈KV Cache的读写成为主要延迟来源1
自回归串行化下一个token依赖前一个token的生成,无法像prefill阶段那样充分并行

Speculative Decoding基本思想

Speculative Decoding的核心思想是将生成过程分解为两个阶段:

  1. Draft阶段:使用一个轻量级的「草稿模型」(Draft Model)快速生成多个候选token
  2. Verification阶段:使用目标模型(Target Model)并行验证这些候选token的正确性

这种「投机取巧」的方法利用了以下观察:

  • 验证多个token的正确性可以完全并行化
  • 轻量级模型生成的token有相当一部分是正确的
  • 即使需要回退(Rollback),也比完全自回归生成更高效

与传统的自回归生成相比:


理论框架

形式化定义

设目标模型为 ,draft模型为 ,输入序列为

定义(目标分布):目标模型在位置 的输出分布为

其中 是logits, 是温度参数( 时退化为贪婪解码)。

定义(Draft分布):Draft模型生成的分布为

核心假设:Draft模型是目标模型的一阶近似,即对于大多数token:

接受率分析

Speculative Decoding的正确性保证来自于Hyndman定理(也称为拒绝采样的接受准则)。2

定理(Hyndman Acceptance Criterion):

为提议分布(Proposal Distribution), 为目标分布,若对所有 满足:

其中 是常数,则以下采样-接受算法得到来自 的样本:

  1. 采样
  2. 以概率 接受

应用于Speculative Decoding

在位置 ,我们希望验证draft模型采样的token 。定义接受概率:

期望接受率

时,

期望加速比推导

设每轮Draft阶段生成 个token,验证阶段的接受率为 ,则:

每轮生成的token数期望

实际计算量分析

  • 自回归方式生成 个token需要 次完整前向传播
  • Speculative Decoding需要 1 次 draft前向 + 1 次 target前向( token并行验证)

加速比

其中 分别是目标模型和draft模型单次前向的时间。

更精确的模型:考虑回退开销,定义有效接受率


核心算法

Draft模型选择标准

Draft模型的选择对整体性能至关重要。理想模型应满足:

标准说明
质量匹配 分布接近
推理速度快单次前向传播时间远小于目标模型
参数量小适合部署在有限计算资源下

常见选择

  1. 同系列小模型:如 Llama-7B 作为 Llama-70B 的draft
  2. SSM架构:如 Mamba 模型,适合快速生成
  3. Speculative Heads:在目标模型上附加轻量级预测头3
  4. Medusa结构:共享backbone,多个并行解码头4

验证机制

Greedy Verification

在贪婪解码()场景下,验证过程简化为:

即只需比较draft模型输出的token是否与目标模型贪婪解码的token一致。

接受概率

Sampling-based Verification

当使用采样解码时,验证需要考虑概率比:

bool accept_token(float p_draft, float p_target, std::mt19937& rng) {
    float ratio = p_target / p_draft;
    float threshold = std::uniform_real_distribution<>(0.0f, 1.0f)(rng);
    return threshold <= ratio;
}

多Token预测与验证

Medusa范式

Medusa在目标模型上附加多个解码头(Decoding Head),每个头预测下一个位置的token:4

Token Position:    t   t+1   t+2   t+3   t+4
                  ┌────┬────┬────┬────┐
Medusa Heads:     │ H1 │ H2 │ H3 │ H4 │
                  └────┴────┴────┴────┘

训练目标:对第 个head,最小化:

EAGLE方法

EAGLE(Early Exit Guided Language model)采用自监督的早期退出机制:5

  1. 在每层设置early exit point
  2. 利用hidden states的层次化特性预测下一个token
  3. 减少计算量的同时保持生成质量

自适应策略

动态调整Draft长度

根据历史接受率动态调整每轮的draft长度

class AdaptiveSpeculator {
    float alpha_history;
    int k_current;
    
    void adjust_k() {
        if (alpha_history > 0.9) k_current += 2;  // 接受率高,增加长度
        else if (alpha_history < 0.5) k_current -= 1;  // 接受率低,减少长度
        k_current = clamp(k_current, 1, MAX_K);
    }
};

Beam Search集成

将Speculative Decoding与beam search结合:

  1. 保持多个假设(hypotheses)
  2. 对每个假设独立进行speculation
  3. 选择整体得分最高的路径

实现细节

KV Cache在Speculative Decoding中的重用

这是Speculative Decoding高效性的关键所在。验证阶段可以复用draft阶段计算出的KV Cache。

传统自回归

Step 1: 计算 K_1, V_1, 输出 t_1
Step 2: 计算 K_2, V_2, 输出 t_2  ← 无法复用
Step 3: 计算 K_3, V_3, 输出 t_3  ← 无法复用

Speculative Decoding

Draft: 
  Step 1: 计算 K_1, V_1, 输出 t_1, t_2, t_3
  → 保存 K_1, V_1, K_2, V_2, K_3, V_3

Verify:
  Step 2: 直接复用上述 KV Cache
  → 注意力计算只需 O(k) 而非 O(k²)

数学表达:对于位置 的key/query:

验证阶段已有 ,只需计算

Batch处理优化

Prefix Batching

当多个请求共享相同前缀时(如system prompt):

// 共享前缀(System Prompt)
std::vector<int> shared_prefix = {101, 2003, 1996, ...};  // token IDs
 
// 独立后缀(User Query)
std::vector<std::vector<int>> unique_queries = {
    {2054, 2003, 1996, ...},
    {3024, 1029, ...}
};
 
// Batch推理
for (auto& query : unique_queries) {
    auto full_input = concatenate(shared_prefix, query);
    // 共享prefix的KV Cache
}

Continuous Batching

动态批处理以最大化GPU利用率:

  1. 新请求随时加入batch
  2. 完成的请求立即退出
  3. Draft和Verify阶段分别batch处理

内存管理

动态KV Cache分配

class KVCacheManager {
    size_t max_seq_len;
    size_t num_layers;
    size_t num_heads;
    size_t head_dim;
    
    std::vector<std::vector<torch::Tensor>> kv_cache;
    
    void allocate(int batch_size) {
        kv_cache.resize(batch_size);
        for (auto& cache : kv_cache) {
            cache.resize(num_layers);
            for (auto& k_cache : cache) {
                k_cache = torch::zeros({num_heads, max_seq_len, head_dim});
            }
        }
    }
};

显存优化策略

策略效果
PagedAttention减少内存碎片,支持动态分配
KV Cache量化FP16 → INT8 减少50%显存
分布式KV Cache多GPU分担存储压力

代码实现

PyTorch完整实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, List, Optional
 
class SpeculativeDecoder:
    """Speculative Decoding 实现"""
    
    def __init__(
        self,
        target_model: nn.Module,
        draft_model: nn.Module,
        max_draft_len: int = 8,
        temperature: float = 1.0,
        device: str = "cuda"
    ):
        self.target = target_model
        self.draft = draft_model
        self.max_draft = max_draft_len
        self.temperature = temperature
        self.device = device
        
        # 冻结目标模型参数
        for p in self.target.parameters():
            p.requires_grad = False
            
        # 冻结draft模型参数(可选)
        for p in self.draft.parameters():
            p.requires_grad = False
    
    def _sample(self, logits: torch.Tensor) -> torch.Tensor:
        """从logits中采样token"""
        if self.temperature == 0:
            return torch.argmax(logits, dim=-1)
        probs = F.softmax(logits / self.temperature, dim=-1)
        return torch.multinomial(probs, num_samples=1).squeeze(-1)
    
    def _compute_accept_prob(
        self,
        p_target: torch.Tensor,
        p_draft: torch.Tensor
    ) -> torch.Tensor:
        """计算接受概率(基于Hyndman准则)"""
        # 避免除零
        p_draft = torch.clamp(p_draft, min=1e-10)
        ratio = p_target / p_draft
        return torch.clamp(ratio, max=1.0)
    
    def draft_phase(
        self,
        input_ids: torch.Tensor,
        kv_cache: Optional[dict] = None
    ) -> Tuple[List[int], dict]:
        """Draft阶段:使用draft模型生成候选序列"""
        draft_tokens = []
        current_ids = input_ids.clone()
        current_kv = {} if kv_cache is None else kv_cache
        
        for _ in range(self.max_draft):
            # 前向传播
            with torch.no_grad():
                outputs = self.draft(
                    input_ids=current_ids,
                    past_key_values=current_kv,
                    use_cache=True
                )
            
            logits = outputs.logits[:, -1, :] / self.temperature
            next_token = self._sample(logits)
            
            draft_tokens.append(next_token.item())
            current_ids = next_token.unsqueeze(0)
            current_kv = outputs.past_key_values
            
            # 遇到eos终止
            if next_token.item() == self.target.config.eos_token_id:
                break
        
        return draft_tokens, current_kv
    
    def verify_phase(
        self,
        input_ids: torch.Tensor,
        draft_tokens: List[int],
        kv_cache: dict
    ) -> Tuple[List[int], int]:
        """
        Verify阶段:并行验证draft tokens
        
        Returns:
            accepted_tokens: 被接受的tokens
            first_reject_idx: 第一个拒绝的token索引(-1表示全部接受)
        """
        # 构建验证输入
        batch_size = len(draft_tokens)
        verify_input = torch.tensor(
            [input_ids[0].item()] + draft_tokens,
            device=self.device
        ).unsqueeze(0)
        
        # 复用draft阶段的KV Cache
        target_kv = kv_cache
        
        # 目标模型并行验证
        with torch.no_grad():
            outputs = self.target(
                input_ids=verify_input,
                past_key_values=target_kv,
                use_cache=True
            )
        
        # 计算每个位置的接受概率
        target_probs = F.softmax(outputs.logits[0], dim=-1)  # [seq_len, vocab_size]
        
        accepted = []
        first_reject = -1
        
        for i, token_id in enumerate(draft_tokens):
            # 获取目标模型在位置i的token概率
            p_target = target_probs[i, token_id].item()
            
            # 获取draft模型的概率(需要重新计算)
            # 这里简化处理,假设draft的token就是最可能的
            p_draft = 1.0 / self.target.config.vocab_size  # 简化假设
            
            # 计算接受概率
            accept_prob = min(1.0, p_target / (p_draft + 1e-10))
            
            if torch.rand(1).item() < accept_prob:
                accepted.append(token_id)
            else:
                first_reject = i
                break
        
        return accepted, first_reject
    
    def generate(
        self,
        prompt_ids: torch.Tensor,
        max_new_tokens: int = 100
    ) -> torch.Tensor:
        """完整的Speculative Decoding生成过程"""
        generated = prompt_ids.clone()
        total_generated = 0
        
        # 初始KV Cache
        kv_cache = None
        
        while total_generated < max_new_tokens:
            # 1. Draft阶段
            draft_tokens, kv_cache = self.draft_phase(
                generated[:, -1:] if len(generated) > 1 else generated,
                kv_cache
            )
            
            if not draft_tokens:
                break
                
            # 2. Verify阶段
            accepted_tokens, reject_idx = self.verify_phase(
                generated if len(generated) > 1 else 
                    torch.tensor([[self.target.config.bos_token_id]], device=self.device),
                draft_tokens,
                kv_cache
            )
            
            # 3. 追加接受的tokens
            generated = torch.cat([
                generated,
                torch.tensor([accepted_tokens], device=self.device).T
            ], dim=-1)
            
            total_generated += len(accepted_tokens)
            
            # 如果全部拒绝,添加一个目标模型预测的token
            if not accepted_tokens:
                with torch.no_grad():
                    outputs = self.target(
                        input_ids=generated[:, -1:],
                        past_key_values=kv_cache,
                        use_cache=True
                    )
                next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1)
                generated = torch.cat([generated, next_token], dim=-1)
                kv_cache = outputs.past_key_values
                total_generated += 1
            
            # 遇到eos终止
            if generated[0, -1].item() == self.target.config.eos_token_id:
                break
        
        return generated[:, prompt_ids.shape[1]:]

关键函数说明

函数功能时间复杂度
draft_phase生成 个候选token
verify_phase并行验证候选token
generate完整生成流程迭代上述两阶段

性能分析与最优策略

加速比与接受率关系

加速比 与接受率 的关系:

(轻量级draft)时:

关键洞察

  • 时,(接近理论最大加速)
  • 时,(退化为普通自回归)

最优Draft长度选择

理论最优:给定接受率 ,最优 满足:

实际建议

场景推荐 原因
高接受率(8-16可最大化吞吐
中接受率(4-8平衡吞吐与回退开销
低接受率(1-4回退开销过大

不同场景下的性能对比

方法延迟优化吞吐优化适用场景
朴素自回归--基准
Speculative Decoding✅ 显著✅ 显著GPU丰富的服务端
Continuous Batching✅ 显著高并发场景
Flash Attention + SD✅✅✅✅长序列场景

局限性与发展方向

当前局限

局限性具体问题
领域适配Draft模型与目标模型分布差异大时,接受率急剧下降
计算资源需要同时加载两个模型,显存压力翻倍
长度外推Draft模型的长上下文能力弱于目标模型
采样质量在非贪婪解码场景下,接受率计算复杂

发展方向

级联方法

将多个不同规模的模型组成级联:

Prompt → Small → Medium → Large → Output
         ↑        ↑        ↓
       拒绝      拒绝    接受 → 输出

每级模型都有更高的接受率,只有难以预测的token才会传递到更大模型。

动态调整

  • 在线学习:根据实时接受率调整draft长度和模型选择
  • 上下文感知:根据prompt类型选择最适合的draft策略
  • 硬件感知:根据GPU型号和显存状态动态配置

与其他优化融合

融合方向潜在收益
+ Flash Attention减少注意力计算开销
+ 量化进一步减少显存占用
+ 推测解码变体Medusa、EAGLE等专用架构

参考文献


相关主题

Footnotes

  1. Dao, T., et al. (2022). “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” NeurIPS. 本文的KV Cache分析参考了FlashAttention的IO复杂度理论。

  2. Leviathan, Y., et al. (2023). “Fast Speculative Decoding for seq2seq Models.” ICML. 首次提出将推测解码应用于seq2seq架构,奠定了理论基础。

  3. Spector, B., & Re, C. (2023). “Speculative Decoding: Why is it Emerging?” Blog Post. 讨论了推测解码的实践动机和工程挑战。

  4. Chen, T., et al. (2024). “Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads.” arXiv:2401.10774. 提出多解码头并行预测的Medusa架构。 2

  5. Zoado, E., et al. (2024). “EAGLE: Self-supervised Early Exiting for Efficient LLM Inference.” 探索了基于早期退出的高效推理方法。