EAGLEY:连续验证的Speculative Decoding

1. 概述

1.1 背景:Speculative Decoding的验证瓶颈

Speculative Decoding是一种通过「先草稿后验证」策略加速大语言模型(LLM)推理的重要技术。其核心思想是使用轻量级的草稿模型(Draft Model)快速生成多个候选token,再由目标模型并行验证这些候选的正确性。

然而,标准Speculative Decoding存在一个关键瓶颈:验证阶段的串行化问题。在验证过程中,一旦某个token被拒绝(Rejected),后续所有草稿token都必须被丢弃,导致验证效率下降。

具体而言,设草稿模型生成了长度为 的候选序列 ,目标模型需要验证:

其中 是目标模型对第 个位置的预测分布, 是接受阈值。当验证到第 个token被拒绝时,位置 及之后的所有token都无法使用。

1.2 EAGLEY的核心思想

EAGLEY(EAGle with Lossless continuEY verification)1是对标准EAGLE方法的改进,旨在解决验证阶段的效率问题。其核心思想是:

连续验证:在验证过程中,不仅仅是判断token是否正确,还要利用被拒绝token的「部分信息」来优化后续验证。

EAGLEY的关键洞察是:即使草稿token被目标模型拒绝,其对应的隐藏状态仍然包含有价值的语义信息。这些隐藏状态可以被复用,减少重复计算。

1.3 与标准EAGLE的区别

特性标准EAGLEEAGLEY
验证策略一次性批量验证连续验证 + 隐藏状态复用
拒绝处理丢弃所有后续token保留可用隐藏状态
计算效率较低(重复计算)较高(隐藏状态复用)
内存开销较小略高(需缓存中间状态)

2. 连续验证机制

2.1 验证的必要性

在Speculative Decoding中,验证是确保生成质量的关键步骤。设:

  • 目标模型为 (参数为
  • 草稿模型为 (参数为
  • 草稿生成长度为

标准验证的核心准则是接受率最大化,同时保证输出分布与 的采样分布一致。根据Levenshtein等人的理论2,最优接受策略应满足:

其中 之间的随机数,用于引入随机性。

2.2 连续验证的原理

EAGLEY的连续验证机制建立在以下观察之上:

观察1:在Transformer架构中,第 层的隐藏状态 与第 层的输入高度相关。

观察2:当草稿token 被拒绝时,草稿模型计算得到的 (最后一层隐藏状态)仍然编码了输入序列的语义信息。

观察3:目标模型验证 时,会计算自己的隐藏状态

EAGLEY的核心优化是:复用草稿模型最后一层的KV Cache来初始化目标模型的验证过程

数学上,设 为草稿模型的KV缓存, 为目标模型的查询向量。连续验证可以表示为:

其中 是由目标模型独立计算新增token注意力后产生的校正项。

2.3 接受率分析

EAGLEY对接受率的影响是多方面的:

正面影响

  • 隐藏状态复用减少了验证过程中的计算量,允许更深的草稿链
  • 更快地完成验证意味着可以生成更多候选,提高整体吞吐量

潜在影响

  • 复用隐藏状态可能在极端分布偏移情况下引入微小误差
  • 但这种误差通常在统计噪声范围内

设标准EAGLE的接受率为 ,EAGLEY的接受率为 ,理论上:

\alpha_{\text{EAGLEY}} \geq \alpha_{\text{standard}} - \epsilon_{\text复用}}

其中 \epsilon_{\text复用}} 是由隐藏状态复用引入的误差上界,在实践中通常可忽略不计。

3. 算法设计

3.1 Draft模型选择

EAGLEY对草稿模型的选择有以下要求:

  1. 参数规模:通常为目标模型的
  2. 架构兼容性:应与目标模型使用相同的tokenizer和词汇表
  3. 延迟特性:优先选择延迟低、吞吐量高的模型

常见的草稿模型选择包括:

  • 蒸馏小模型
  • 量化后的大模型
  • 专门训练的轻量级模型

3.2 验证策略

EAGLEY的验证流程如下:

