大语言模型蒸馏

将大语言模型(LLM)的知识迁移到小模型是一个具有挑战性的任务,因为生成式模型具有巨大的输出空间和自回归特性。本文档介绍针对LLM的蒸馏方法。

1. 生成式LLM蒸馏的挑战

1.1 与分类模型的区别

特性分类模型生成式LLM
输出空间固定类别(~1000)巨大词表(~50000)
自回归单次前向逐token生成
训练目标交叉熵语言建模
暴露偏差不存在严重

1.2 暴露偏差问题

问题定义:训练时使用教师token作为输入,推理时使用自身输出作为输入,导致分布偏移。

训练时: [SOS] The cat sat [TEACHER] on the mat
推理时: [SOS] The cat sat [STUDENT] on the ...
                ↑
              错误累积

1.3 MiniLLM的核心洞察

MiniLLM1指出生成式LLM蒸馏使用反向KL散度(Reverse KL)比标准KL更合适。

标准KL vs 反向KL

散度公式特性
标准KL覆盖模式,鼓励学生探索低概率区域
反向KL模式寻求,学生选择一个最可能的模式

为什么反向KL更好

生成式LLM的输出分布通常是多峰的。反向KL让学生”锁定”在教师分布的一个峰上,避免了过度扩展到低概率区域导致的模糊生成。

2. MiniLLM方法

2.1 目标函数

MiniLLM使用反向KL散度作为蒸馏损失:

其中 是教师采样的序列。

简化形式(不使用教师采样)

2.2 训练流程

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
 
