引言

类别增量学习(Class-Incremental Learning, Class-IL)是持续学习中最具挑战性的设置之一。与 Task-IL 不同,Class-IL 在测试时不知道任务边界,需要模型能够识别所有已学过的类别。

核心困境:训练时只看到新类别的数据,但测试时需要区分所有类别。


1. 类别增量学习的独特挑战

1.1 问题形式化

设共有 个阶段,每阶段学习 个新类别。

  • 训练阶段 :只能访问类别 的数据
  • 测试阶段:需要识别所有 个类别

1.2 三大核心挑战

挑战描述影响
表示漂移旧类别的特征表示在学习新类别时发生变化旧类别分类器失效
分类器偏置新类别数量远多于旧类别,导致分类器偏向新类旧类别被误分类
数据不平衡每个阶段只能看到部分类别的数据训练不充分

1.3 表示漂移问题

┌────────────────────────────────────────────────────────────────┐
│                    表示漂移示意图                               │
├────────────────────────────────────────────────────────────────┤
│                                                                │
│  阶段1 (类别A, B):                                             │
│       A ●                                                      │
│                     ● B                                        │
│                                                                │
│  阶段2 (类别C, D):                                             │
│       A ●           ● C                                        │
│                     ● B                                        │
│                                   ● D                          │
│                                                                │
│  问题: 旧类别的特征边界可能被新类别入侵                         │
└────────────────────────────────────────────────────────────────┘

2. iCaRL: 增量分类器与表示学习

2.1 核心思想

iCaRL(Incremental Classifier and Representation Learning)由 Rebuffi 等人在 2017 年提出,是 Class-IL 的里程碑式工作。1

三大核心组件

  1. 原型分类器:使用类原型(平均特征)而非线性分类器
  2. 样本回放:存储每个类别的代表性样本
  3. 表示学习:使用蒸馏损失保持旧类别表示

2.2 原型分类器

类原型定义

其中 是类别 的样本集, 是特征提取器。

分类规则

即选择最近的原型对应的类别。

2.3 PyTorch 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
import numpy as np
 
