1. 概述

1.1 背景与动机

大型语言模型(LLM)的推理过程面临自回归生成的顺序性瓶颈。传统自回归解码在每个时间步只能生成一个Token,前向传播必须等待前一个Token生成完成才能开始:

这种串行生成模式导致GPU利用率低下,尤其在生成长序列时,计算资源大部分时间处于空闲状态。

1.2 多Token预测的直观想法

解决这一问题的直观思路是:在单个前向传播中预测多个后续Token。Medusa1正是基于这一思想,通过在LLM的主干网络旁添加多个并行的解码头,使模型能够一次性预测未来多个位置的Token。

与需要额外小模型进行预测的投机解码不同,Medusa的所有预测头与主模型共享同一主干网络,不需要训练额外的「草稿模型」。

1.3 Medusa与Speculative Decoding的区别

特性Speculative DecodingMedusa
草稿模型需要单独训练的小模型无需额外模型
预测分布小模型独立预测与主模型共享表示
硬件需求可能需双设备单设备即可
预测质量依赖小模型能力受益于主模型知识
灵活性较差(模型绑定)较好(可调节头数)

2. 架构设计

2.1 Medusa Heads:多个解码头并行预测

Medusa的核心思想是在Transformer主干网络的最后一层隐藏状态上,附加**个独立的解码头**(Medusa Heads)。每个头负责预测特定偏移量后的Token:

// Medusa架构伪代码
class MedusaModel(nn.Module):
    def __init__(self, base_model, num_heads=5):
        super().__init__()
        self.base_model = base_model  // 主干Transformer
        self.hidden_size = base_model.config.hidden_size
        
        // K个解码头,每个头预测偏移量为i+1的Token
        self.medusa_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(self.hidden_size, self.hidden_size),
                nn.ReLU(),
                nn.Linear(self.hidden_size, vocab_size)
            ) for _ in range(num_heads)
        ])
    
    def forward(self, input_ids, position):
        // 获取主干网络的隐藏状态
        outputs = self.base_model(input_ids)
        hidden_states = outputs.last_hidden_state
        
        // 取目标位置的隐藏状态
        target_hidden = hidden_states[:, position, :]
        
        // K个头并行预测
        predictions = [head(target_hidden) for head in self.medusa_heads]
        return predictions  // 返回K个logits

对于第个Medusa Head,其训练目标是预测当前Token之后第个位置的Token:

2.2 树形注意力机制

在推理阶段,需要同时验证多个候选Token序列。Medusa采用树形注意力(Tree Attention)来高效处理这种结构:

传统自回归注意力(单路径):
[START] → [A] → [B] → [C] → ...

Medusa树形注意力(多路径):
                    [START]
                    /  |  \
                   A   B   C
                  /|\ /|  /|\
                 D E F G H I J  (不同路径长度)

在实现上,树形注意力通过attention mask控制Token只Attend到同一条路径上的前驱节点:

// 树形注意力掩码构建
def build_tree_attention_mask(tree_structure, seq_len):
    """
    tree_structure: list of (parent_idx, depth) for each position
    例如: [(0,0), (0,1), (0,1), (1,2), (1,2), (2,3)] 表示:
          - 位置1是根节点的子节点(深度1)
          - 位置2,3是根节点的子节点(深度1)
          - 位置4,5是位置1的子节点(深度2)
          - 位置6是位置2的子节点(深度3)
    """
    mask = torch.zeros(seq_len, seq_len)
    for i in range(seq_len):
        parent_i, depth_i = tree_structure[i]
        for j in range(i):
            parent_j, depth_j = tree_structure[j]
            // 位置i只能Attend到同路径上的节点
            if is_ancestor(parent_j, parent_i, tree_structure):
                mask[i, j] = 1
    return mask

2.3 训练目标:多任务学习框架

Medusa的训练采用多任务学习框架,联合优化主模型和所有Medusa Head:

其中:

  • :主模型的交叉熵损失
  • :第个头的预测损失
  • :平衡系数(通常设为1.0)

2.4 与主模型的联合训练

训练策略有两种选择:

  1. 从零训练:在预训练阶段就加入Medusa Head
  2. 后训练:在已训练好的模型上添加Head并微调

实践中发现,后训练通常效果更好,因为:

  • 主模型已具备强大的语言建模能力
  • 只需让Medusa Head学习主模型的隐式预测能力
  • 训练代价更小,收敛更快

3. 解码算法

3.1 候选生成与验证

Medusa的解码流程分为两个阶段:

阶段一:候选生成

  • 给定当前上下文,使用所有Medusa Head并行生成个候选Token
  • 每个Head 预测位置的Token