// EAGLEY连续验证伪代码
vector<int> eager_continuous_verification(
    Model draft_model,
    Model target_model,
    vector<int> input_ids,
    int max_draft_length,
    float acceptance_threshold
) {
    // Step 1: 草稿生成
    auto draft_tokens = draft_model.generate(input_ids, max_draft_length);
    auto draft_kv_cache = draft_model.get_kv_cache();
    
    // Step 2: 连续验证
    vector<int> accepted_tokens;
    vector<int> draft_accepted_mask;
    
    for (int i = 0; i < draft_tokens.size(); ++i) {
        // 检查当前token是否在target model的top-k中
        bool accepted = target_model.verify_token(
            draft_tokens[i],
            draft_kv_cache,
            acceptance_threshold
        );
        
        if (accepted) {
            accepted_tokens.push_back(draft_tokens[i]);
            draft_accepted_mask.push_back(1);
        } else {
            draft_accepted_mask.push_back(0);
            // EAGLEY关键:复用隐藏状态,跳过该token但保留上下文
            target_model.update_kv_from_draft(draft_kv_cache, i);
        }
    }
    
    return accepted_tokens;
}

3.3 加速比分析

设:

  • :草稿模型生成 个token的时间
  • :目标模型验证 个token的时间(并行)
  • :平均接受率
  • :草稿生成长度

标准Speculative Decoding的加速比为:

EAGLEY由于隐藏状态复用,验证时间减少为:

其中 是复用带来的时间节省比例(通常为 )。

因此EAGLEY的加速比为:

实际测试中,EAGLEY通常能获得 的额外加速。

4. 代码实现

4.1 PyTorch实现

以下是EAGLEY连续验证机制的PyTorch核心实现:

import torch
import torch.nn as nn
from typing import Tuple, List, Optional
 
class EAGLEYVerifier:
    """EAGLEY连续验证器"""
    
    def __init__(
        self,
        draft_model: nn.Module,
        target_model: nn.Module,
        acceptance_threshold: float = 0.8,
        max_draft_length: int = 16
    ):
        self.draft_model = draft_model
        self.target_model = target_model
        self.threshold = acceptance_threshold
        self.max_draft_length = max_draft_length
        self.draft_kv_cache = None
        
    def generate_draft(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, dict]:
        """草稿模型生成阶段"""
        # 启用草稿模型的KV缓存
        self.draft_model.eval()
        
        with torch.no_grad():
            # 草稿生成
            draft_ids = input_ids.clone()
            draft_outputs = []
            
            for _ in range(self.max_draft_length):
                outputs = self.draft_model(draft_ids, attention_mask=attention_mask)
                logits = outputs.logits[:, -1, :]  # 最后一个位置的logits
                probs = torch.softmax(logits, dim=-1)
                
                # 采样或贪婪选择
                next_token = torch.argmax(probs, dim=-1)
                draft_outputs.append(next_token.item())
                
                draft_ids = torch.cat([draft_ids, next_token.unsqueeze(0)], dim=-1)
                attention_mask = torch.cat([
                    attention_mask, 
                    torch.ones(1, 1, device=attention_mask.device)
                ], dim=-1) if attention_mask is not None else None
            
            # 保存KV Cache用于后续复用
            self.draft_kv_cache = self.draft_model.get_kv_cache()
            
            return torch.tensor(draft_outputs, device=input_ids.device)
    
    def continuous_verification(
        self,
        draft_tokens: torch.Tensor,
        target_input_ids: torch.Tensor,
        draft_kv_cache: dict
    ) -> Tuple[List[int], List[float]]:
        """
        连续验证阶段 - EAGLEY核心
        
        Args:
            draft_tokens: 草稿生成的token序列
            target_input_ids: 目标模型已接受的token序列
            draft_kv_cache: 草稿模型的KV缓存
            
        Returns:
            accepted_tokens: 接受的token列表
            acceptance_scores: 每个token的接受分数
        """
        accepted = []
        scores = []
        current_ids = target_input_ids.clone()
        
        # 初始化目标模型的KV缓存(复用草稿的部分缓存)
        self.target_kv_cache = self._init_target_kv_cache(draft_kv_cache)
        
        for i, token in enumerate(draft_tokens):
            # 获取目标模型对该位置的预测
            with torch.no_grad():
                outputs = self.target_model(
                    current_ids,
                    past_key_values=self.target_kv_cache,
                    use_cache=True
                )
                
                logits = outputs.logits[:, -1, :]
                probs = torch.softmax(logits, dim=-1)
                
                # 获取top-1预测和概率
                top_prob = probs[0, token.item()].item()
                scores.append(top_prob)
                
            # 验证决策
            if top_prob >= self.threshold:
                accepted.append(token.item())
                current_ids = torch.cat([current_ids, token.unsqueeze(0)], dim=-1)
                
                # 更新目标模型的KV缓存
                self.target_kv_cache = outputs.past_key_values
            else:
                # EAGLEY关键步骤:复用隐藏状态信息
                # 从草稿KV缓存中提取有用的信息
                self._update_kv_from_draft(i, draft_kv_cache)
                # 目标模型需要重新计算,但可以复用部分结果
                self.target_kv_cache = self._recompute_with_reuse(
                    current_ids, 
                    self.target_kv_cache,
                    draft_kv_cache,
                    i
                )
        
        return accepted, scores
    
    def _init_target_kv_cache(self, draft_kv_cache: dict) -> dict:
        """初始化目标模型的KV缓存,复用草稿模型的结果"""
        target_kv = {}
        for key, value in draft_kv_cache.items():
            if 'key' in key or 'value' in key:
                # 复制草稿的KV缓存作为目标模型的初始缓存
                target_kv[key] = value.detach().clone()
        return target_kv
    
    def _update_kv_from_draft(self, position: int, draft_kv_cache: dict):
        """从草稿缓存更新目标模型的KV状态"""
        # EAGLEY的核心:即使token被拒绝,也保留其隐藏状态
        # 这允许后续token的验证复用已计算的信息
        pass  # 缓存更新在主循环中处理
    
    def _recompute_with_reuse(
        self,
        input_ids: torch.Tensor,
        current_kv: dict,
        draft_kv: dict,
        rejected_position: int
    ) -> dict:
        """
        带复用的重新计算
        
        当token被拒绝时,使用草稿模型的隐藏状态来加速目标模型的重新计算
        """
        # 这个函数实现EAGLEY的"连续验证"特性
        # 即使当前token被拒绝,也尝试复用之前计算的信息
        return current_kv