class MiniLLM:
    """
    MiniLLM: Knowledge Distillation for Language Models
    
    核心创新:使用反向KL散度进行生成式模型蒸馏
    """
    def __init__(self, teacher_name, student_name, temperature=2.0, 
                 beta=1.0, device='cuda'):
        self.teacher = AutoModelForCausalLM.from_pretrained(teacher_name).to(device)
        self.student = AutoModelForCausalLM.from_pretrained(student_name).to(device)
        self.tokenizer = AutoTokenizer.from_pretrained(student_name)
        self.temperature = temperature
        self.beta = beta
        self.device = device
        
        # 冻结教师模型
        for p in self.teacher.parameters():
            p.requires_grad = False
    
    def compute_reverse_kl_loss(self, input_ids, attention_mask):
        """
        计算反向KL散度损失
        
        对于每个token,计算:
        D_KL(p_S || p_T) = Σ p_S(y) * log(p_S(y) / p_T(y))
        """
        # 教师模型前向传播(不计算梯度)
        with torch.no_grad():
            teacher_outputs = self.teacher(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            teacher_logits = teacher_outputs.logits  # (B, L, V)
        
        # 学生模型前向传播
        student_outputs = self.student(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        student_logits = student_outputs.logits
        
        # 计算反向KL损失
        batch_size, seq_len, vocab_size = student_logits.shape
        
        # 学生概率分布
        student_probs = F.softmax(student_logits / self.temperature, dim=-1)
        # 教师概率分布
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
        # 教师log概率
        teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1)
        
        # 反向KL: D_KL(p_S || p_T) = Σ p_S * log(p_S / p_T)
        # = Σ p_S * log(p_S) - Σ p_S * log(p_T)
        # = - Σ p_S * log(p_T) - H(p_S)
        
        kl_loss = -torch.sum(student_probs * teacher_log_probs, dim=-1)  # (B, L)
        kl_loss = kl_loss * (self.temperature ** 2)  # 温度缩放
        
        # 加权平均(忽略padding)
        mask = attention_mask.float()
        kl_loss = (kl_loss * mask).sum() / mask.sum()
        
        return kl_loss
    
    def forward(self, input_ids, attention_mask):
        """前向传播,返回蒸馏损失"""
        # 反向KL蒸馏损失
        kl_loss = self.compute_reverse_kl_loss(input_ids, attention_mask)
        
        # 标准语言建模损失(可选)
        student_outputs = self.student(input_ids=input_ids, attention_mask=attention_mask)
        lm_loss = F.cross_entropy(
            student_outputs.logits[:, :-1].contiguous(),
            input_ids[:, 1:].contiguous(),
            reduction='mean'
        )
        
        # 总损失
        total_loss = self.beta * kl_loss + lm_loss
        
        return total_loss, kl_loss, lm_loss

2.3 完整训练循环

def train_minillm(minillm, train_dataloader, num_epochs, lr=1e-4):
    """MiniLLM训练循环"""
    optimizer = torch.optim.AdamW(minillm.student.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        for step, batch in enumerate(train_dataloader):
            input_ids = batch['input_ids'].to(minillm.device)
            attention_mask = batch['attention_mask'].to(minillm.device)
            
            optimizer.zero_grad()
            
            total_loss, kl_loss, lm_loss = minillm(input_ids, attention_mask)
            
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(minillm.student.parameters(), 1.0)
            optimizer.step()
            
            if step % 100 == 0:
                print(f"Epoch {epoch}, Step {step}: "
                      f"Loss={total_loss.item():.4f}, "
                      f"KL={kl_loss.item():.4f}, "
                      f"LM={lm_loss.item():.4f}")

3. GKD(Generalized Knowledge Distillation)

3.1 核心思想

GKD2针对离线蒸馏的分布不匹配问题提出On-policy蒸馏策略。

问题:离线蒸馏假设学生和教师在相同输入上学习,但实际中学生可能在某些区域学得很好而其他区域很差。

解决方案:动态选择学生预测不确定的区域进行蒸馏。

3.2 GKD目标函数

其中 是混合分布:

是真实标签的one-hot分布。

3.3 GKD实现

class GKD:
    """
    GKD: Generalized Knowledge Distillation
    
    On-policy蒸馏框架
    """
    def __init__(self, teacher, student, beta=0.5, temperature=4.0):
        self.teacher = teacher
        self.student = student
        self.beta = beta  # 混合系数
        self.temperature = temperature
    
    def gkd_loss(self, input_ids, attention_mask, labels=None, 
                 use_teacher_sampling=False):
        """
        GKD损失计算
        """
        # 教师分布
        with torch.no_grad():
            teacher_logits = self.teacher(
                input_ids=input_ids,
                attention_mask=attention_mask
            ).logits
        
        # 学生分布
        student_logits = self.student(
            input_ids=input_ids,
            attention_mask=attention_mask
        ).logits
        
        # 如果使用教师采样
        if use_teacher_sampling:
            # 从教师分布采样token
            probs = F.softmax(teacher_logits / self.temperature, dim=-1)
            sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), 1)
            sampled_tokens = sampled_tokens.view_as(labels)
            
            # 创建混合目标
            student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
            teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
            
            # 混合分布
            mixed_target = self.beta * teacher_probs + (1 - self.beta) * torch.zeros_like(teacher_probs)
            # 手动设置采样token的目标为1
            for b in range(labels.size(0)):
                t = sampled_tokens[b]
                # 在采样位置给予更高权重
                pass  # 简化实现
        
        # 标准离线GKD
        student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
        teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
        
        # 反向KL + 熵正则化
        kl_loss = F.kl_div(
            student_log_probs, 
            teacher_probs, 
            reduction='batchmean',
            log_target=True
        ) * (self.temperature ** 2)
        
        # 熵正则化(鼓励学生分布更平坦)
        entropy_loss = -torch.sum(
            teacher_probs * student_log_probs, 
            dim=-1
        ).mean()
        
        return kl_loss + 0.1 * entropy_loss

4. Chain-of-Thought蒸馏

4.1 Distilling Step-by-Step

Hsieh等人提出的Distilling Step-by-Step3利用LLM的推理能力进行蒸馏。

核心思想:不仅蒸馏最终答案,还蒸馏中间推理步骤。

class DistillStepByStep:
    """
    Distilling Step-by-Step: Outperforming larger language models 
    with less training data and smaller model sizes
    
    论文: ACL 2023
    """
    def __init__(self, llm_api_key=None):
        self.llm_api = LLMWrapper(api_key=llm_api_key)  # PaLM 540B等
    
    def extract_reasoning_chain(self, prompt, question):
        """从大模型提取推理链"""
        response = self.llm_api.generate(
            f"{prompt}\nQ: {question}\nA: Let me think step by step."
        )
        return response  # 返回带有推理步骤的答案
    
    def create_distillation_dataset(self, questions):
        """创建蒸馏数据集"""
        dataset = []
        for q in questions:
            result = self.extract_reasoning_chain(self.prompt, q)
            # 解析答案和推理步骤
            answer, rationale = self.parse_response(result)
            dataset.append({
                'question': q,
                'rationale': rationale,
                'answer': answer
            })
        return dataset
    
    def train_with_rationale(self, student_model, dataset):
        """使用推理链训练学生模型"""
        # 多任务学习:预测答案 + 生成理由
        for item in dataset:
            # 教师学生联合训练
            loss_answer = self.compute_answer_loss(student_model, item)
            loss_rationale = self.compute_rationale_loss(student_model, item)
            
            loss = loss_answer + lambda_ratio * loss_rationale
            loss.backward()

