知识蒸馏基础

知识蒸馏(Knowledge Distillation, KD)由Hinton等人于2015年提出,是一种将大型”教师”模型的知识迁移到小型”学生”模型的技术1。本文档介绍知识蒸馏的基本原理、分类体系和核心方法。

1. 基本原理

1.1 核心思想

知识蒸馏的核心思想源于**暗知识(Dark Knowledge)**概念:

即使教师模型的最终预测错误,其输出的软概率分布仍包含关于数据相似性的丰富信息。

硬标签 vs 软标签

类型信息量示例
硬标签单一类别
软标签类别间相似性

软标签揭示了类别之间的语义关系(如”猫”和”狗”比”猫”和”汽车”更相似),这是硬标签无法提供的信息。

1.2 数学框架

KL散度目标函数

其中:

  • :真实标签(one-hot)
  • :学生模型输出
  • :教师模型软化后的概率分布
  • 温度参数(Temperature)
  • :平衡超参数

温度缩放

时,分布趋向均匀;当 时,分布趋向one-hot。

1.3 为什么知识蒸馏有效

理论解释

  1. 信息压缩:教师模型学习到的决策边界包含比标签更多的信息
  2. 梯度平滑:软标签提供更平滑的梯度信号
  3. 正则化:防止学生模型过拟合训练数据

实验观察

  • 使用相同的训练数据,学生模型从教师软标签学习比从硬标签学习表现更好
  • 特别是在小数据集或噪声数据上效果显著

2. 知识蒸馏分类体系

2.1 按蒸馏来源分类

知识蒸馏
├── 响应蒸馏 (Response-based KD)
│   └── 匹配教师模型的输出 logits
├── 特征蒸馏 (Feature-based KD)
│   └── 匹配教师模型的中间层特征
└── 关系蒸馏 (Relation-based KD)
    └── 匹配样本间的关系结构

2.2 按蒸馏策略分类

类型教师状态特点
离线蒸馏固定两阶段:先训练教师,再蒸馏
在线蒸馏共同更新单阶段:教师和学生同时学习
自蒸馏同一模型利用模型不同版本/视角

2.3 按模型架构分类

类型教师/学生关系挑战
同构蒸馏相同架构主要是如何分配层
异构蒸馏不同架构中间层对应关系

3. 响应蒸馏

3.1 定义

响应蒸馏直接让学生学习教师模型的最终输出(logits或预测概率)。

适用场景

  • 教师和学生架构相似
  • 输出空间可直接对应
  • 任务为分类或回归

3.2 分类器知识蒸馏

import torch
import torch.nn as nn
import torch.nn.functional as F
 
def response_distillation(student_logits, teacher_logits, labels, 
                         temperature=4.0, alpha=0.7):
    """
    响应蒸馏损失函数
    
    Args:
        student_logits: 学生模型logits
        teacher_logits: 教师模型logits
        labels: 真实标签
        temperature: 温度参数
        alpha: 蒸馏损失权重
    """
    # 软化后的概率分布
    soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)
    soft_student = F.log_softmax(student_logits / temperature, dim=-1)
    
    # KL散度损失
    distill_loss = F.kl_div(
        soft_student, 
        soft_teacher, 
        reduction='batchmean'
    ) * (temperature ** 2)
    
    # 硬标签交叉熵
    hard_loss = F.cross_entropy(student_logits, labels)
    
    # 加权组合
    total_loss = alpha * distill_loss + (1 - alpha) * hard_loss
    
    return total_loss

3.3 生成式模型的蒸馏挑战

对于生成式语言模型(如GPT、LLaMA),传统的响应蒸馏面临挑战:

问题原因
巨大输出空间词表通常数万个token
自回归特性每个位置需要条件分布
暴露偏差训练与推理不一致

解决方案:见MiniLLM与LLM蒸馏

4. 特征蒸馏

4.1 定义

特征蒸馏利用教师模型中间层的表示来指导学生学习。

适用场景

  • 教师和学生有可对应的中间层
  • 需要迁移更丰富的表示知识
  • 网络深度不同时

4.2 特征匹配方法

早期方法:Layer-by-layer匹配