class iCaRL:
    """
    iCaRL: 增量分类器与表示学习
    
    参考文献: Rebuffi et al. "iCaRL: Incremental classifier and 
             representation learning", CVPR 2017
    """
    
    def __init__(self, feature_extractor, classifier, memory_per_class=20):
        """
        Args:
            feature_extractor: 特征提取网络
            classifier: 分类头
            memory_per_class: 每个类别存储的样本数
        """
        self.feature_extractor = feature_extractor
        self.classifier = classifier
        self.memory_per_class = memory_per_class
        
        # 类别数量
        self.n_classes = 0
        self.n_known_classes = 0
        
        # 样本存储
        self.exemplar_sets = defaultdict(list)  # {class: [sample1, sample2, ...]}
        
        # 旧模型(用于蒸馏)
        self.old_feature_extractor = None
        self.old_classifier = None
        
    def update_representation(self, dataset, batch_size=128, epochs=20):
        """
        更新表示学习部分
        
        使用新旧类别的数据训练特征提取器
        """
        # 保存旧模型
        self.old_feature_extractor = copy.deepcopy(self.feature_extractor)
        self.old_classifier = copy.deepcopy(self.classifier)
        self.old_feature_extractor.eval()
        self.old_classifier.eval()
        
        # 获取当前任务的类别数
        old_n_classes = self.n_known_classes
        new_n_classes = self.n_classes
        
        # 更新分类器输出维度
        if self.classifier.out_features < new_n_classes:
            old_weight = self.classifier.fc.weight.data
            old_bias = self.classifier.fc.bias.data
            self.classifier.fc = nn.Linear(
                self.classifier.fc.in_features, new_n_classes
            ).to(self.classifier.fc.weight.device)
            # 初始化新输出
            self.classifier.fc.weight.data[:old_n_classes] = old_weight
            self.classifier.fc.bias.data[:old_n_classes] = old_bias
        
        # 训练配置
        optimizer = torch.optim.Adam(self.feature_extractor.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()
        
        self.feature_extractor.train()
        
        for epoch in range(epochs):
            for inputs, targets in dataset:
                optimizer.zero_grad()
                
                features = self.feature_extractor(inputs)
                outputs = self.classifier(features)
                
                # 1. 分类损失(新类别)
                loss_cls = criterion(outputs[:, :self.n_classes], targets)
                
                # 2. 蒸馏损失(保持旧类别表示)
                loss_kd = 0
                if self.old_feature_extractor is not None:
                    with torch.no_grad():
                        old_features = self.old_feature_extractor(inputs)
                    
                    # 特征蒸馏
                    loss_kd = F.mse_loss(features, old_features)
                
                # 3. 对比损失(可选:增强特征区分度)
                # loss_contrast = self.contrastive_loss(features, targets)
                
                loss = loss_cls + 0.5 * loss_kd
                loss.backward()
                optimizer.step()
    
    def construct_exemplar_set(self, dataset, class_idx, num_samples=20):
        """
        为指定类别构建示例集
        
        使用 herding 策略选择代表性样本
        """
        self.feature_extractor.eval()
        
        features = []
        with torch.no_grad():
            for inputs, _ in dataset:
                feats = self.feature_extractor(inputs)
                features.append(feats)
        
        features = torch.cat(features, dim=0)
        class_mean = features.mean(dim=0)
        
        # Herding 策略:逐步选择使运行均值最接近类均值的样本
        selected = []
        running_mean = torch.zeros_like(class_mean)
        
        for _ in range(num_samples):
            # 计算每个候选样本加入后的距离
            min_dist = float('inf')
            best_idx = 0
            
            for i, feat in enumerate(features):
                if i in selected:
                    continue
                    
                candidate_mean = (running_mean * len(selected) + feat) / (len(selected) + 1)
                dist = torch.norm(candidate_mean - class_mean)
                
                if dist < min_dist:
                    min_dist = dist
                    best_idx = i
            
            selected.append(best_idx)
            running_mean = (running_mean * (len(selected) - 1) + features[best_idx]) / len(selected)
        
        return selected
    
    def reduce_exemplar_set(self, class_idx):
        """减少示例集大小"""
        if class_idx in self.exemplar_sets:
            self.exemplar_sets[class_idx] = self.exemplar_sets[class_idx][:self.memory_per_class]
    
    def update_exemplars(self, dataset):
        """
        更新所有类别的示例集
        """
        # 先压缩现有示例
        for cls in range(self.n_known_classes):
            self.reduce_exemplar_set(cls)
        
        # 为新类别创建示例
        for cls in range(self.n_known_classes, self.n_classes):
            # 获取该类别的样本
            cls_dataset = [(x, y) for x, y in dataset if y == cls]
            selected_idx = self.construct_exemplar_set(cls_dataset, cls, self.memory_per_class)
            self.exemplar_sets[cls] = [cls_dataset[i][0] for i in selected_idx]
    
    def classify(self, x):
        """
        原型分类
        
        Args:
            x: 输入样本
            
        Returns:
            predictions: 预测的类别
        """
        self.feature_extractor.eval()
        
        with torch.no_grad():
            features = self.feature_extractor(x)
        
        # 计算与每个类原型的距离
        distances = []
        for cls in range(self.n_classes):
            if cls in self.exemplar_sets and len(self.exemplar_sets[cls]) > 0:
                # 计算类原型
                exemplar_features = []
                for ex in self.exemplar_sets[cls]:
                    with torch.no_grad():
                        feat = self.feature_extractor(ex.unsqueeze(0))
                    exemplar_features.append(feat.squeeze(0))
                
                class_mean = torch.stack(exemplar_features).mean(dim=0)
                dist = torch.norm(features - class_mean, dim=-1)
            else:
                dist = torch.full((features.size(0),), float('inf'))
            
            distances.append(dist)
        
        distances = torch.stack(distances, dim=1)
        predictions = distances.argmin(dim=1)
        
        return predictions
    
    def add_task(self, dataset, n_new_classes):
        """添加新任务"""
        self.n_known_classes = self.n_classes
        self.n_classes += n_new_classes
        
        # 更新表示
        self.update_representation(dataset)
        
        # 更新示例
        self.update_exemplars(dataset)

2.4 Herding 策略详解

Herding 是一种智能样本选择策略,目标是选择最「代表性」的样本:

算法流程

  1. 计算类均值
  2. 迭代选择样本,使运行均值最接近
  3. 每次选择:

优势:相比随机采样,herding 选择的样本能更好地代表类别分布。


3. PODNet: 池化输出蒸馏

3.1 核心思想

PODNet(Pooled Output Distillation)由 Douillard 等人在 2020 年提出。2

核心洞察:不仅蒸馏最终输出,还蒸馏中间池化层的表示,保留更丰富的结构信息。

3.2 多粒度蒸馏

class PODNet:
    """
    PODNet: 池化输出蒸馏
    
    参考文献: Douillard et al. "PODNet: Pooled outputs distillation 
             for small-tasks incremental learning", ECCV 2020
    """
    
    def __init__(self, model, n_classes,蒸馏损失权重=1.0):
        self.model = model
        self.old_model = None
        self.lambda_dist = distillation_weight
        
        # 中间层池化器
        self.poolers = nn.ModuleDict()
        
    def compute_pod_loss(self, features_old, features_new, pool='avg'):
        """
        计算池化输出蒸馏损失
        
        Args:
            features_old: 旧模型的中间特征
            features_new: 新模型的中间特征
            pool: 池化方式 ('avg', 'max', 'both')
        """
        if pool == 'avg':
            old_pooled = F.adaptive_avg_pool2d(features_old, 1)
            new_pooled = F.adaptive_avg_pool2d(features_new, 1)
        elif pool == 'max':
            old_pooled = F.adaptive_max_pool2d(features_old, 1)
            new_pooled = F.adaptive_max_pool2d(features_new, 1)
        elif pool == 'both':
            old_avg = F.adaptive_avg_pool2d(features_old, 1)
            new_avg = F.adaptive_avg_pool2d(features_new, 1)
            old_max = F.adaptive_max_pool2d(features_old, 1)
            new_max = F.adaptive_max_pool2d(features_new, 1)
            old_pooled = torch.cat([old_avg, old_max], dim=1)
            new_pooled = torch.cat([new_avg, new_max], dim=1)
        
        # 平铺后计算 MSE
        loss = F.mse_loss(
            old_pooled.view(old_pooled.size(0), -1),
            new_pooled.view(new_pooled.size(0), -1)
        )
        
        return loss

3.3 多步蒸馏策略

PODNet 还提出了多步蒸馏策略:

  • 不只蒸馏到上一个阶段,而是蒸馏到所有之前的阶段
  • 每个阶段有独立的蒸馏权重
  • 总蒸馏损失:

4. 分类器偏置问题

4.1 问题分析

在学习新类别时,分类器倾向于将样本预测为数量占优的新类别,这是因为:

  1. 训练数据不平衡:旧类别样本少,新类别样本多
  2. 决策边界偏移:分类器将更多空间分配给新类别
  3. 特征表示偏移:旧类别的特征被新类别「挤压」

4.2 解决方案分类

方案描述代表方法
分类器校准调整分类器参数或阈值WA, BiC
日志its校准对 logits 进行后处理LDAM
特征正则化约束新旧类别的特征分布L2W
对比正则化增强类别间区分度COIL

4.3 BiC: 双分类器方法

BiC(Bias Correction)由 Wu 等人在 2019 年提出。3

核心思想:使用两个分类器——一个用于旧类别,一个用于新类别,然后融合。

class BiasCorrection:
    """
    BiC: 分类器偏置校正
    
    参考文献: Wu et al. "Large scale incremental learning", CVPR 2019
    """
    
    def __init__(self, model, n_old_classes):
        self.model = model
        self.n_old_classes = n_old_classes
        
        # 偏置校正层
        self.bias_correction = nn.Linear(1, 1)
        
    def train_bias_layer(self, val_loader, epochs=50):
        """
        在验证集上训练偏置校正层
        
        验证集包含新旧类别的混合样本
        """
        optimizer = torch.optim.Adam(self.bias_correction.parameters(), lr=0.01)
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(epochs):
            for inputs, targets in val_loader:
                optimizer.zero_grad()
                
                # 前向传播
                logits = self.model(inputs)
                
                # 分离新旧类别的 logits
                old_logits = logits[:, :self.n_old_classes]
                new_logits = logits[:, self.n_old_classes:]
                
                # 新类别 logits 加上学习的偏置
                beta = self.bias_correction.weight
                gamma = self.bias_correction.bias
                adjusted_new_logits = (new_logits + beta) * gamma
                
                # 合并
                adjusted_logits = torch.cat([old_logits, adjusted_new_logits], dim=1)
                
                loss = criterion(adjusted_logits, targets)
                loss.backward()
                optimizer.step()
    
    def predict(self, x):
        """带偏置校正的预测"""
        logits = self.model(x)
        
        old_logits = logits[:, :self.n_old_classes]
        new_logits = logits[:, self.n_old_classes:]
        
        beta = self.bias_correction.weight
        gamma = self.bias_correction.bias
        adjusted_new_logits = (new_logits + beta) * gamma
        
        adjusted_logits = torch.cat([old_logits, adjusted_new_logits], dim=1)
        
        return adjusted_logits.argmax(dim=1)

4.4 WA: 加权分类器

WA(Weight Alignment)是一种更简单的方法,通过对齐新旧分类器的权重来缓解偏置。4

class WeightAlignment:
    """
    WA: 权重对齐
    
    参考文献: Zhao et al. "Balanced softmax for class incremental learning"
    """
    
    def __init__(self, model, tau=2.0):
        self.model = model
        self.tau = tau  # 温度参数
        self.old_mean = None
        
    def compute_class_mean(self, dataloader):
        """计算每个类别的 logits 均值"""
        self.model.eval()
        
        class_means = defaultdict(list)
        
        with torch.no_grad():
            for inputs, targets in dataloader:
                logits = self.model(inputs)
                
                for t, l in zip(targets, logits):
                    class_means[t.item()].append(l)
        
        # 计算均值
        class_mean = {}
        for cls, logs in class_means.items():
            class_mean[cls] = torch.stack(logs).mean(dim=0)
        
        return class_mean
    
    def align_weights(self, n_old_classes):
        """
        对齐分类器权重
        
        使旧类别和新类别的 logits 均值对齐
        """
        # 获取分类器权重
        fc = self.model.classifier
        old_weights = fc.weight.data[:n_old_classes]
        new_weights = fc.weight.data[n_old_classes:]
        
        # 计算权重缩放因子
        if self.old_mean is not None:
            new_mean = fc.weight.data[n_old_classes:].mean(dim=0)
            scale = self.old_mean.norm() / new_mean.norm()
            
            # 缩放新类别权重
            fc.weight.data[n_old_classes:] *= scale
        
        # 保存当前均值用于下次对齐
        self.old_mean = fc.weight.data[:n_old_classes].mean(dim=0)

5. 其他代表性方法

5.1 方法对比表

方法表示学习样本回放偏置校正发表
iCaRL✓ (原型)-CVPR 2017
PODNet-ECCV 2020
BiCCVPR 2019
WA-CVPR 2020
LUCIRCVPR 2019
AFC-ECCV 2020

5.2 LUCIR: less Uniform Classifier

LUCIR 提出了三项改进:5

  1. 余量正则化:为旧类别设置更大的决策边界余量
  2. 蒸馏损失:结合特征蒸馏和分类器蒸馏
  3. 样本回放:使用 herding 选择代表性样本

5.3 最新进展:GACL

GACL(NeurIPS 2024)提出了无需样本存储的广义分析持续学习方法,在无回放设置下达到 SOTA 性能。6

核心思想:通过分析分类器权重的几何结构来保持类别可分性。


6. 评估与基准

6.1 常用数据集

数据集类别数图像数特点
CIFAR-10010060K32×32 彩色图像
ImageNet-100010001.2M大规模自然图像
Tiny-ImageNet200100K64×64 图像

6.2 评估协议

增量设置

  • 10阶段增量(每阶段10类)
  • 5阶段增量(每阶段20类)
  • 25阶段增量(每阶段4类)

6.3 评估指标

def evaluate_class_incremental(model, test_loaders, n_classes):
    """
    评估类别增量学习模型
    
    Args:
        model: 模型
        test_loaders: 每个阶段的测试集加载器
        n_classes: 总类别数
        
    Returns:
        results: 评估结果
    """
    n_stages = len(test_loaders)
    
    # 记录每个阶段后的性能
    stage_accuracy = []
    
    for stage in range(n_stages):
        correct = 0
        total = 0
        
        for inputs, targets in test_loaders[stage]:
            outputs = model(inputs)
            preds = outputs.argmax(dim=1)
            
            correct += (preds == targets).sum().item()
            total += targets.size(0)
        
        accuracy = correct / total
        stage_accuracy.append(accuracy)
    
    # 计算平均准确率和遗忘率
    avg_accuracy = np.mean(stage_accuracy)
    forgetting = 0
    
    for stage in range(n_stages - 1):
        # 每个阶段的遗忘
        forgetting += stage_accuracy[stage] - stage_accuracy[-1]
    
    forgetting /= (n_stages - 1)
    
    return {
        'stage_accuracy': stage_accuracy,
        'avg_accuracy': avg_accuracy,
        'forgetting': forgetting,
        'final_accuracy': stage_accuracy[-1]
    }

7. 实践建议

7.1 方法选择指南

场景推荐方法原因
存储受限BiC + WA偏置校正有效,减少回放
存储充足iCaRL + PODNet样本回放效果最好
极端存储限制GACL无需样本存储
延迟敏感WA简单有效

7.2 超参数设置

参数建议值调整策略
记忆/类20-50数据集大时减少
蒸馏损失权重0.5-1.0遗忘严重时增大
分类器余量0.5-2.0类别不平衡时增大
学习率0.001与标准训练相同

7.3 常见问题排查

问题可能原因解决方案
旧类别准确率骤降表示漂移严重增大蒸馏损失权重
新类别准确率低数据不平衡使用类别平衡采样
所有类别准确率低学习率过高降低学习率
训练不稳定特征范数变化使用特征归一化

参考资料


相关阅读

Footnotes

  1. Rebuffi, S. A., et al. (2017). iCaRL: Incremental classifier and representation learning. CVPR.

  2. Douillard, A., et al. (2020). PODNet: Pooled outputs distillation for small-tasks incremental learning. ECCV.

  3. Wu, Y., et al. (2019). Large scale incremental learning. CVPR.

  4. Zhao, B., et al. (2020). Balanced softmax for class incremental learning. CVPR.

  5. Hou, S., et al. (2019). Learning a unified classifier incrementally via rebalancing. CVPR.

  6. NeurIPS 2024. GACL: Exemplar-Free Generalized Analytic Continual Learning.