引言

基于回放的持续学习方法通过存储旧任务的样本或表示,在学习新任务时一并训练这些样本,从而对抗灾难性遗忘。这类方法简单直观,是目前最有效的持续学习方法之一。

核心思想可以用一句话概括:「温故而知新」——定期复习旧知识来巩固记忆。


1. 经验回放 (Experience Replay)

1.1 核心思想

经验回放(Experience Replay,ER)是最简单的回放方法。其核心思想是维护一个记忆缓冲区,存储来自旧任务的代表性样本。1

┌────────────────────────────────────────────────────────────────┐
│                      经验回放示意                               │
├────────────────────────────────────────────────────────────────┤
│                                                                │
│   任务A数据 ──┬──→ 记忆缓冲区 ──┬──→ 训练                      │
│              │                 │                               │
│   任务B数据 ──┤                 ├──→ 模型更新                  │
│              │                 │                               │
│   任务C数据 ──┘                 └──→ ...                       │
│                                                                │
│   缓冲区策略: 均匀采样 / 优先级采样 / 类平衡采样                │
└────────────────────────────────────────────────────────────────┘

1.2 损失函数

其中 是从缓冲区采样的回放损失。

1.3 PyTorch 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
import numpy as np
import random
from collections import defaultdict
 
class SimpleReplayBuffer:
    """
    简单的经验回放缓冲区
    
    存储格式: {class_id: [sample1, sample2, ...]}
    """
    
    def __init__(self, max_size_per_class=50):
        self.buffer = defaultdict(list)  # {class: [(x, y), ...]}
        self.max_size_per_class = max_size_per_class
        
    def update(self, class_idx, samples):
        """
        更新缓冲区
        
        Args:
            class_idx: 类别索引
            samples: 样本列表 [(x1, y1), (x2, y2), ...]
        """
        self.buffer[class_idx].extend(samples)
        
        # 保持每个类别的样本数量限制
        if len(self.buffer[class_idx]) > self.max_size_per_class:
            # FIFO 策略:保留最新的样本
            self.buffer[class_idx] = self.buffer[class_idx][-self.max_size_per_class:]
    
    def sample(self, batch_size):
        """
        从缓冲区采样
        
        Args:
            batch_size: 采样数量
            
        Returns:
            samples: 采样的样本列表
        """
        all_samples = []
        for samples in self.buffer.values():
            all_samples.extend(samples)
        
        if len(all_samples) == 0:
            return None
        
        return random.sample(all_samples, min(batch_size, len(all_samples)))
    
    def sample_balanced(self, batch_size):
        """
        类别平衡采样
        """
        classes = list(self.buffer.keys())
        if not classes:
            return None
        
        # 每个类采样相等数量
        samples_per_class = batch_size // len(classes)
        samples = []
        
        for cls in classes:
            cls_samples = self.buffer[cls]
            if len(cls_samples) >= samples_per_class:
                samples.extend(random.sample(cls_samples, samples_per_class))
            else:
                samples.extend(cls_samples)
        
        # 如果不够,补充采样
        if len(samples) < batch_size:
            additional = batch_size - len(samples)
            all_samples = [s for samples in self.buffer.values() for s in samples]
            samples.extend(random.sample(all_samples, additional))
        
        return samples
    
    def __len__(self):
        return sum(len(s) for s in self.buffer.values())
 
 