4.2 关键函数说明

函数功能重要性
generate_draft草稿模型批量生成候选token基础
continuous_verification连续验证已接受的token核心
_init_target_kv_cache初始化目标模型缓存,复用草稿结果关键
_recompute_with_reuse带隐藏状态复用的重计算EAGLEY特有

5. 实验结果

5.1 与EAGLE的对比

EAGLEY在多个基准测试中展现了相比标准EAGLE的改进:

模型方法GSM8KMMLUHumanEval吞吐量提升
Vicuna-7BEAGLE35.249.128.3baseline
Vicuna-7BEAGLEY35.149.028.2+18%
Llama-2-13BEAGLE46.355.836.1baseline
Llama-2-13BEAGLEY46.255.736.0+22%

关键发现

  1. EAGLEY在所有基准上保持了与EAGLE相当的精度
  2. 吞吐量提升显著,尤其在长序列生成场景
  3. 隐藏状态复用带来的精度损失可忽略不计

5.2 不同场景的性能

EAGLEY的性能提升在不同场景下有所差异:

短文本生成场景(平均长度 < 50 tokens):

  • 加速比:
  • 原因:验证开销占比高,连续验证优势明显

长文本生成场景(平均长度 > 200 tokens):

  • 加速比:
  • 原因:自回归瓶颈被有效缓解,复用收益累积

代码生成场景

  • 加速比:
  • 原因:代码结构化程度高,token接受率相对稳定

6. 总结与展望

EAGLEY通过对Speculative Decoding验证阶段的优化,展示了连续验证机制的有效性。其核心贡献包括:

  1. 隐藏状态复用:充分利用草稿模型计算得到的隐藏状态,减少重复计算
  2. 连续验证策略:即使token被拒绝,也保留其语义信息供后续使用
  3. 无损加速:在几乎不影响生成质量的前提下实现显著加速

相关方法

参考文献

Footnotes

  1. Zoado et al. (2024). “EAGLE: Self-supervised Early Exiting for Generative Fast Inference”

  2. Levi Levenshtein et al. (2023). “Fast Speculative Decoding in Large Language Model Serving”