class FeatureDistillationLoss(nn.Module):
    """特征蒸馏损失"""
    def __init__(self, temperature=2.0):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, student_features, teacher_features):
        """
        计算特征蒸馏损失
        
        Args:
            student_features: List[Tensor],学生各层特征
            teacher_features: List[Tensor],教师各层特征
        """
        loss = 0
        
        for s_feat, t_feat in zip(student_features, teacher_features):
            # 确保维度匹配
            if s_feat.shape != t_feat.shape:
                s_feat = self._adapt_dimension(s_feat, t_feat.shape)
            
            # MSE损失
            loss += F.mse_loss(s_feat, t_feat)
        
        return loss / len(student_features)
    
    def _adapt_dimension(self, feat, target_shape):
        """维度适配层"""
        # 使用1×1卷积或线性层适配维度
        # 简化实现:插值
        return F.interpolate(
            feat, 
            size=target_shape[-2:], 
            mode='bilinear', 
            align_corners=False
        )

4.3 FitNet:深度特征蒸馏

FitNet2是最早的特征蒸馏方法之一,提出Hints-based训练

class FitNetLoss(nn.Module):
    """
    FitNet: 深度特征蒸馏
    
    论文: "FitNet: How versatile are fitting relations in deeper networks?"
    """
    def __init__(self, hint_layer_student, hint_layer_teacher):
        super().__init__()
        self.hint_layer_student = hint_layer_student
        self.hint_layer_teacher = hint_layer_teacher
        
        # 适配器:维度变换
        self.adapter = None  # 根据网络结构定义
    
    def forward(self, student_features, teacher_features):
        s_hint = self.hint_layer_student(student_features)
        t_hint = self.hint_layer_teacher(teacher_features)
        
        if self.adapter:
            s_hint = self.adapter(s_hint)
        
        return F.mse_loss(s_hint, t_hint)

4.4 注意力蒸馏

Zagoruyko等人提出注意力迁移3,利用教师网络的注意力图指导学生:

注意力定义

其中 是特征图的通道维度。

def attention_transfer(student_features, teacher_features, p=2):
    """
    注意力迁移
    
    Args:
        p: 注意力计算的幂次(p=2为二范数注意力)
    """
    loss = 0
    
    for s_feat, t_feat in zip(student_features, teacher_features):
        # 计算注意力图
        s_attn = torch.sum(s_feat ** p, dim=1, keepdim=True)
        t_attn = torch.sum(t_feat ** p, dim=1, keepdim=True)
        
        # 归一化
        s_attn = F.normalize(s_attn.view(s_attn.size(0), -1), dim=1)
        t_attn = F.normalize(t_attn.view(t_attn.size(0), -1), dim=1)
        
        loss += F.mse_loss(s_attn, t_attn)
    
    return loss

5. 关系蒸馏

5.1 定义

关系蒸馏不仅匹配单样本的输出或特征,还建模样本之间的关系知识。

核心假设:教师模型学到的样本间关系结构本身就是重要的知识。

5.2 关系知识蒸馏(RKD)

Distance-based RKD4

def rkd_distance_loss(student_features, teacher_features, eps=1e-12):
    """
    基于距离的关系蒸馏
    
    匹配样本间距离的关系
    """
    batch_size = student_features[0].size(0)
    
    # 提取特征(使用最后一层)
    s_feat = student_features[-1]
    t_feat = teacher_features[-1]
    
    # 归一化特征
    s_feat = F.normalize(s_feat.view(batch_size, -1), dim=1)
    t_feat = F.normalize(t_feat.view(batch_size, -1), dim=1)
    
    # 计算距离矩阵
    s_dist = pairwise_distance(s_feat)
    t_dist = pairwise_distance(t_feat)
    
    # MSE损失
    return F.mse_loss(s_dist, t_dist)
 
 
def pairwise_distance(x):
    """计算成对距离矩阵"""
    x_square = (x ** 2).sum(dim=1)
    dist = x_square.unsqueeze(0) - 2 * x @ x.T + x_square.unsqueeze(1)
    return dist.clamp(min=0).sqrt()

Angle-based RKD

考虑三元组样本间的角度关系:

def rkd_angle_loss(student_features, teacher_features):
    """
    基于角度的关系蒸馏
    """
    batch_size = student_features[-1].size(0)
    
    s_feat = student_features[-1].view(batch_size, -1)
    t_feat = teacher_features[-1].view(batch_size, -1)
    
    # 归一化
    s_feat = F.normalize(s_feat, dim=1)
    t_feat = F.normalize(t_feat, dim=1)
    
    loss = 0
    for i in range(batch_size):
        for j in range(i + 1, batch_size):
            for k in range(j + 1, batch_size):
                # 计算角度
                s_angle = angle_triplet(s_feat[i], s_feat[j], s_feat[k])
                t_angle = angle_triplet(t_feat[i], t_feat[j], t_feat[k])
                loss += F.smooth_l1_loss(s_angle, t_angle)
    
    return loss / (batch_size * (batch_size - 1) * (batch_size - 2) / 6)
 
 