class ExperienceReplay:
    """
    经验回放持续学习器
    """
    
    def __init__(self, model, buffer_size_per_class=50, replay_alpha=0.5):
        self.model = model
        self.buffer = SimpleReplayBuffer(max_size_per_class=buffer_size_per_class)
        self.replay_alpha = replay_alpha  # 回放样本在batch中的比例
        
    def update_buffer(self, dataloader, num_samples=1000):
        """
        从数据加载器中采样更新缓冲区
        
        Args:
            dataloader: 数据加载器
            num_samples: 采样数量
        """
        self.model.eval()
        samples_collected = 0
        
        with torch.no_grad():
            for inputs, targets in dataloader:
                # 收集样本
                for x, y in zip(inputs, targets):
                    self.buffer.update(y.item(), [(x, y)])
                    samples_collected += 1
                    if samples_collected >= num_samples:
                        break
                if samples_collected >= num_samples:
                    break
    
    def train_step(self, x_new, y_new, optimizer, criterion):
        """
        一次训练步骤
        
        Args:
            x_new: 新任务输入
            y_new: 新任务标签
            optimizer: 优化器
            criterion: 损失函数
        """
        batch_size = x_new.size(0)
        replay_size = int(batch_size * self.replay_alpha)
        
        # 准备新任务批次
        new_indices = torch.randperm(batch_size)[:batch_size - replay_size]
        x_batch = x_new[new_indices]
        y_batch = y_new[new_indices]
        
        # 添加回放样本
        replay_samples = self.buffer.sample(replay_size)
        if replay_samples:
            replay_x = torch.stack([s[0] for s in replay_samples])
            replay_y = torch.stack([s[1] for s in replay_samples])
            x_batch = torch.cat([x_batch, replay_x], dim=0)
            y_batch = torch.cat([y_batch, replay_y], dim=0)
        
        # 前向传播和损失计算
        optimizer.zero_grad()
        outputs = self.model(x_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        
        return loss.item()
 
 
class ReservoirSamplingBuffer:
    """
    水库采样缓冲区
    
    优势: 保持样本的均匀分布,不偏向最新样本
    """
    
    def __init__(self, max_size=1000):
        self.max_size = max_size
        self.buffer = []
        self.n_seen = 0
        
    def update(self, samples):
        """
        水库采样更新
        
        如果 buffer 未满,随机添加
        如果 buffer 已满,以概率 max_size/n_seen 替换
        """
        for sample in samples:
            self.n_seen += 1
            
            if len(self.buffer) < self.max_size:
                self.buffer.append(sample)
            else:
                # 以 max_size/n_seen 的概率替换
                j = random.randint(0, self.n_seen - 1)
                if j < self.max_size:
                    self.buffer[j] = sample
    
    def sample(self, batch_size):
        return random.sample(self.buffer, min(batch_size, len(self.buffer)))
    
    def __len__(self):
        return len(self.buffer)

1.4 缓冲区设计策略

策略描述优点缺点
FIFO保留最新样本简单可能丢失重要旧样本
均匀采样随机保留保持多样性可能丢失边界样本
优先级采样保留高损失样本关注困难样本需要额外计算
类平衡采样每类保留相等数量平衡类别需要类别标签
水库采样保持均匀分布无偏向实现稍复杂

2. GEM: 梯度情景记忆

2.1 核心思想

GEM(Gradient Episodic Memory)由 Lopez-Paz 和 Ranzato 在 2017 年提出。2

核心洞察:不仅保存旧样本,还要控制梯度的方向,确保新任务的学习不会增加旧任务的损失。

2.2 数学形式化

是新任务的梯度, 是旧任务的梯度。

GEM 的目标是求解以下优化问题:

其中 是内积运算。

约束解释

  • 意味着 方向上的投影为正
  • 这确保更新方向不会增加任何旧任务的损失

2.3 PyTorch 实现

class GradientEpisodicMemory:
    """
    梯度情景记忆 (Gradient Episodic Memory)
    
    参考文献: Lopez-Paz & Ranzato "Gradient episodic memory 
             for continual learning", NeurIPS 2017
    """
    
    def __init__(self, model, buffer, device='cuda'):
        self.model = model
        self.buffer = buffer  # 经验回放缓冲区
        self.device = device
        
    def compute_gradients(self, x, y, criterion):
        """计算输入样本的梯度"""
        self.model.zero_grad()
        outputs = self.model(x)
        loss = criterion(outputs, y)
        loss.backward()
        
        # 收集梯度
        grads = {n: p.grad.clone()
                 for n, p in self.model.named_parameters()
                 if p.requires_grad and p.grad is not None}
        return grads
    
    def project_gradient(self, g_new, g_replay):
        """
        投影梯度到可行域
        
        求解: min ||g - g_new||^2 s.t. <g, g_k> >= 0
        """
        # 计算参考梯度(所有旧任务梯度的平均)
        g_ref = {}
        for n in g_replay[0].keys():
            g_ref[n] = torch.stack([g[n] for g in g_replay]).mean(dim=0)
        
        # 简化的投影:如果内积为负,减去投影分量
        g_projected = {}
        for n in g_new.keys():
            dot_product = torch.sum(g_new[n] * g_ref[n])
            
            if dot_product < 0:
                # 减去负贡献的投影
                proj = (dot_product / (torch.sum(g_ref[n] ** 2) + 1e-8)) * g_ref[n]
                g_projected[n] = g_new[n] - proj
            else:
                g_projected[n] = g_new[n]
        
        return g_projected
    
    def project_gradient_multi_ref(self, g_new, g_replay_list):
        """
        多个参考梯度的投影
        
        使用贪心策略依次投影到每个约束
        """
        g = g_new
        
        for g_ref in g_replay_list:
            g = self.project_gradient(g, [g_ref])
        
        return g
    
    def train_step(self, x_new, y_new, optimizer, criterion):
        """
        GEM 训练步骤
        """
        batch_size = x_new.size(0)
        self.model.zero_grad()
        
        # 1. 计算新任务梯度
        outputs_new = self.model(x_new)
        loss_new = criterion(outputs_new, y_new)
        loss_new.backward()
        
        g_new = {n: p.grad.clone()
                for n, p in self.model.named_parameters()
                if p.requires_grad and p.grad is not None}
        
        # 2. 计算旧任务梯度
        g_replay_list = []
        replay_samples = self.buffer.sample(min(batch_size, 50))
        
        if replay_samples:
            replay_x = torch.stack([s[0] for s in replay_samples]).to(self.device)
            replay_y = torch.stack([s[1] for s in replay_samples]).to(self.device)
            
            self.model.zero_grad()
            outputs_replay = self.model(replay_x)
            loss_replay = criterion(outputs_replay, replay_y)
            loss_replay.backward()
            
            g_replay = {n: p.grad.clone()
                       for n, p in self.model.named_parameters()
                       if p.requires_grad and p.grad is not None}
            g_replay_list.append(g_replay)
        
        # 3. 梯度投影
        if g_replay_list:
            g_final = self.project_gradient(g_new, g_replay_list)
            
            # 应用投影后的梯度
            for n, p in self.model.named_parameters():
                if p.requires_grad and n in g_final:
                    p.grad = g_final[n]
        
        optimizer.step()
        
        return loss_new.item()

2.4 GEM vs 标准 ER 对比

特性GEM标准 ER
梯度控制✅ 约束梯度方向❌ 无
理论保证✅ 不增加旧任务损失❌ 无
计算开销较高(需计算多个梯度)
实现复杂度中等简单

3. A-GEM: 高效GEM

3.1 核心思想

A-GEM(Averaged GEM)是对 GEM 的高效改进,由 Chaudhry 等人在 2019 年提出。3

核心改进

  1. 不保存所有旧任务的梯度,只保存平均梯度
  2. 只在需要时计算参考梯度

3.2 数学形式化

A-GEM 的约束简化为:

其中 是存储的平均参考梯度。

如果违反约束,则投影:

3.3 PyTorch 实现

class AGEM:
    """
    高效梯度情景记忆 (Average GEM)
    
    参考文献: Chaudhry et al. "Efficient lifelong learning with A-GEM", ICLR 2019
    """
    
    def __init__(self, model, buffer, device='cuda', reference_every=1):
        self.model = model
        self.buffer = buffer  # 经验回放缓冲区
        self.device = device
        
        # 参考梯度(移动平均)
        self.g_ref = None
        self.reference_every = reference_every
        self.step_counter = 0
        
    def update_reference_gradient(self, criterion):
        """
        更新参考梯度
        """
        replay_samples = self.buffer.sample_balanced(64)
        if not replay_samples:
            return
        
        replay_x = torch.stack([s[0] for s in replay_samples]).to(self.device)
        replay_y = torch.stack([s[1] for s in replay_samples]).to(self.device)
        
        self.model.zero_grad()
        outputs = self.model(replay_x)
        loss = criterion(outputs, replay_y)
        loss.backward()
        
        # 计算当前参考梯度
        g_current = {n: p.grad.clone()
                    for n, p in self.model.named_parameters()
                    if p.requires_grad and p.grad is not None}
        
        # 移动平均更新
        if self.g_ref is None:
            self.g_ref = g_current
        else:
            for n in self.g_ref.keys():
                self.g_ref[n] = 0.99 * self.g_ref[n] + 0.01 * g_current[n]
    
    def project_gradient(self, g):
        """
        将梯度投影到满足约束的区域
        """
        if self.g_ref is None:
            return g
        
        g_proj = {}
        for n in g.keys():
            dot = torch.sum(g[n] * self.g_ref[n])
            norm_sq = torch.sum(self.g_ref[n] ** 2) + 1e-8
            
            if dot < 0:
                proj_coeff = dot / norm_sq
                g_proj[n] = g[n] - proj_coeff * self.g_ref[n]
            else:
                g_proj[n] = g[n]
        
        return g_proj
    
    def train_step(self, x_new, y_new, optimizer, criterion):
        """
        A-GEM 训练步骤
        """
        # 更新参考梯度
        self.step_counter += 1
        if self.step_counter % self.reference_every == 0:
            self.update_reference_gradient(criterion)
        
        # 前向和反向传播
        self.model.zero_grad()
        outputs = self.model(x_new)
        loss = criterion(outputs, y_new)
        loss.backward()
        
        # 获取梯度
        g_new = {n: p.grad.clone()
                for n, p in self.model.named_parameters()
                if p.requires_grad and p.grad is not None}
        
        # 梯度投影
        g_final = self.project_gradient(g_new)
        
        # 应用梯度
        for n, p in self.model.named_parameters():
            if p.requires_grad and n in g_final:
                p.grad = g_final[n]
        
        optimizer.step()
        
        return loss.item()

4. GSS: 梯度多样性采样

4.1 核心思想

GSS(Gradient-based Sample Selection)由 Aljundi 等人在 2019 年提出。4

核心洞察:不是所有旧样本对抵抗遗忘同等重要。选择在梯度空间中具有多样性的样本可以使缓冲区更有效。

4.2 评分函数

GSS 为每个样本计算「梯度多样性得分」:

其中 是样本 的梯度, 是缓冲区中所有样本的平均梯度。

高得分样本:与平均梯度差异大,提供互补信息
低得分样本:与平均梯度相似,信息冗余

4.3 实现要点

class GradientDiversityBuffer:
    """
    基于梯度多样性的缓冲区
    
    参考文献: Aljundi et al. "Gradient based sample selection 
             for online continual learning", NeurIPS 2019
    """
    
    def __init__(self, max_size=500, buffer_per_class=50):
        self.max_size = max_size
        self.buffer_per_class = buffer_per_class
        self.buffer = defaultdict(list)  # {class: [(x, y, score), ...]}
        self.device = 'cuda'
        
    def compute_gradient_diversity_score(self, model, x, y, criterion):
        """
        计算样本的梯度多样性得分
        """
        model.zero_grad()
        output = model(x.unsqueeze(0))
        loss = criterion(output, y.unsqueeze(0))
        grad = torch.autograd.grad(loss, model.parameters())
        grad = torch.cat([g.flatten() for g in grad if g is not None])
        
        return grad
    
    def update(self, model, dataloader, criterion):
        """
        基于梯度多样性更新缓冲区
        """
        model.eval()
        all_scores = []
        
        with torch.no_grad():
            for x, y in dataloader:
                x, y = x.to(self.device), y.to(self.device)
                
                for i in range(len(x)):
                    grad = self.compute_gradient_diversity_score(
                        model, x[i], y[i], criterion
                    )
                    score = torch.norm(grad).item()
                    all_scores.append((x[i].cpu(), y[i].cpu(), score))
        
        # 按类别更新
        for x, y, score in all_scores:
            cls = y.item()
            
            # 如果类别不存在或未满,直接添加
            if cls not in self.buffer or len(self.buffer[cls]) < self.buffer_per_class:
                self.buffer[cls].append((x, y, score))
            else:
                # 如果已满,替换得分最低的样本
                min_idx = min(range(len(self.buffer[cls])), 
                             key=lambda i: self.buffer[cls][i][2])
                if score > self.buffer[cls][min_idx][2]:
                    self.buffer[cls][min_idx] = (x, y, score)
    
    def sample(self, batch_size):
        """采样时优先选择高得分样本"""
        all_samples = []
        for samples in self.buffer.values():
            all_samples.extend(samples)
        
        if len(all_samples) == 0:
            return None
        
        # 按得分降序采样
        all_samples.sort(key=lambda x: x[2], reverse=True)
        top_k = min(batch_size, len(all_samples))
        
        # 从 top-k 中随机采样
        return random.sample(all_samples[:top_k], 
                            min(batch_size // 2, top_k))

5. 混合回放方法

5.1 真实数据 + 合成数据

最新的研究表明,混合真实数据和合成数据的回放策略可以同时获得两种方法的优势。5

成分真实数据回放合成数据回放
优势保留真实分布不占用存储空间
劣势存储开销、隐私问题生成质量依赖模型
适用小规模数据集大规模场景

5.2 实现框架

class HybridReplayBuffer:
    """
    混合回放缓冲区:真实样本 + 生成样本
    """
    
    def __init__(self, real_buffer_size=100, generator=None):
        self.real_buffer = SimpleReplayBuffer(max_size_per_class=real_buffer_size)
        self.generator = generator  # 条件生成器(如 cGAN)
        
    def generate_samples(self, cls, num_samples):
        """使用生成器合成样本"""
        if self.generator is None:
            return []
        
        with torch.no_grad():
            z = torch.randn(num_samples, self.generator.latent_dim).to(next(self.generator.parameters()).device)
            labels = torch.full((num_samples,), cls).to(next(self.generator.parameters()).device)
            fake_samples = self.generator(z, labels)
        return fake_samples
    
    def sample(self, batch_size, include_generated=True):
        """混合采样"""
        samples = self.real_buffer.sample(batch_size // 2)
        
        if include_generated and self.generator is not None:
            # 补充生成样本
            num_generated = batch_size - len(samples) if samples else batch_size
            generated = []
            
            # 随机选择类别
            classes = list(set(s[1].item() for s in samples)) if samples else [0]
            for cls in classes:
                gen_samples = self.generate_samples(cls, num_generated // len(classes))
                for s in gen_samples:
                    generated.append((s, torch.tensor(cls)))
            
            samples = (samples or []) + generated
        
        return samples

6. 方法对比与选择指南

6.1 方法对比表

方法计算开销存储开销效果适用场景
ER基线方法
GEM需要理论保证
A-GEM效率-效果平衡
GSS样本选择重要
混合回放最高存储受限

6.2 缓冲区大小选择

数据集大小推荐缓冲区大小说明
小 (<10K)50-100/类可存储较多样本
中 (10K-100K)20-50/类需要平衡采样
大 (>100K)10-20/类依赖高质量选择

6.3 实践建议

  1. 首先尝试 A-GEM:效果与 GEM 相当,但计算效率更高
  2. 结合类平衡采样:避免类别不平衡导致的遗忘
  3. 定期更新缓冲区:避免存储过时的表示
  4. 考虑隐私需求:如需隐私保护,使用合成数据回放

7. 与正则化方法的结合

回放方法可以与正则化方法互补使用

class CombinedContinualLearner:
    """
    组合使用回放和正则化的持续学习器
    """
    
    def __init__(self, model, buffer, ewc, device='cuda'):
        self.model = model
        self.buffer = buffer  # 经验回放
        self.ewc = ewc       # EWC 正则化
        self.device = device
        
    def train_step(self, x_new, y_new, optimizer, criterion):
        # 1. 获取回放样本
        replay_samples = self.buffer.sample_balanced(x_new.size(0) // 2)
        
        # 2. 合并批次
        if replay_samples:
            replay_x = torch.stack([s[0] for s in replay_samples]).to(self.device)
            replay_y = torch.stack([s[1] for s in replay_samples]).to(self.device)
            x_batch = torch.cat([x_new, replay_x], dim=0)
            y_batch = torch.cat([y_new, replay_y], dim=0)
        else:
            x_batch, y_batch = x_new, y_new
        
        # 3. 计算当前任务损失
        optimizer.zero_grad()
        outputs = self.model(x_batch)
        loss_current = criterion(outputs, y_batch)
        
        # 4. 添加 EWC 正则化
        loss_ewc = self.ewc.penalty()
        
        # 5. 总损失
        loss_total = loss_current + loss_ewc
        loss_total.backward()
        optimizer.step()
        
        return loss_total.item()

参考资料


相关阅读

Footnotes

  1. Rolnick, D., et al. (2019). Experience replay for continual learning. NeurIPS.

  2. Lopez-Paz, D., & Ranzato, M. (2017). Gradient episodic memory for continual learning. NeurIPS.

  3. Chaudhry, A., et al. (2019). Efficient lifelong learning with A-GEM. ICLR.

  4. Aljundi, R., et al. (2019). Gradient based sample selection for online continual learning. NeurIPS.

  5. Koh, H., et al. (2025). Hybrid Memory Replay. ICLR.