4.2 Beyond Imitation

Beyond Imitation4提出双链式思考蒸馏

  1. 模仿链:从教师学习标准推理步骤
  2. 创新链:鼓励学生发现更简洁的推理路径
class DualChainDistillation:
    """
    双链式思考蒸馏
    """
    def __init__(self, teacher, student):
        self.teacher = teacher
        self.student = student
    
    def dual_chain_loss(self, question, teacher_rationale, teacher_answer):
        """
        计算双链蒸馏损失
        """
        # 学生输出
        student_output = self.student(question)
        student_rationale = student_output.rationale
        student_answer = student_output.answer
        
        # 模仿损失:向教师学习
        imitation_loss = self.rationale_matching_loss(
            student_rationale, teacher_rationale
        )
        
        # 创新损失:鼓励发现更优推理(简化)
        # 通过一致性正则化
        innovation_loss = F.mse_loss(
            student_output.logits,
            teacher_output.logits.detach()
        )
        
        # 答案损失
        answer_loss = F.cross_entropy(
            student_output.logits,
            teacher_answer
        )
        
        return {
            'imitation': imitation_loss,
            'innovation': innovation_loss,
            'answer': answer_loss,
            'total': imitation_loss + 0.5 * innovation_loss + answer_loss
        }

5. 偏好学习蒸馏

5.1 Direct Preference Knowledge Distillation

DPKD5结合人类偏好数据进行蒸馏:

class DirectPreferenceKD:
    """
    Direct Preference Knowledge Distillation
    
    结合DPO进行蒸馏
    """
    def __init__(self, teacher, student, ref_model):
        self.teacher = teacher
        self.student = student
        self.ref_model = ref_model  # 参考模型(通常是SFT后的模型)
    
    def compute_preference_loss(self, prompts, chosen_responses, 
                               rejected_responses):
        """
        计算偏好损失
        """
        # 学生对chosen和rejected的评分
        chosen_logits = self.student(prompts, chosen_responses)
        rejected_logits = self.student(prompts, rejected_responses)
        
        # DPO风格的损失
        # log P(preferred|x) / P(dispreferred|x)
        loss = -F.logsigmoid(chosen_logits - rejected_logits).mean()
        
        # 结合蒸馏损失
        teacher_chosen = self.teacher(prompts, chosen_responses)
        teacher_rejected = self.teacher(prompts, rejected_responses)
        
        distill_loss = F.mse_loss(
            chosen_logits, 
            teacher_chosen
        ) + F.mse_loss(
            rejected_logits,
            teacher_rejected
        )
        
        return loss + 0.1 * distill_loss

6. 实践建议

6.1 蒸馏配置推荐

模型规模教师学生温度
7B → 3BLLaMA-2-7BLLaMA-2-3B4.00.5
13B → 7BLLaMA-2-13BLLaMA-2-7B4.00.7
70B → 13BLLaMA-2-70BLLaMA-2-13B4.00.8

6.2 训练技巧

  1. 学习率调度:使用余弦退火
  2. 权重衰减:1e-4
  3. 梯度累积:小批量时使用
  4. 早停:监控验证集困惑度

6.3 评估指标

指标说明
困惑度(PPL)越低越好
下游任务Accuracy分类/问答等
生成质量多样性、一致性
分布匹配与教师的相似度

7. 参考资料

扩展阅读:

Footnotes

  1. Gu Y, Dong L, Wei P, et al. MiniLLM: Knowledge Distillation of Large Language Models. ICLR, 2024. arXiv:2306.08543

  2. Agarwal R, Nijkamp E, Dai A F, et al. Generalized Knowledge Distillation via Regularization. arXiv:2310.07114, 2023.

  3. Hsieh C Y, Li C L, Kim Y C, et al. Distilling step-by-step! Outperforming larger language models with less training data and smaller model sizes. ACL, 2023. arXiv:2305.02301

  4. Beyond Imitation: Learning Key Reasoning Steps from Dual Chain-of-Thoughts. arXiv:2405.19737, 2024.

  5. Direct Preference Knowledge Distillation for Large Language Models. arXiv:2406.19774, 2024.