def angle_triplet(a, b, c):
    """计算三元组角度向量"""
    v1 = a - b
    v2 = c - b
    cos_angle = F.cosine_similarity(v1, v2, dim=0)
    return cos_angle

6. 蒸馏策略

6.1 离线蒸馏

两阶段流程

阶段1: 训练教师模型
    数据集 → 教师模型(充分训练)

阶段2: 知识蒸馏
    数据集 + 教师软标签 → 学生模型

特点

  • 教师模型质量高
  • 计算开销较大
  • 教师需单独训练

6.2 在线蒸馏

单阶段流程

同时训练教师和学生
    ↕ 互相学习 ↕
教师模型 ←→ 学生模型
def online_distillation_step(models, optimizers, data):
    """在线蒸馏:所有模型共同学习"""
    # 多个模型ensemble作为教师
    teachers_logits = [model(data) for model in models[:-1]]
    teacher_avg = torch.stack(teachers_logits).mean(dim=0)
    
    # 每个学生从ensemble学习
    for student, opt in zip(models[-1:], optimizers):
        opt.zero_grad()
        student_logits = student(data)
        
        # KL散度损失
        loss = F.kl_div(
            F.log_softmax(student_logits / 4, dim=-1),
            F.softmax(teacher_avg / 4, dim=-1),
            reduction='batchmean'
        ) * (4 ** 2)
        
        loss.backward()
        opt.step()

6.3 自蒸馏

概念:使用同一模型的不同部分或版本作为教师和学生。

典型应用

方法教师来源特点
Self-Distillation同一模型深层→浅层知识迁移
L2P主网络→提示网络持续学习
Be Your Own Teacher完整模型→自身性能提升
def self_distillation(model, data, temperature=4.0):
    """自蒸馏:深层特征指导浅层学习"""
    
    # 获取不同层深度的特征
    deep_features = model.get_deep_features(data)
    shallow_features = model.get_shallow_features(data)
    
    # 深层→浅层的知识迁移
    loss = F.mse_loss(
        model.embed(shallow_features),
        model.embed(deep_features).detach()
    )
    
    return loss

7. 蒸馏温度与损失设计

7.1 温度参数的作用

效果
标准softmax,无软化
软化分布,突出次优类别的关系
趋向均匀分布
趋向硬标签

温度选择建议

  • 分类任务:
  • 知识丰富度低时:使用较高温度
  • 类别数多时:使用较高温度

7.2 损失函数变体

标签平滑交叉熵

方差正则化

def distillation_loss_with_variance(student_logits, teacher_logits, labels, 
                                   temperature=4.0, alpha=0.9):
    """
    带方差正则化的蒸馏损失
    """
    # KL散度
    soft_student = F.log_softmax(student_logits / temperature, dim=-1)
    soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)
    kl_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (temperature ** 2)
    
    # 硬标签损失
    hard_loss = F.cross_entropy(student_logits, labels)
    
    # 类别平衡正则化
    class_weights = 1.0 / (student_logits.size(1) ** 0.5)
    
    return alpha * kl_loss + (1 - alpha) * hard_loss

8. 实践指南

8.1 蒸馏策略选择

场景推荐策略
同架构压缩响应蒸馏
不同架构压缩特征蒸馏
层级对应困难关系蒸馏
无教师模型自蒸馏

8.2 超参数设置

超参数建议值调整原则
温度 2-5类别越多越高
0.5-0.9任务越难越低
教师容量学生3-10倍受硬件限制
训练轮数通常更长需收敛保证

8.3 常见问题

问题原因解决方案
学生表现差教师-学生差距大分阶段蒸馏
训练不稳定软标签过于尖锐提高温度
性能提升有限学生容量不足减小容量差距

9. 参考资料

扩展阅读:

Footnotes

  1. Hinton G, Vinyals O, Dean J. Distilling the knowledge in a neural network. arXiv:1503.02531, 2015.

  2. Romero A, Ballas N, Kahou S E, et al. Fitnets: Hints for thin deep nets. ICLR, 2015. arXiv:1412.6550

  3. Zagoruyko S, Komodakis N. Paying more attention to attention: Improving the performance of convolutional neural networks via attention transfer. ICLR, 2017. arXiv:1612.03928

  4. Park W, Kim D, Lu Y, et al. Relational knowledge distillation. CVPR, 2019. arXiv:1904.05068