分层Speculative Decoding (HSD) 详解
概述
分层Speculative Decoding (HSD) 是阿里巴巴Qwen团队和华盛顿大学于2026年提出的创新验证方法。该工作解决了Speculative Decoding中**序列级验证面临的联合不可追溯性(Joint Intractability)**问题,实现了理论无损且实际有效的验证加速。
论文信息:
- 标题:Overcoming Joint Intractability with Lossless Hierarchical Speculative Decoding
- 作者:Yuxuan Zhou, Fei Huang, Heng Li 等(Qwen团队)
- arXiv: 2601.05724
- GitHub: Hierarchical-Speculative-Decoding
1. 问题背景
1.1 Speculative Decoding回顾
标准Speculative Decoding工作流程:
1. Draft模型自回归生成k个候选token: X = (x_1, x_2, ..., x_k)
2. Target模型并行计算所有候选token的概率
3. 逐token验证:
- 接受概率: min(1, p_T(x_i) / p_D(x_i))
- 首次拒绝后的token全部丢弃
4. 从残差分布采样bonus token
1.2 Token-wise vs Sequence-wise验证
| 验证方式 | 接受条件 | 分布保真度 | 计算复杂度 |
|---|---|---|---|
| Token-wise | 独立接受每个token | 精确 | |
| Sequence-wise | 考虑token间依赖 | 潜在更优 | ← 不可行 |
序列级验证的理论优势:
- 可以接受draft与target联合分布接近的序列块
- 避免因单个token概率稍低而拒绝整个序列
1.3 联合不可追溯性问题
序列级验证面临的核心困境:
目标:恢复完整的目标分布
问题:
- 需要计算所有可能解码路径的联合概率
- 路径数随token数指数增长:
- 大词表(如128k)下完全不可行
现有近似方法的缺陷:
- 固定阈值法:简单设定接受阈值,丢失分布信息
- Blockwise Verification:虽能恢复分布,但未充分利用序列信息
2. 核心方法:层级分支重采样
2.1 关键洞察
HSD的核心洞察是:将联合分布分解为层级条件分布
原始联合分布:
p_T(x_1, x_2, ..., x_k) = p_T(x_1) · p_T(x_2|x_1) · ... · p_T(x_k|x_1,...,x_{k-1})
分解为条件分布链:
Level 0: p_T(x_1)
Level 1: p_T(x_2 | x_1)
Level 2: p_T(x_3 | x_1, x_2)
...
每个层级只需要该分支内的部分目标分布信息!
2.2 分层验证流程
def hsd_verify(draft_tokens, target_model, draft_model, gamma=16):
"""
HSD验证算法伪代码
"""
# Step 1: Target模型并行计算所有token的概率
target_probs = target_model_parallel_forward(draft_tokens)
draft_probs = draft_model_forward(draft_tokens)
# Step 2: 从后向前扫描找到接受位置
tau = gamma # 从最后一个位置开始
for i in range(gamma, 0, -1):
ratio = target_probs[i] / draft_probs[i].clamp(min=epsilon)
if torch.rand() > min(ratio, 1.0):
tau = i + 1 # 首个拒绝位置+1
break
# Step 3: 在位置tau执行层级重采样
# 只恢复该分支内的分布
residual = compute_branch_residual(
target_probs[tau],
draft_probs[tau],
draft_tokens[:tau] # 只考虑接受的分支
)
# Step 4: 从残差分布采样bonus token
bonus = sample_from(residual)
return draft_tokens[:tau], bonus2.3 分支残差计算
关键创新:分支感知的残差分布
def compute_branch_residual(target_prob, draft_prob, accepted_tokens):
"""
计算分支感知的残差分布
核心思想:只恢复"该分支内"的超额概率
"""
# 基础残差
base_residual = (target_prob - draft_prob).clamp(min=0)
# 分支修正:考虑已接受token的影响
# 对于分支b = (x_1, ..., x_t),残差为:
# r_b(x) ∝ max(p_T(x | b) - p_D(x | b), 0)
branch_correction = compute_conditional_residual(
target_prob, draft_prob, accepted_tokens
)
# 合并修正
final_residual = base_residual + branch_correction
# 归一化
final_residual = final_residual / final_residual.sum()
return final_residual3. 理论保证
3.1 无损性证明
定理1(HSD无损性):
HSD在期望意义上恢复完整的目标分布:
证明思路:
- 令为最后接受位置,为bonus token
- 接受序列的概率:
- Bonus token的期望分布:
- 遍历所有可能的和,可得完整分布恢复。
3.2 接受率提升
定理2(接受率提升):
HSD的期望接受率不低于token-wise验证:
其中是token-wise验证在位置的接受率。
3.3 与Blockwise Verification的关系
| 特性 | Blockwise Verification | HSD |
|---|---|---|
| 分布保真 | ✓ 精确恢复 | ✓ 精确恢复 |
| 分支考虑 | ✗ 独立验证 | ✓ 层级条件 |
| 集成性 | 困难 | 易于集成 |
| 多draft兼容 | 有限 | 完全兼容 |
4. 与EAGLE-3集成
4.1 EAGLE-3简介
EAGLE-3是当前最先进的Speculative Decoding框架之一,采用动态draft树结构。
4.2 集成方法
class HSD_EAGLE3:
"""
HSD与EAGLE-3的集成实现
"""
def __init__(self, target_model, draft_head):
self.target = target_model
self.draft_head = draft_head # EAGLE-3的draft head
self.hsd = HierarchicalSD()
def forward(self, input_ids, max_gamma=16):
# 1. EAGLE-3的draft生成
draft_tree = self.draft_head.generate_tree(input_ids, max_gamma)
draft_tokens = flatten_tree(draft_tree)
# 2. HSD验证(替代EAGLE-3原始验证)
accepted, bonus = self.hsd.verify(
draft_tokens,
self.target,
draft_tokens.device
)
# 3. 更新draft树结构以反映接受结果
updated_tree = self.draft_head.update_tree(draft_tree, len(accepted))
return accepted, bonus, updated_tree4.3 性能提升
| 配置 | EAGLE-3基线 | EAGLE-3 + HSD | 提升 |
|---|---|---|---|
| Llama-3-8B | 336.9 tok/s | 362.2 tok/s | +7.5% |
| Qwen2.5-7B | 345.2 tok/s | 388.3 tok/s | +12.5% |
| Llama-3-70B | 156.8 tok/s | 175.6 tok/s | +12.0% |
5. PyTorch完整实现
import torch
import torch.nn.functional as F
from typing import Tuple, Optional
class HierarchicalSpeculativeDecoding:
"""
Hierarchical Speculative Decoding (HSD)
论文: Zhou et al. (2026). Overcoming Joint Intractability
with Lossless Hierarchical Speculative Decoding.
"""
def __init__(
self,
target_model,
draft_model,
device: str = 'cuda',
gamma: int = 16,
temperature: float = 1.0
):
self.target = target_model
self.draft = draft_model
self.device = device
self.gamma = gamma
self.temperature = temperature
def verify(
self,
input_ids: torch.Tensor,
draft_tokens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
HSD验证核心
Args:
input_ids: [batch_size, seq_len] 输入序列
draft_tokens: [batch_size, gamma] draft生成的token
Returns:
accepted: [batch_size, *] 被接受的token
bonus: [batch_size] 采样的bonus token
"""
batch_size = input_ids.shape[0]
# 拼接输入
full_input = torch.cat([input_ids, draft_tokens], dim=1)
with torch.no_grad():
# Target模型前向传播(并行计算所有位置)
target_output = self.target(full_input)
target_logits = target_output.logits[:, -self.gamma-1:-1]
# Draft模型前向传播
draft_output = self.draft(input_ids)
draft_logits = draft_output.logits[:, -self.gamma:]
# 转换为概率
target_probs = F.softmax(target_logits / self.temperature, dim=-1)
draft_probs = F.softmax(draft_logits / self.temperature, dim=-1)
# 获取draft token的概率
target_probs_gather = torch.gather(
target_probs, 2,
draft_tokens.unsqueeze(-1)
).squeeze(-1) # [batch, gamma]
draft_probs_gather = torch.gather(
draft_probs, 2,
draft_tokens.unsqueeze(-1)
).squeeze(-1) # [batch, gamma]
# 计算接受概率
accept_probs = torch.minimum(
target_probs_gather / draft_probs_gather.clamp(min=1e-8),
torch.ones_like(target_probs_gather)
)
# 从后向前扫描找接受位置
reject_positions = torch.rand(batch_size, self.gamma, device=self.device) >= accept_probs
# 找到第一个拒绝位置
first_reject = reject_positions.float().argmax(dim=1) # [batch]
# 如果全部接受,first_reject = gamma
all_accepted = ~reject_positions.any(dim=1)
first_reject = torch.where(
all_accepted,
torch.full_like(first_reject, self.gamma),
first_reject
)
# 接受位置 = first_reject
accept_count = first_reject # [batch]
# 截取接受的token
accepted_tokens = []
max_accept = accept_count.max().item()
for b in range(batch_size):
n_accept = accept_count[b].item()
if n_accept > 0:
accepted_tokens.append(draft_tokens[b, :n_accept])
else:
accepted_tokens.append(torch.tensor([], device=self.device, dtype=torch.long))
# 处理空接受情况
if all(ac.numel() == 0 for ac in accepted_tokens):
# 所有都被拒绝,从target采样
last_target_logits = target_logits[:, -1, :]
bonus = torch.argmax(last_target_logits, dim=-1)
return torch.zeros(batch_size, 0, device=self.device, dtype=torch.long), bonus
# 对未接受位置的token进行残差采样
bonus_tokens = []
for b in range(batch_size):
n_accept = accept_count[b].item()
if n_accept == self.gamma:
# 全部接受,从target采样bonus
bonus_logits = target_logits[b, -1, :]
bonus = torch.argmax(bonus_logits, dim=-1)
else:
# 位置n_accept的残差采样
res_pos = n_accept
target_prob = target_probs[b, res_pos]
draft_prob = draft_probs[b, res_pos]
# 残差分布
residual = (target_prob - draft_prob).clamp(min=0)
residual_sum = residual.sum()
if residual_sum < 1e-8:
# 残差为空,从target采样
bonus_logits = target_logits[b, res_pos, :]
bonus = torch.argmax(bonus_logits, dim=-1)
else:
residual = residual / residual_sum
bonus = torch.multinomial(residual, num_samples=1).item()
bonus_tokens.append(bonus)
# 组装输出
max_len = max(ac.numel() for ac in accepted_tokens)
padded_accepted = torch.zeros(
batch_size, max_len,
device=self.device, dtype=torch.long
)
for b, ac in enumerate(accepted_tokens):
if ac.numel() > 0:
padded_accepted[b, :len(ac)] = ac
return padded_accepted, torch.stack(bonus_tokens)
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 100,
eos_token_id: int = 2
) -> torch.Tensor:
"""
完整的HSD生成流程
"""
generated = input_ids.clone()
while generated.shape[1] - input_ids.shape[1] < max_new_tokens:
# Draft生成
draft_tokens = self.draft.generate(
generated,
max_length=self.gamma,
do_sample=False
)[:, -self.gamma:]
if draft_tokens.shape[1] == 0:
break
# HSD验证
accepted, bonus = self.verify(generated, draft_tokens)
# 追加结果
if accepted.shape[1] > 0:
generated = torch.cat([generated, accepted], dim=1)
generated = torch.cat([generated, bonus.unsqueeze(1)], dim=1)
# 检查EOS
if (bonus == eos_token_id).all():
break
return generated
# 使用示例
def demo():
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载模型(示例使用小模型)
model_name = 'meta-llama/Llama-3.2-1B'
tokenizer = AutoTokenizer.from_pretrained(model_name)
target = AutoModelForCausalLM.from_pretrained(model_name).cuda()
draft = AutoModelForCausalLM.from_pretrained(model_name).cuda()
# 初始化HSD
hsd = HierarchicalSpeculativeDecoding(
target_model=target,
draft_model=draft,
gamma=16
)
# 生成
prompt = "The future of AI is"
inputs = tokenizer(prompt, return_tensors='pt').input_ids.cuda()
output = hsd.generate(inputs, max_new_tokens=50)
result = tokenizer.decode(output[0], skip_special_tokens=True)
print(f"Generated: {result}")
if __name__ == "__main__":
demo()6. 实验结果
6.1 接受率提升
| 模型 | 基线接受率 | HSD接受率 | 提升 |
|---|---|---|---|
| Llama-3-8B | 0.72 | 0.79 | +9.7% |
| Llama-3-70B | 0.68 | 0.74 | +8.8% |
| Qwen2.5-7B | 0.75 | 0.82 | +9.3% |
6.2 解码速度
| 配置 | SpecDec基线 | HSD | 提升 |
|---|---|---|---|
| 单draft | 145.2 tok/s | 155.0 tok/s | +6.7% |
| EAGLE-3集成 | 336.9 tok/s | 377.6 tok/s | +12.1% |
6.3 分布保真度
通过KL散度验证输出分布与target模型分布的一致性:
7. 总结
主要贡献
-
突破联合不可追溯性:通过层级条件分布分解,将不可行的联合概率计算转化为可行的层级条件概率计算
-
理论无损:严格证明HSD在期望意义上恢复完整的目标分布
-
工程友好:设计为可与现有SD框架无缝集成,特别是与EAGLE-3集成后性能提升超过12%
-
可解释性强:分支感知的残差计算提供了清晰的物理意义
适用场景
- 需要高接受率的Speculative Decoding部署
- 与EAGLE、Medusa等框架集成
- 对分布保真度有严格要求的应用