知识蒸馏基础
知识蒸馏(Knowledge Distillation, KD)由Hinton等人于2015年提出,是一种将大型”教师”模型的知识迁移到小型”学生”模型的技术1。本文档介绍知识蒸馏的基本原理、分类体系和核心方法。
1. 基本原理
1.1 核心思想
知识蒸馏的核心思想源于**暗知识(Dark Knowledge)**概念:
即使教师模型的最终预测错误,其输出的软概率分布仍包含关于数据相似性的丰富信息。
硬标签 vs 软标签:
| 类型 | 信息量 | 示例 |
|---|---|---|
| 硬标签 | 单一类别 | |
| 软标签 | 类别间相似性 |
软标签揭示了类别之间的语义关系(如”猫”和”狗”比”猫”和”汽车”更相似),这是硬标签无法提供的信息。
1.2 数学框架
KL散度目标函数:
其中:
- :真实标签(one-hot)
- :学生模型输出
- :教师模型软化后的概率分布
- :温度参数(Temperature)
- :平衡超参数
温度缩放:
当 时,分布趋向均匀;当 时,分布趋向one-hot。
1.3 为什么知识蒸馏有效
理论解释:
- 信息压缩:教师模型学习到的决策边界包含比标签更多的信息
- 梯度平滑:软标签提供更平滑的梯度信号
- 正则化:防止学生模型过拟合训练数据
实验观察:
- 使用相同的训练数据,学生模型从教师软标签学习比从硬标签学习表现更好
- 特别是在小数据集或噪声数据上效果显著
2. 知识蒸馏分类体系
2.1 按蒸馏来源分类
知识蒸馏
├── 响应蒸馏 (Response-based KD)
│ └── 匹配教师模型的输出 logits
├── 特征蒸馏 (Feature-based KD)
│ └── 匹配教师模型的中间层特征
└── 关系蒸馏 (Relation-based KD)
└── 匹配样本间的关系结构
2.2 按蒸馏策略分类
| 类型 | 教师状态 | 特点 |
|---|---|---|
| 离线蒸馏 | 固定 | 两阶段:先训练教师,再蒸馏 |
| 在线蒸馏 | 共同更新 | 单阶段:教师和学生同时学习 |
| 自蒸馏 | 同一模型 | 利用模型不同版本/视角 |
2.3 按模型架构分类
| 类型 | 教师/学生关系 | 挑战 |
|---|---|---|
| 同构蒸馏 | 相同架构 | 主要是如何分配层 |
| 异构蒸馏 | 不同架构 | 中间层对应关系 |
3. 响应蒸馏
3.1 定义
响应蒸馏直接让学生学习教师模型的最终输出(logits或预测概率)。
适用场景:
- 教师和学生架构相似
- 输出空间可直接对应
- 任务为分类或回归
3.2 分类器知识蒸馏
import torch
import torch.nn as nn
import torch.nn.functional as F
def response_distillation(student_logits, teacher_logits, labels,
temperature=4.0, alpha=0.7):
"""
响应蒸馏损失函数
Args:
student_logits: 学生模型logits
teacher_logits: 教师模型logits
labels: 真实标签
temperature: 温度参数
alpha: 蒸馏损失权重
"""
# 软化后的概率分布
soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)
soft_student = F.log_softmax(student_logits / temperature, dim=-1)
# KL散度损失
distill_loss = F.kl_div(
soft_student,
soft_teacher,
reduction='batchmean'
) * (temperature ** 2)
# 硬标签交叉熵
hard_loss = F.cross_entropy(student_logits, labels)
# 加权组合
total_loss = alpha * distill_loss + (1 - alpha) * hard_loss
return total_loss3.3 生成式模型的蒸馏挑战
对于生成式语言模型(如GPT、LLaMA),传统的响应蒸馏面临挑战:
| 问题 | 原因 |
|---|---|
| 巨大输出空间 | 词表通常数万个token |
| 自回归特性 | 每个位置需要条件分布 |
| 暴露偏差 | 训练与推理不一致 |
解决方案:见MiniLLM与LLM蒸馏。
4. 特征蒸馏
4.1 定义
特征蒸馏利用教师模型中间层的表示来指导学生学习。
适用场景:
- 教师和学生有可对应的中间层
- 需要迁移更丰富的表示知识
- 网络深度不同时
4.2 特征匹配方法
早期方法:Layer-by-layer匹配
class FeatureDistillationLoss(nn.Module):
"""特征蒸馏损失"""
def __init__(self, temperature=2.0):
super().__init__()
self.temperature = temperature
def forward(self, student_features, teacher_features):
"""
计算特征蒸馏损失
Args:
student_features: List[Tensor],学生各层特征
teacher_features: List[Tensor],教师各层特征
"""
loss = 0
for s_feat, t_feat in zip(student_features, teacher_features):
# 确保维度匹配
if s_feat.shape != t_feat.shape:
s_feat = self._adapt_dimension(s_feat, t_feat.shape)
# MSE损失
loss += F.mse_loss(s_feat, t_feat)
return loss / len(student_features)
def _adapt_dimension(self, feat, target_shape):
"""维度适配层"""
# 使用1×1卷积或线性层适配维度
# 简化实现:插值
return F.interpolate(
feat,
size=target_shape[-2:],
mode='bilinear',
align_corners=False
)4.3 FitNet:深度特征蒸馏
FitNet2是最早的特征蒸馏方法之一,提出Hints-based训练:
class FitNetLoss(nn.Module):
"""
FitNet: 深度特征蒸馏
论文: "FitNet: How versatile are fitting relations in deeper networks?"
"""
def __init__(self, hint_layer_student, hint_layer_teacher):
super().__init__()
self.hint_layer_student = hint_layer_student
self.hint_layer_teacher = hint_layer_teacher
# 适配器:维度变换
self.adapter = None # 根据网络结构定义
def forward(self, student_features, teacher_features):
s_hint = self.hint_layer_student(student_features)
t_hint = self.hint_layer_teacher(teacher_features)
if self.adapter:
s_hint = self.adapter(s_hint)
return F.mse_loss(s_hint, t_hint)4.4 注意力蒸馏
Zagoruyko等人提出注意力迁移3,利用教师网络的注意力图指导学生:
注意力定义:
其中 是特征图的通道维度。
def attention_transfer(student_features, teacher_features, p=2):
"""
注意力迁移
Args:
p: 注意力计算的幂次(p=2为二范数注意力)
"""
loss = 0
for s_feat, t_feat in zip(student_features, teacher_features):
# 计算注意力图
s_attn = torch.sum(s_feat ** p, dim=1, keepdim=True)
t_attn = torch.sum(t_feat ** p, dim=1, keepdim=True)
# 归一化
s_attn = F.normalize(s_attn.view(s_attn.size(0), -1), dim=1)
t_attn = F.normalize(t_attn.view(t_attn.size(0), -1), dim=1)
loss += F.mse_loss(s_attn, t_attn)
return loss5. 关系蒸馏
5.1 定义
关系蒸馏不仅匹配单样本的输出或特征,还建模样本之间的关系知识。
核心假设:教师模型学到的样本间关系结构本身就是重要的知识。
5.2 关系知识蒸馏(RKD)
Distance-based RKD4:
def rkd_distance_loss(student_features, teacher_features, eps=1e-12):
"""
基于距离的关系蒸馏
匹配样本间距离的关系
"""
batch_size = student_features[0].size(0)
# 提取特征(使用最后一层)
s_feat = student_features[-1]
t_feat = teacher_features[-1]
# 归一化特征
s_feat = F.normalize(s_feat.view(batch_size, -1), dim=1)
t_feat = F.normalize(t_feat.view(batch_size, -1), dim=1)
# 计算距离矩阵
s_dist = pairwise_distance(s_feat)
t_dist = pairwise_distance(t_feat)
# MSE损失
return F.mse_loss(s_dist, t_dist)
def pairwise_distance(x):
"""计算成对距离矩阵"""
x_square = (x ** 2).sum(dim=1)
dist = x_square.unsqueeze(0) - 2 * x @ x.T + x_square.unsqueeze(1)
return dist.clamp(min=0).sqrt()Angle-based RKD:
考虑三元组样本间的角度关系:
def rkd_angle_loss(student_features, teacher_features):
"""
基于角度的关系蒸馏
"""
batch_size = student_features[-1].size(0)
s_feat = student_features[-1].view(batch_size, -1)
t_feat = teacher_features[-1].view(batch_size, -1)
# 归一化
s_feat = F.normalize(s_feat, dim=1)
t_feat = F.normalize(t_feat, dim=1)
loss = 0
for i in range(batch_size):
for j in range(i + 1, batch_size):
for k in range(j + 1, batch_size):
# 计算角度
s_angle = angle_triplet(s_feat[i], s_feat[j], s_feat[k])
t_angle = angle_triplet(t_feat[i], t_feat[j], t_feat[k])
loss += F.smooth_l1_loss(s_angle, t_angle)
return loss / (batch_size * (batch_size - 1) * (batch_size - 2) / 6)
def angle_triplet(a, b, c):
"""计算三元组角度向量"""
v1 = a - b
v2 = c - b
cos_angle = F.cosine_similarity(v1, v2, dim=0)
return cos_angle6. 蒸馏策略
6.1 离线蒸馏
两阶段流程:
阶段1: 训练教师模型
数据集 → 教师模型(充分训练)
阶段2: 知识蒸馏
数据集 + 教师软标签 → 学生模型
特点:
- 教师模型质量高
- 计算开销较大
- 教师需单独训练
6.2 在线蒸馏
单阶段流程:
同时训练教师和学生
↕ 互相学习 ↕
教师模型 ←→ 学生模型
def online_distillation_step(models, optimizers, data):
"""在线蒸馏:所有模型共同学习"""
# 多个模型ensemble作为教师
teachers_logits = [model(data) for model in models[:-1]]
teacher_avg = torch.stack(teachers_logits).mean(dim=0)
# 每个学生从ensemble学习
for student, opt in zip(models[-1:], optimizers):
opt.zero_grad()
student_logits = student(data)
# KL散度损失
loss = F.kl_div(
F.log_softmax(student_logits / 4, dim=-1),
F.softmax(teacher_avg / 4, dim=-1),
reduction='batchmean'
) * (4 ** 2)
loss.backward()
opt.step()6.3 自蒸馏
概念:使用同一模型的不同部分或版本作为教师和学生。
典型应用:
| 方法 | 教师来源 | 特点 |
|---|---|---|
| Self-Distillation | 同一模型深层→浅层 | 知识迁移 |
| L2P | 主网络→提示网络 | 持续学习 |
| Be Your Own Teacher | 完整模型→自身 | 性能提升 |
def self_distillation(model, data, temperature=4.0):
"""自蒸馏:深层特征指导浅层学习"""
# 获取不同层深度的特征
deep_features = model.get_deep_features(data)
shallow_features = model.get_shallow_features(data)
# 深层→浅层的知识迁移
loss = F.mse_loss(
model.embed(shallow_features),
model.embed(deep_features).detach()
)
return loss7. 蒸馏温度与损失设计
7.1 温度参数的作用
| 值 | 效果 |
|---|---|
| 标准softmax,无软化 | |
| 软化分布,突出次优类别的关系 | |
| 趋向均匀分布 | |
| 趋向硬标签 |
温度选择建议:
- 分类任务:
- 知识丰富度低时:使用较高温度
- 类别数多时:使用较高温度
7.2 损失函数变体
标签平滑交叉熵:
方差正则化:
def distillation_loss_with_variance(student_logits, teacher_logits, labels,
temperature=4.0, alpha=0.9):
"""
带方差正则化的蒸馏损失
"""
# KL散度
soft_student = F.log_softmax(student_logits / temperature, dim=-1)
soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)
kl_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (temperature ** 2)
# 硬标签损失
hard_loss = F.cross_entropy(student_logits, labels)
# 类别平衡正则化
class_weights = 1.0 / (student_logits.size(1) ** 0.5)
return alpha * kl_loss + (1 - alpha) * hard_loss8. 实践指南
8.1 蒸馏策略选择
| 场景 | 推荐策略 |
|---|---|
| 同架构压缩 | 响应蒸馏 |
| 不同架构压缩 | 特征蒸馏 |
| 层级对应困难 | 关系蒸馏 |
| 无教师模型 | 自蒸馏 |
8.2 超参数设置
| 超参数 | 建议值 | 调整原则 |
|---|---|---|
| 温度 | 2-5 | 类别越多越高 |
| 0.5-0.9 | 任务越难越低 | |
| 教师容量 | 学生3-10倍 | 受硬件限制 |
| 训练轮数 | 通常更长 | 需收敛保证 |
8.3 常见问题
| 问题 | 原因 | 解决方案 |
|---|---|---|
| 学生表现差 | 教师-学生差距大 | 分阶段蒸馏 |
| 训练不稳定 | 软标签过于尖锐 | 提高温度 |
| 性能提升有限 | 学生容量不足 | 减小容量差距 |
9. 参考资料
扩展阅读:
Footnotes
-
Hinton G, Vinyals O, Dean J. Distilling the knowledge in a neural network. arXiv:1503.02531, 2015. ↩
-
Romero A, Ballas N, Kahou S E, et al. Fitnets: Hints for thin deep nets. ICLR, 2015. arXiv:1412.6550 ↩
-
Zagoruyko S, Komodakis N. Paying more attention to attention: Improving the performance of convolutional neural networks via attention transfer. ICLR, 2017. arXiv:1612.03928 ↩
-
Park W, Kim D, Lu Y, et al. Relational knowledge distillation. CVPR, 2019. arXiv:1904.05068 ↩