阶段二:验证

  • 个候选Token组成树形结构的候选序列
  • 一次性前向传播验证整个候选序列
  • 根据验证结果决定接受哪些Token
// Medusa解码流程
def medusa_decode(model, input_ids, medusa_heads, num_heads=5):
    current_ids = input_ids.clone()
    
    while not end_of_generation(current_ids):
        // 1. 获取主干隐藏状态
        outputs = model.base_model(current_ids)
        hidden = outputs.last_hidden_state[:, -1, :]  // 最后一个位置
        
        // 2. Medusa Head并行预测
        candidates = []
        for i, head in enumerate(medusa_heads):
            logits = head(hidden)
            token = logits.argmax(dim=-1)
            candidates.append(token.item())
        
        // 3. 构建候选序列并进行树形验证
        tree_candidates = build_tree_candidates(candidates)  // 构建树结构
        accepted, accepted_count = verify_tree(
            model, current_ids, tree_candidates
        )
        
        // 4. 更新已接受的Token
        if accepted_count > 0:
            current_ids = extend_with_accepted(current_ids, accepted)
        else:
            // 如果没有Token被接受,使用贪婪采样
            next_token = candidates[0]
            current_ids = concat(current_ids, next_token)
    
    return current_ids

3.2 树搜索解码策略

Medusa使用树搜索策略组织候选序列:

// 树形候选序列构建
vector<vector<int>> build_tree_candidates(vector<int> candidates) {
    // 贪婪路径 + 采样路径
    vector<vector<int>> tree;
    
    // 路径0:贪婪选择
    tree.push_back({candidates[0]});
    
    // 路径1:第一个Head用贪婪,其余采样
    tree.push_back({candidates[0], sample_from_logits(candidates[1])});
    
    // 路径2:不同组合
    tree.push_back({candidates[0], candidates[1], sample_from_logits(candidates[2])});
    
    // ... 更多路径
    
    return tree;
}

典型的树结构如下(以为例):

深度0: [ROOT]
深度1: [T1_0](Head 1 贪婪预测)
深度2: [T2_0](Head 2 贪婪), [T2_1](Head 2 采样)
深度3: [T3_0], [T3_1], [T3_2](Head 3 组合)
深度4: [T4_0], [T4_1], [T4_2], [T4_3](Head 4 组合)
深度5: [T5_0], [T5_1], [T5_2], [T5_3], [T5_4](Head 5 组合)

3.3 贪婪vs采样验证

Medusa支持两种验证模式:

模式贪婪验证采样验证
选择策略argmax从分布中采样
生成多样性
接受率较高较低
适用场景代码生成、精确输出对话、创意写作
// 采样验证函数
int sample_from_logits(torch::Tensor logits, float temperature = 1.0) {
    if (temperature == 0.0) {
        return logits.argmax().item<int>();
    }
    // 温度采样
    auto probs = torch::softmax(logits / temperature, -1);
    auto dist = torch::distributions::Categorical(probs);
    return dist.sample().item<int>();
}

3.4 接受率与加速比分析

Medusa的加速比取决于接受率(Acceptance Rate):

