大语言模型蒸馏
将大语言模型(LLM)的知识迁移到小模型是一个具有挑战性的任务,因为生成式模型具有巨大的输出空间和自回归特性。本文档介绍针对LLM的蒸馏方法。
1. 生成式LLM蒸馏的挑战
1.1 与分类模型的区别
| 特性 | 分类模型 | 生成式LLM |
|---|---|---|
| 输出空间 | 固定类别(~1000) | 巨大词表(~50000) |
| 自回归 | 单次前向 | 逐token生成 |
| 训练目标 | 交叉熵 | 语言建模 |
| 暴露偏差 | 不存在 | 严重 |
1.2 暴露偏差问题
问题定义:训练时使用教师token作为输入,推理时使用自身输出作为输入,导致分布偏移。
训练时: [SOS] The cat sat [TEACHER] on the mat
推理时: [SOS] The cat sat [STUDENT] on the ...
↑
错误累积
1.3 MiniLLM的核心洞察
MiniLLM1指出生成式LLM蒸馏使用反向KL散度(Reverse KL)比标准KL更合适。
标准KL vs 反向KL:
| 散度 | 公式 | 特性 |
|---|---|---|
| 标准KL | 覆盖模式,鼓励学生探索低概率区域 | |
| 反向KL | 模式寻求,学生选择一个最可能的模式 |
为什么反向KL更好:
生成式LLM的输出分布通常是多峰的。反向KL让学生”锁定”在教师分布的一个峰上,避免了过度扩展到低概率区域导致的模糊生成。
2. MiniLLM方法
2.1 目标函数
MiniLLM使用反向KL散度作为蒸馏损失:
其中 是教师采样的序列。
简化形式(不使用教师采样):
2.2 训练流程
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
class MiniLLM:
"""
MiniLLM: Knowledge Distillation for Language Models
核心创新:使用反向KL散度进行生成式模型蒸馏
"""
def __init__(self, teacher_name, student_name, temperature=2.0,
beta=1.0, device='cuda'):
self.teacher = AutoModelForCausalLM.from_pretrained(teacher_name).to(device)
self.student = AutoModelForCausalLM.from_pretrained(student_name).to(device)
self.tokenizer = AutoTokenizer.from_pretrained(student_name)
self.temperature = temperature
self.beta = beta
self.device = device
# 冻结教师模型
for p in self.teacher.parameters():
p.requires_grad = False
def compute_reverse_kl_loss(self, input_ids, attention_mask):
"""
计算反向KL散度损失
对于每个token,计算:
D_KL(p_S || p_T) = Σ p_S(y) * log(p_S(y) / p_T(y))
"""
# 教师模型前向传播(不计算梯度)
with torch.no_grad():
teacher_outputs = self.teacher(
input_ids=input_ids,
attention_mask=attention_mask
)
teacher_logits = teacher_outputs.logits # (B, L, V)
# 学生模型前向传播
student_outputs = self.student(
input_ids=input_ids,
attention_mask=attention_mask
)
student_logits = student_outputs.logits
# 计算反向KL损失
batch_size, seq_len, vocab_size = student_logits.shape
# 学生概率分布
student_probs = F.softmax(student_logits / self.temperature, dim=-1)
# 教师概率分布
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
# 教师log概率
teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1)
# 反向KL: D_KL(p_S || p_T) = Σ p_S * log(p_S / p_T)
# = Σ p_S * log(p_S) - Σ p_S * log(p_T)
# = - Σ p_S * log(p_T) - H(p_S)
kl_loss = -torch.sum(student_probs * teacher_log_probs, dim=-1) # (B, L)
kl_loss = kl_loss * (self.temperature ** 2) # 温度缩放
# 加权平均(忽略padding)
mask = attention_mask.float()
kl_loss = (kl_loss * mask).sum() / mask.sum()
return kl_loss
def forward(self, input_ids, attention_mask):
"""前向传播,返回蒸馏损失"""
# 反向KL蒸馏损失
kl_loss = self.compute_reverse_kl_loss(input_ids, attention_mask)
# 标准语言建模损失(可选)
student_outputs = self.student(input_ids=input_ids, attention_mask=attention_mask)
lm_loss = F.cross_entropy(
student_outputs.logits[:, :-1].contiguous(),
input_ids[:, 1:].contiguous(),
reduction='mean'
)
# 总损失
total_loss = self.beta * kl_loss + lm_loss
return total_loss, kl_loss, lm_loss2.3 完整训练循环
def train_minillm(minillm, train_dataloader, num_epochs, lr=1e-4):
"""MiniLLM训练循环"""
optimizer = torch.optim.AdamW(minillm.student.parameters(), lr=lr)
for epoch in range(num_epochs):
for step, batch in enumerate(train_dataloader):
input_ids = batch['input_ids'].to(minillm.device)
attention_mask = batch['attention_mask'].to(minillm.device)
optimizer.zero_grad()
total_loss, kl_loss, lm_loss = minillm(input_ids, attention_mask)
total_loss.backward()
torch.nn.utils.clip_grad_norm_(minillm.student.parameters(), 1.0)
optimizer.step()
if step % 100 == 0:
print(f"Epoch {epoch}, Step {step}: "
f"Loss={total_loss.item():.4f}, "
f"KL={kl_loss.item():.4f}, "
f"LM={lm_loss.item():.4f}")3. GKD(Generalized Knowledge Distillation)
3.1 核心思想
GKD2针对离线蒸馏的分布不匹配问题提出On-policy蒸馏策略。
问题:离线蒸馏假设学生和教师在相同输入上学习,但实际中学生可能在某些区域学得很好而其他区域很差。
解决方案:动态选择学生预测不确定的区域进行蒸馏。
3.2 GKD目标函数
其中 是混合分布:
是真实标签的one-hot分布。
3.3 GKD实现
class GKD:
"""
GKD: Generalized Knowledge Distillation
On-policy蒸馏框架
"""
def __init__(self, teacher, student, beta=0.5, temperature=4.0):
self.teacher = teacher
self.student = student
self.beta = beta # 混合系数
self.temperature = temperature
def gkd_loss(self, input_ids, attention_mask, labels=None,
use_teacher_sampling=False):
"""
GKD损失计算
"""
# 教师分布
with torch.no_grad():
teacher_logits = self.teacher(
input_ids=input_ids,
attention_mask=attention_mask
).logits
# 学生分布
student_logits = self.student(
input_ids=input_ids,
attention_mask=attention_mask
).logits
# 如果使用教师采样
if use_teacher_sampling:
# 从教师分布采样token
probs = F.softmax(teacher_logits / self.temperature, dim=-1)
sampled_tokens = torch.multinomial(probs.view(-1, probs.size(-1)), 1)
sampled_tokens = sampled_tokens.view_as(labels)
# 创建混合目标
student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
# 混合分布
mixed_target = self.beta * teacher_probs + (1 - self.beta) * torch.zeros_like(teacher_probs)
# 手动设置采样token的目标为1
for b in range(labels.size(0)):
t = sampled_tokens[b]
# 在采样位置给予更高权重
pass # 简化实现
# 标准离线GKD
student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
teacher_probs = F.softmax(teacher_logits / self.temperature, dim=-1)
# 反向KL + 熵正则化
kl_loss = F.kl_div(
student_log_probs,
teacher_probs,
reduction='batchmean',
log_target=True
) * (self.temperature ** 2)
# 熵正则化(鼓励学生分布更平坦)
entropy_loss = -torch.sum(
teacher_probs * student_log_probs,
dim=-1
).mean()
return kl_loss + 0.1 * entropy_loss4. Chain-of-Thought蒸馏
4.1 Distilling Step-by-Step
Hsieh等人提出的Distilling Step-by-Step3利用LLM的推理能力进行蒸馏。
核心思想:不仅蒸馏最终答案,还蒸馏中间推理步骤。
class DistillStepByStep:
"""
Distilling Step-by-Step: Outperforming larger language models
with less training data and smaller model sizes
论文: ACL 2023
"""
def __init__(self, llm_api_key=None):
self.llm_api = LLMWrapper(api_key=llm_api_key) # PaLM 540B等
def extract_reasoning_chain(self, prompt, question):
"""从大模型提取推理链"""
response = self.llm_api.generate(
f"{prompt}\nQ: {question}\nA: Let me think step by step."
)
return response # 返回带有推理步骤的答案
def create_distillation_dataset(self, questions):
"""创建蒸馏数据集"""
dataset = []
for q in questions:
result = self.extract_reasoning_chain(self.prompt, q)
# 解析答案和推理步骤
answer, rationale = self.parse_response(result)
dataset.append({
'question': q,
'rationale': rationale,
'answer': answer
})
return dataset
def train_with_rationale(self, student_model, dataset):
"""使用推理链训练学生模型"""
# 多任务学习:预测答案 + 生成理由
for item in dataset:
# 教师学生联合训练
loss_answer = self.compute_answer_loss(student_model, item)
loss_rationale = self.compute_rationale_loss(student_model, item)
loss = loss_answer + lambda_ratio * loss_rationale
loss.backward()4.2 Beyond Imitation
Beyond Imitation4提出双链式思考蒸馏:
- 模仿链:从教师学习标准推理步骤
- 创新链:鼓励学生发现更简洁的推理路径
class DualChainDistillation:
"""
双链式思考蒸馏
"""
def __init__(self, teacher, student):
self.teacher = teacher
self.student = student
def dual_chain_loss(self, question, teacher_rationale, teacher_answer):
"""
计算双链蒸馏损失
"""
# 学生输出
student_output = self.student(question)
student_rationale = student_output.rationale
student_answer = student_output.answer
# 模仿损失:向教师学习
imitation_loss = self.rationale_matching_loss(
student_rationale, teacher_rationale
)
# 创新损失:鼓励发现更优推理(简化)
# 通过一致性正则化
innovation_loss = F.mse_loss(
student_output.logits,
teacher_output.logits.detach()
)
# 答案损失
answer_loss = F.cross_entropy(
student_output.logits,
teacher_answer
)
return {
'imitation': imitation_loss,
'innovation': innovation_loss,
'answer': answer_loss,
'total': imitation_loss + 0.5 * innovation_loss + answer_loss
}5. 偏好学习蒸馏
5.1 Direct Preference Knowledge Distillation
DPKD5结合人类偏好数据进行蒸馏:
class DirectPreferenceKD:
"""
Direct Preference Knowledge Distillation
结合DPO进行蒸馏
"""
def __init__(self, teacher, student, ref_model):
self.teacher = teacher
self.student = student
self.ref_model = ref_model # 参考模型(通常是SFT后的模型)
def compute_preference_loss(self, prompts, chosen_responses,
rejected_responses):
"""
计算偏好损失
"""
# 学生对chosen和rejected的评分
chosen_logits = self.student(prompts, chosen_responses)
rejected_logits = self.student(prompts, rejected_responses)
# DPO风格的损失
# log P(preferred|x) / P(dispreferred|x)
loss = -F.logsigmoid(chosen_logits - rejected_logits).mean()
# 结合蒸馏损失
teacher_chosen = self.teacher(prompts, chosen_responses)
teacher_rejected = self.teacher(prompts, rejected_responses)
distill_loss = F.mse_loss(
chosen_logits,
teacher_chosen
) + F.mse_loss(
rejected_logits,
teacher_rejected
)
return loss + 0.1 * distill_loss6. 实践建议
6.1 蒸馏配置推荐
| 模型规模 | 教师 | 学生 | 温度 | |
|---|---|---|---|---|
| 7B → 3B | LLaMA-2-7B | LLaMA-2-3B | 4.0 | 0.5 |
| 13B → 7B | LLaMA-2-13B | LLaMA-2-7B | 4.0 | 0.7 |
| 70B → 13B | LLaMA-2-70B | LLaMA-2-13B | 4.0 | 0.8 |
6.2 训练技巧
- 学习率调度:使用余弦退火
- 权重衰减:1e-4
- 梯度累积:小批量时使用
- 早停:监控验证集困惑度
6.3 评估指标
| 指标 | 说明 |
|---|---|
| 困惑度(PPL) | 越低越好 |
| 下游任务Accuracy | 分类/问答等 |
| 生成质量 | 多样性、一致性 |
| 分布匹配 | 与教师的相似度 |
7. 参考资料
扩展阅读:
Footnotes
-
Gu Y, Dong L, Wei P, et al. MiniLLM: Knowledge Distillation of Large Language Models. ICLR, 2024. arXiv:2306.08543 ↩
-
Agarwal R, Nijkamp E, Dai A F, et al. Generalized Knowledge Distillation via Regularization. arXiv:2310.07114, 2023. ↩
-
Hsieh C Y, Li C L, Kim Y C, et al. Distilling step-by-step! Outperforming larger language models with less training data and smaller model sizes. ACL, 2023. arXiv:2305.02301 ↩
-
Beyond Imitation: Learning Key Reasoning Steps from Dual Chain-of-Thoughts. arXiv:2405.19737, 2024. ↩
-
Direct Preference Knowledge Distillation for Large Language Models. arXiv:2406.19774, 2024. ↩