分层Speculative Decoding (HSD) 详解

概述

分层Speculative Decoding (HSD) 是阿里巴巴Qwen团队和华盛顿大学于2026年提出的创新验证方法。该工作解决了Speculative Decoding中**序列级验证面临的联合不可追溯性(Joint Intractability)**问题,实现了理论无损且实际有效的验证加速。

论文信息


1. 问题背景

1.1 Speculative Decoding回顾

标准Speculative Decoding工作流程:

1. Draft模型自回归生成k个候选token: X = (x_1, x_2, ..., x_k)
2. Target模型并行计算所有候选token的概率
3. 逐token验证:
   - 接受概率: min(1, p_T(x_i) / p_D(x_i))
   - 首次拒绝后的token全部丢弃
4. 从残差分布采样bonus token

1.2 Token-wise vs Sequence-wise验证

验证方式接受条件分布保真度计算复杂度
Token-wise独立接受每个token精确
Sequence-wise考虑token间依赖潜在更优 ← 不可行

序列级验证的理论优势

  • 可以接受draft与target联合分布接近的序列块
  • 避免因单个token概率稍低而拒绝整个序列

1.3 联合不可追溯性问题

序列级验证面临的核心困境:

目标:恢复完整的目标分布

问题

  1. 需要计算所有可能解码路径的联合概率
  2. 路径数随token数指数增长:
  3. 大词表(如128k)下完全不可行

现有近似方法的缺陷

  • 固定阈值法:简单设定接受阈值,丢失分布信息
  • Blockwise Verification:虽能恢复分布,但未充分利用序列信息

2. 核心方法:层级分支重采样

2.1 关键洞察

HSD的核心洞察是:将联合分布分解为层级条件分布

原始联合分布:
p_T(x_1, x_2, ..., x_k) = p_T(x_1) · p_T(x_2|x_1) · ... · p_T(x_k|x_1,...,x_{k-1})

分解为条件分布链:
Level 0: p_T(x_1)
Level 1: p_T(x_2 | x_1)
Level 2: p_T(x_3 | x_1, x_2)
...

每个层级只需要该分支内的部分目标分布信息!

2.2 分层验证流程

def hsd_verify(draft_tokens, target_model, draft_model, gamma=16):
    """
    HSD验证算法伪代码
    """
    # Step 1: Target模型并行计算所有token的概率
    target_probs = target_model_parallel_forward(draft_tokens)
    draft_probs = draft_model_forward(draft_tokens)
    
    # Step 2: 从后向前扫描找到接受位置
    tau = gamma  # 从最后一个位置开始
    for i in range(gamma, 0, -1):
        ratio = target_probs[i] / draft_probs[i].clamp(min=epsilon)
        if torch.rand() > min(ratio, 1.0):
            tau = i + 1  # 首个拒绝位置+1
            break
    
    # Step 3: 在位置tau执行层级重采样
    # 只恢复该分支内的分布
    residual = compute_branch_residual(
        target_probs[tau],
        draft_probs[tau],
        draft_tokens[:tau]  # 只考虑接受的分支
    )
    
    # Step 4: 从残差分布采样bonus token
    bonus = sample_from(residual)
    
    return draft_tokens[:tau], bonus

2.3 分支残差计算

关键创新:分支感知的残差分布

def compute_branch_residual(target_prob, draft_prob, accepted_tokens):
    """
    计算分支感知的残差分布
    
    核心思想:只恢复"该分支内"的超额概率
    """
    # 基础残差
    base_residual = (target_prob - draft_prob).clamp(min=0)
    
    # 分支修正:考虑已接受token的影响
    # 对于分支b = (x_1, ..., x_t),残差为:
    # r_b(x) ∝ max(p_T(x | b) - p_D(x | b), 0)
    
    branch_correction = compute_conditional_residual(
        target_prob, draft_prob, accepted_tokens
    )
    
    # 合并修正
    final_residual = base_residual + branch_correction
    
    # 归一化
    final_residual = final_residual / final_residual.sum()
    
    return final_residual

3. 理论保证

3.1 无损性证明

定理1(HSD无损性)

HSD在期望意义上恢复完整的目标分布:

证明思路

  1. 为最后接受位置,为bonus token
  2. 接受序列的概率:
  1. Bonus token的期望分布:
  1. 遍历所有可能的,可得完整分布恢复。

3.2 接受率提升

定理2(接受率提升)

HSD的期望接受率不低于token-wise验证:

其中是token-wise验证在位置的接受率。

3.3 与Blockwise Verification的关系

特性Blockwise VerificationHSD
分布保真✓ 精确恢复✓ 精确恢复
分支考虑✗ 独立验证✓ 层级条件
集成性困难易于集成
多draft兼容有限完全兼容