其中:

  • :每个前向传播预测的Token数(通常为
  • :平均接受率

,接受率

这个公式表明,只有接受率足够高才能获得加速。实践中:

  • 第一个Token的接受率最高(90%)
  • 后续Token接受率递减
  • 需要精心调优树结构和温度参数

4. 训练策略

4.1 分布匹配损失

Medusa Head的训练目标是使其预测分布与主模型的条件分布匹配:

KL散度损失相比交叉熵更稳定,因为:

  • 避免了对主模型高概率Token的过度惩罚
  • 鼓励Medusa Head学习主模型的整体分布结构
  • 减少训练初期的梯度爆炸问题

4.2 温度调优

Medusa Head的温度参数对接受率有显著影响:

# 温度与接受率的关系(实验观察)
temperatures = [0.5, 0.7, 1.0, 1.2, 1.5]
acceptance_rates = [0.85, 0.78, 0.65, 0.52, 0.38]  # Head 1
# 温度越低,接受率越高,但生成多样性降低
 
# 推荐配置
MEDUSA_TEMPERATURES = {
    'head_1': 0.5,   # 高接受率
    'head_2': 0.7,
    'head_3': 1.0,
    'head_4': 1.2,
    'head_5': 1.5,   # 更多探索
}

4.3 训练稳定性

为保证训练稳定性,Medusa采用以下策略:

  1. 渐进式添加Head:先训练前面的Head,再逐步添加后面的Head
  2. 学习率调度:Head的学习率通常为主模型的10-50%
  3. 梯度裁剪:防止梯度爆炸
  4. 分布正则化:使用KL散度而非交叉熵
# 训练配置示例
training_config = {
    'num_medusa_heads': 5,
    'head_learning_rate': 3e-5,  # 约为主模型的1/10
    'warmup_steps': 1000,
    'grad_clip_norm': 1.0,
    'use_kl_loss': True,  # 使用KL散度损失
    'progressive_training': True,  # 渐进式训练
}

5. 代码实现

5.1 PyTorch实现核心部分

#include <torch/torch.h>
#include <vector>
 
class MedusaHeadImpl : public torch::nn::Module {
public:
    MedusaHeadImpl(int hidden_size, int vocab_size)
        : fc1(hidden_size, hidden_size),
          fc2(hidden_size, vocab_size) {
        register_module("fc1", fc1);
        register_module("fc2", fc2);
    }
    
    torch::Tensor forward(torch::Tensor x) {
        x = torch::relu(fc1->forward(x));
        return fc2->forward(x);
    }
    
private:
    torch::nn::Linear fc1, fc2;
};
 
class MedusaModelImpl : public torch::nn::Module {
public:
    MedusaModelImpl(
        torch::nn::Module base_model,
        int num_heads,
        int vocab_size,
        int hidden_size)
        : base_model(base_model),
          medusa_heads(num_heads) {
        // 初始化Medusa Heads
        for (int i = 0; i < num_heads; i++) {
            medusa_heads->push_back(
                MedusaHeadImpl(hidden_size, vocab_size));
        }
        register_module("base_model", base_model);
        register_module("medusa_heads", medusa_heads);
    }
    
    std::vector<torch::Tensor> forward(
        torch::Tensor input_ids,
        torch::Tensor positions) {
        // 获取主干模型输出
        auto base_output = base_model->forward(input_ids);
        auto hidden_states = base_output.last_hidden_state;
        
        // 获取目标位置的隐藏状态
        auto target_hidden = hidden_states.index({
            torch::indexing::Slice(),
            positions,
            torch::indexing::Slice()
        });
        
        // K个Head并行预测
        std::vector<torch::Tensor> predictions;
        for (auto& head : *medusa_heads) {
            predictions.push_back(head->as<MedusaHeadImpl>()->forward(target_hidden));
        }
        
        return predictions;
    }
    
private:
    torch::nn::Module base_model;
    torch::nn::ModuleList medusa_heads;
};

5.2 树形注意力实现

// 树形注意力掩码CUDA实现
__global__ void tree_attention_kernel(
    float* query,
    float* key,
    float* value,
    float* output,
    int* parent_ids,    // 每个位置的父节点ID
    int* depth_ids,     // 每个位置的深度
    int seq_len,
    int num_heads,
    int head_dim) {
    
    int bid = blockIdx.x;
    int tid = threadIdx.x;
    int head_idx = bid % num_heads;
    int batch_idx = bid / num_heads;
    
    for (int i = tid; i < seq_len; i += blockDim.x) {
        float sum = 0.0f;
        float max_val = -INFINITY;
        
        // 找同路径上的所有可Attend节点
        for (int j = 0; j < i; j++) {
            if (is_in_same_path(parent_ids, depth_ids, i, j)) {
                // 计算注意力分数
                float score = compute_dot_product(
                    query, key, i, j, head_idx, head_dim);
                score = expf(score - max_val);
                sum += score;
            }
        }
        
        // softmax归一化
        output[bid * seq_len * head_dim + i * head_dim + tid] = sum;
    }
}
 
// 检查两个位置是否在同一条路径上
__device__ bool is_in_same_path(int* parent, int* depth, int a, int b) {
    // 深度大的节点是深度小的节点的后代
    if (depth[a] > depth[b]) {
        return find_ancestor(parent, a, depth[a] - depth[b]) == b;
    } else {
        return find_ancestor(parent, b, depth[b] - depth[a]) == a;
    }
}

5.3 验证逻辑

// 树形验证逻辑
struct VerificationResult {
    std::vector<int> accepted_tokens;
    int accepted_count;
    float log_prob_sum;
};
 
VerificationResult verify_tree(
    torch::nn::Module& model,
    torch::Tensor input_ids,
    std::vector<std::vector<int>>& tree_candidates,
    torch::Tensor attention_mask) {
    
    // 构建完整的输入序列
    int max_depth = tree_candidates.size();
    int num_paths = tree_candidates[0].size();
    
    // 展平树结构为验证序列
    std::vector<torch::Tensor> flat_tokens;
    for (int depth = 0; depth < max_depth; depth++) {
        for (int path = 0; path < tree_candidates[depth].size(); path++) {
            flat_tokens.push_back(
                torch::tensor(tree_candidates[depth][path]));
        }
    }
    
    // 一次性前向传播
    auto flat_ids = torch::stack(flat_tokens).unsqueeze(0);
    auto outputs = model.forward(flat_ids);
    auto logits = outputs.logits;
    
    // 验证每个候选Token
    VerificationResult result;
    result.accepted_tokens = {};
    result.accepted_count = 0;
    
    for (int i = 0; i < tree_candidates[0].size(); i++) {
        int predicted = logits[0][i].argmax().item<int>();
        int target = tree_candidates[0][i];
        
        if (predicted == target) {
            result.accepted_tokens.push_back(target);
            result.accepted_count++;
        } else {
            break;  // 贪婪策略:一旦失败,停止接受
        }
    }
    
    return result;
}

6. 实验结果

6.1 加速比分析

在Vicuna基准测试上的加速效果1

配置平均接受率加速比
基线(无Medusa)-1.0×
Medusa-1(1个头)92%1.4×
Medusa-3(3个头)75%2.0×
Medusa-5(5个头)68%2.5×
Medusa-10(10个头)52%2.8×

关键发现

  • 第一个Token接受率最高(通常90%)
  • 后续Token接受率呈指数递减
  • 存在最优Head数量(边际收益递减)

6.2 生成质量对比

Medusa在多个基准上的生成质量与基线相当:

基准基线Medusa差异
MT-Bench6.926.88-0.04
Vicuna-Bench7.057.01-0.04
MMLU68.2%68.1%-0.1%

结论:Medusa在加速推理的同时,基本不损失生成质量。

6.3 内存开销分析

组件参数量内存占用(FP16)
LLaMA-7B主干7B~14GB
LLaMA-7B + Medusa-57.05B~14.1GB
LLaMA-13B主干13B~26GB
LLaMA-13B + Medusa-513.05B~26.1GB

结论:Medusa Head的参数量极小(通常基模型的1%),内存开销可忽略。

7. 与其他方法的对比

7.1 vs Speculative Decoding

特性Speculative DecodingMedusa
架构双模型(Draft + Target)单模型 + 多Head
预测能力依赖Draft模型依赖主模型知识
训练成本需训练小模型仅训练解码头
内存占用较高(双模型)较低
接受率60-80%65-75%
加速比2-3×2-3×

7.2 vs EAGLE

EAGLE(Executable and Alternating Generation)是另一种自推测解码方法:

特性EAGLEMedusa
预测方式自回归(逐Token)并行(多Head)
树结构动态构建静态定义
训练目标回归损失分类损失
实现复杂度较高较低
接受率稍高略低

7.3 vs 其他多Token预测方法

方法预测Token数实现方式优缺点
Medusa1-10多解码头简单高效
Eagle动态自回归Draft接受率高
LookaheadN=5n-gram缓存无需训练
Constrained Decoding可变树搜索适合受限生成

8. 实践建议

8.1 何时使用Medusa

适合场景

  • 对推理延迟敏感的在线服务
  • 生成长文本(如代码、报告)
  • 批量推理场景
  • 单GPU部署(无额外Draft模型空间)

不适合场景

  • 内存受限环境(极小模型)
  • 需要精确控制生成的过程
  • 推理设备不支持树形注意力

8.2 调优建议

# 推荐的超参数配置
config = {
    'num_heads': 5,        # 5个头通常是最优选择
    'head_lr_ratio': 0.1,  # Head学习率为主模型的10%
    'kl_loss_weight': 1.0,
    'temperatures': [0.5, 0.7, 1.0, 1.2, 1.5],
    'grad_clip': 1.0,
    'warmup_steps': 500,
}
 
# 训练策略
# 1. 先冻结主干,只训练Head
freeze_base_model(epochs=2)
 
# 2. 解冻主干,联合训练
unfreeze_and_joint_train(epochs=3, lower_lr=True)
 
# 3. 可选:微调温度参数
fine_tune_temperatures()

9. 总结

Medusa是一种优雅的LLM推理加速方案,通过:

  1. 共享主干:复用主模型的表示能力,无需额外Draft模型
  2. 并行预测:多个解码头同时预测后续Token
  3. 树形验证:高效验证多条候选路径
  4. 轻量训练:仅需微调解码头参数

实现2-3倍的推理加速,同时保持生成质量基本不变。

在实际应用中,Medusa特别适合部署在单GPU环境下、对延迟敏感的LLM服务场景。

References

Footnotes

  1. Chen, Y., et al. (2024). “Medusa: Simple LLM Inference Acceleration with Multiple Decoding Heads.” arXiv preprint arXiv:2401.10774. Link 2