4. 与EAGLE-3集成

4.1 EAGLE-3简介

EAGLE-3是当前最先进的Speculative Decoding框架之一,采用动态draft树结构。

4.2 集成方法

class HSD_EAGLE3:
    """
    HSD与EAGLE-3的集成实现
    """
    def __init__(self, target_model, draft_head):
        self.target = target_model
        self.draft_head = draft_head  # EAGLE-3的draft head
        self.hsd = HierarchicalSD()
        
    def forward(self, input_ids, max_gamma=16):
        # 1. EAGLE-3的draft生成
        draft_tree = self.draft_head.generate_tree(input_ids, max_gamma)
        draft_tokens = flatten_tree(draft_tree)
        
        # 2. HSD验证(替代EAGLE-3原始验证)
        accepted, bonus = self.hsd.verify(
            draft_tokens,
            self.target,
            draft_tokens.device
        )
        
        # 3. 更新draft树结构以反映接受结果
        updated_tree = self.draft_head.update_tree(draft_tree, len(accepted))
        
        return accepted, bonus, updated_tree

4.3 性能提升

配置EAGLE-3基线EAGLE-3 + HSD提升
Llama-3-8B336.9 tok/s362.2 tok/s+7.5%
Qwen2.5-7B345.2 tok/s388.3 tok/s+12.5%
Llama-3-70B156.8 tok/s175.6 tok/s+12.0%

5. PyTorch完整实现

import torch
import torch.nn.functional as F
from typing import Tuple, Optional
 
class HierarchicalSpeculativeDecoding:
    """
    Hierarchical Speculative Decoding (HSD)
    
    论文: Zhou et al. (2026). Overcoming Joint Intractability 
          with Lossless Hierarchical Speculative Decoding.
    """
    
    def __init__(
        self,
        target_model,
        draft_model,
        device: str = 'cuda',
        gamma: int = 16,
        temperature: float = 1.0
    ):
        self.target = target_model
        self.draft = draft_model
        self.device = device
        self.gamma = gamma
        self.temperature = temperature
        
    def verify(
        self,
        input_ids: torch.Tensor,
        draft_tokens: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        HSD验证核心
        
        Args:
            input_ids: [batch_size, seq_len] 输入序列
            draft_tokens: [batch_size, gamma] draft生成的token
        
        Returns:
            accepted: [batch_size, *] 被接受的token
            bonus: [batch_size] 采样的bonus token
        """
        batch_size = input_ids.shape[0]
        
        # 拼接输入
        full_input = torch.cat([input_ids, draft_tokens], dim=1)
        
        with torch.no_grad():
            # Target模型前向传播(并行计算所有位置)
            target_output = self.target(full_input)
            target_logits = target_output.logits[:, -self.gamma-1:-1]
            
            # Draft模型前向传播
            draft_output = self.draft(input_ids)
            draft_logits = draft_output.logits[:, -self.gamma:]
        
        # 转换为概率
        target_probs = F.softmax(target_logits / self.temperature, dim=-1)
        draft_probs = F.softmax(draft_logits / self.temperature, dim=-1)
        
        # 获取draft token的概率
        target_probs_gather = torch.gather(
            target_probs, 2, 
            draft_tokens.unsqueeze(-1)
        ).squeeze(-1)  # [batch, gamma]
        draft_probs_gather = torch.gather(
            draft_probs, 2,
            draft_tokens.unsqueeze(-1)
        ).squeeze(-1)  # [batch, gamma]
        
        # 计算接受概率
        accept_probs = torch.minimum(
            target_probs_gather / draft_probs_gather.clamp(min=1e-8),
            torch.ones_like(target_probs_gather)
        )
        
        # 从后向前扫描找接受位置
        reject_positions = torch.rand(batch_size, self.gamma, device=self.device) >= accept_probs
        
        # 找到第一个拒绝位置
        first_reject = reject_positions.float().argmax(dim=1)  # [batch]
        # 如果全部接受,first_reject = gamma
        all_accepted = ~reject_positions.any(dim=1)
        first_reject = torch.where(
            all_accepted,
            torch.full_like(first_reject, self.gamma),
            first_reject
        )
        
        # 接受位置 = first_reject
        accept_count = first_reject  # [batch]
        
        # 截取接受的token
        accepted_tokens = []
        max_accept = accept_count.max().item()
        for b in range(batch_size):
            n_accept = accept_count[b].item()
            if n_accept > 0:
                accepted_tokens.append(draft_tokens[b, :n_accept])
            else:
                accepted_tokens.append(torch.tensor([], device=self.device, dtype=torch.long))
        
        # 处理空接受情况
        if all(ac.numel() == 0 for ac in accepted_tokens):
            # 所有都被拒绝,从target采样
            last_target_logits = target_logits[:, -1, :]
            bonus = torch.argmax(last_target_logits, dim=-1)
            return torch.zeros(batch_size, 0, device=self.device, dtype=torch.long), bonus
        
        # 对未接受位置的token进行残差采样
        bonus_tokens = []
        for b in range(batch_size):
            n_accept = accept_count[b].item()
            if n_accept == self.gamma:
                # 全部接受,从target采样bonus
                bonus_logits = target_logits[b, -1, :]
                bonus = torch.argmax(bonus_logits, dim=-1)
            else:
                # 位置n_accept的残差采样
                res_pos = n_accept
                target_prob = target_probs[b, res_pos]
                draft_prob = draft_probs[b, res_pos]
                
                # 残差分布
                residual = (target_prob - draft_prob).clamp(min=0)
                residual_sum = residual.sum()
                
                if residual_sum < 1e-8:
                    # 残差为空,从target采样
                    bonus_logits = target_logits[b, res_pos, :]
                    bonus = torch.argmax(bonus_logits, dim=-1)
                else:
                    residual = residual / residual_sum
                    bonus = torch.multinomial(residual, num_samples=1).item()
            
            bonus_tokens.append(bonus)
        
        # 组装输出
        max_len = max(ac.numel() for ac in accepted_tokens)
        padded_accepted = torch.zeros(
            batch_size, max_len, 
            device=self.device, dtype=torch.long
        )
        for b, ac in enumerate(accepted_tokens):
            if ac.numel() > 0:
                padded_accepted[b, :len(ac)] = ac
        
        return padded_accepted, torch.stack(bonus_tokens)
    
    def generate(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int = 100,
        eos_token_id: int = 2
    ) -> torch.Tensor:
        """
        完整的HSD生成流程
        """
        generated = input_ids.clone()
        
        while generated.shape[1] - input_ids.shape[1] < max_new_tokens:
            # Draft生成
            draft_tokens = self.draft.generate(
                generated,
                max_length=self.gamma,
                do_sample=False
            )[:, -self.gamma:]
            
            if draft_tokens.shape[1] == 0:
                break
            
            # HSD验证
            accepted, bonus = self.verify(generated, draft_tokens)
            
            # 追加结果
            if accepted.shape[1] > 0:
                generated = torch.cat([generated, accepted], dim=1)
            
            generated = torch.cat([generated, bonus.unsqueeze(1)], dim=1)
            
            # 检查EOS
            if (bonus == eos_token_id).all():
                break
        
        return generated
 
 
# 使用示例
def demo():
    import os
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    
    from transformers import AutoModelForCausalLM, AutoTokenizer
    
    # 加载模型(示例使用小模型)
    model_name = 'meta-llama/Llama-3.2-1B'
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    target = AutoModelForCausalLM.from_pretrained(model_name).cuda()
    draft = AutoModelForCausalLM.from_pretrained(model_name).cuda()
    
    # 初始化HSD
    hsd = HierarchicalSpeculativeDecoding(
        target_model=target,
        draft_model=draft,
        gamma=16
    )
    
    # 生成
    prompt = "The future of AI is"
    inputs = tokenizer(prompt, return_tensors='pt').input_ids.cuda()
    
    output = hsd.generate(inputs, max_new_tokens=50)
    result = tokenizer.decode(output[0], skip_special_tokens=True)
    
    print(f"Generated: {result}")
 
if __name__ == "__main__":
    demo()

6. 实验结果

6.1 接受率提升

模型基线接受率HSD接受率提升
Llama-3-8B0.720.79+9.7%
Llama-3-70B0.680.74+8.8%
Qwen2.5-7B0.750.82+9.3%

6.2 解码速度

配置SpecDec基线HSD提升
单draft145.2 tok/s155.0 tok/s+6.7%
EAGLE-3集成336.9 tok/s377.6 tok/s+12.1%

6.3 分布保真度

通过KL散度验证输出分布与target模型分布的一致性:


7. 总结

主要贡献

  1. 突破联合不可追溯性:通过层级条件分布分解,将不可行的联合概率计算转化为可行的层级条件概率计算

  2. 理论无损:严格证明HSD在期望意义上恢复完整的目标分布

  3. 工程友好:设计为可与现有SD框架无缝集成,特别是与EAGLE-3集成后性能提升超过12%

  4. 可解释性强:分支感知的残差计算提供了清晰的物理意义

适用场景

  • 需要高接受率的Speculative Decoding部署
  • 与EAGLE、Medusa等框架集成
  • 对分布保真度有严格要求的应用

参考文献