引言
基于回放的持续学习方法通过存储旧任务的样本或表示,在学习新任务时一并训练这些样本,从而对抗灾难性遗忘。这类方法简单直观,是目前最有效的持续学习方法之一。
核心思想可以用一句话概括:「温故而知新」——定期复习旧知识来巩固记忆。
1. 经验回放 (Experience Replay)
1.1 核心思想
经验回放(Experience Replay,ER)是最简单的回放方法。其核心思想是维护一个记忆缓冲区,存储来自旧任务的代表性样本。1
┌────────────────────────────────────────────────────────────────┐
│ 经验回放示意 │
├────────────────────────────────────────────────────────────────┤
│ │
│ 任务A数据 ──┬──→ 记忆缓冲区 ──┬──→ 训练 │
│ │ │ │
│ 任务B数据 ──┤ ├──→ 模型更新 │
│ │ │ │
│ 任务C数据 ──┘ └──→ ... │
│ │
│ 缓冲区策略: 均匀采样 / 优先级采样 / 类平衡采样 │
└────────────────────────────────────────────────────────────────┘
1.2 损失函数
其中 是从缓冲区采样的回放损失。
1.3 PyTorch 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
import numpy as np
import random
from collections import defaultdict
class SimpleReplayBuffer:
"""
简单的经验回放缓冲区
存储格式: {class_id: [sample1, sample2, ...]}
"""
def __init__(self, max_size_per_class=50):
self.buffer = defaultdict(list) # {class: [(x, y), ...]}
self.max_size_per_class = max_size_per_class
def update(self, class_idx, samples):
"""
更新缓冲区
Args:
class_idx: 类别索引
samples: 样本列表 [(x1, y1), (x2, y2), ...]
"""
self.buffer[class_idx].extend(samples)
# 保持每个类别的样本数量限制
if len(self.buffer[class_idx]) > self.max_size_per_class:
# FIFO 策略:保留最新的样本
self.buffer[class_idx] = self.buffer[class_idx][-self.max_size_per_class:]
def sample(self, batch_size):
"""
从缓冲区采样
Args:
batch_size: 采样数量
Returns:
samples: 采样的样本列表
"""
all_samples = []
for samples in self.buffer.values():
all_samples.extend(samples)
if len(all_samples) == 0:
return None
return random.sample(all_samples, min(batch_size, len(all_samples)))
def sample_balanced(self, batch_size):
"""
类别平衡采样
"""
classes = list(self.buffer.keys())
if not classes:
return None
# 每个类采样相等数量
samples_per_class = batch_size // len(classes)
samples = []
for cls in classes:
cls_samples = self.buffer[cls]
if len(cls_samples) >= samples_per_class:
samples.extend(random.sample(cls_samples, samples_per_class))
else:
samples.extend(cls_samples)
# 如果不够,补充采样
if len(samples) < batch_size:
additional = batch_size - len(samples)
all_samples = [s for samples in self.buffer.values() for s in samples]
samples.extend(random.sample(all_samples, additional))
return samples
def __len__(self):
return sum(len(s) for s in self.buffer.values())
class ExperienceReplay:
"""
经验回放持续学习器
"""
def __init__(self, model, buffer_size_per_class=50, replay_alpha=0.5):
self.model = model
self.buffer = SimpleReplayBuffer(max_size_per_class=buffer_size_per_class)
self.replay_alpha = replay_alpha # 回放样本在batch中的比例
def update_buffer(self, dataloader, num_samples=1000):
"""
从数据加载器中采样更新缓冲区
Args:
dataloader: 数据加载器
num_samples: 采样数量
"""
self.model.eval()
samples_collected = 0
with torch.no_grad():
for inputs, targets in dataloader:
# 收集样本
for x, y in zip(inputs, targets):
self.buffer.update(y.item(), [(x, y)])
samples_collected += 1
if samples_collected >= num_samples:
break
if samples_collected >= num_samples:
break
def train_step(self, x_new, y_new, optimizer, criterion):
"""
一次训练步骤
Args:
x_new: 新任务输入
y_new: 新任务标签
optimizer: 优化器
criterion: 损失函数
"""
batch_size = x_new.size(0)
replay_size = int(batch_size * self.replay_alpha)
# 准备新任务批次
new_indices = torch.randperm(batch_size)[:batch_size - replay_size]
x_batch = x_new[new_indices]
y_batch = y_new[new_indices]
# 添加回放样本
replay_samples = self.buffer.sample(replay_size)
if replay_samples:
replay_x = torch.stack([s[0] for s in replay_samples])
replay_y = torch.stack([s[1] for s in replay_samples])
x_batch = torch.cat([x_batch, replay_x], dim=0)
y_batch = torch.cat([y_batch, replay_y], dim=0)
# 前向传播和损失计算
optimizer.zero_grad()
outputs = self.model(x_batch)
loss = criterion(outputs, y_batch)
loss.backward()
optimizer.step()
return loss.item()
class ReservoirSamplingBuffer:
"""
水库采样缓冲区
优势: 保持样本的均匀分布,不偏向最新样本
"""
def __init__(self, max_size=1000):
self.max_size = max_size
self.buffer = []
self.n_seen = 0
def update(self, samples):
"""
水库采样更新
如果 buffer 未满,随机添加
如果 buffer 已满,以概率 max_size/n_seen 替换
"""
for sample in samples:
self.n_seen += 1
if len(self.buffer) < self.max_size:
self.buffer.append(sample)
else:
# 以 max_size/n_seen 的概率替换
j = random.randint(0, self.n_seen - 1)
if j < self.max_size:
self.buffer[j] = sample
def sample(self, batch_size):
return random.sample(self.buffer, min(batch_size, len(self.buffer)))
def __len__(self):
return len(self.buffer)1.4 缓冲区设计策略
| 策略 | 描述 | 优点 | 缺点 |
|---|---|---|---|
| FIFO | 保留最新样本 | 简单 | 可能丢失重要旧样本 |
| 均匀采样 | 随机保留 | 保持多样性 | 可能丢失边界样本 |
| 优先级采样 | 保留高损失样本 | 关注困难样本 | 需要额外计算 |
| 类平衡采样 | 每类保留相等数量 | 平衡类别 | 需要类别标签 |
| 水库采样 | 保持均匀分布 | 无偏向 | 实现稍复杂 |
2. GEM: 梯度情景记忆
2.1 核心思想
GEM(Gradient Episodic Memory)由 Lopez-Paz 和 Ranzato 在 2017 年提出。2
核心洞察:不仅保存旧样本,还要控制梯度的方向,确保新任务的学习不会增加旧任务的损失。
2.2 数学形式化
设 是新任务的梯度, 是旧任务的梯度。
GEM 的目标是求解以下优化问题:
其中 是内积运算。
约束解释:
- 意味着 在 方向上的投影为正
- 这确保更新方向不会增加任何旧任务的损失
2.3 PyTorch 实现
class GradientEpisodicMemory:
"""
梯度情景记忆 (Gradient Episodic Memory)
参考文献: Lopez-Paz & Ranzato "Gradient episodic memory
for continual learning", NeurIPS 2017
"""
def __init__(self, model, buffer, device='cuda'):
self.model = model
self.buffer = buffer # 经验回放缓冲区
self.device = device
def compute_gradients(self, x, y, criterion):
"""计算输入样本的梯度"""
self.model.zero_grad()
outputs = self.model(x)
loss = criterion(outputs, y)
loss.backward()
# 收集梯度
grads = {n: p.grad.clone()
for n, p in self.model.named_parameters()
if p.requires_grad and p.grad is not None}
return grads
def project_gradient(self, g_new, g_replay):
"""
投影梯度到可行域
求解: min ||g - g_new||^2 s.t. <g, g_k> >= 0
"""
# 计算参考梯度(所有旧任务梯度的平均)
g_ref = {}
for n in g_replay[0].keys():
g_ref[n] = torch.stack([g[n] for g in g_replay]).mean(dim=0)
# 简化的投影:如果内积为负,减去投影分量
g_projected = {}
for n in g_new.keys():
dot_product = torch.sum(g_new[n] * g_ref[n])
if dot_product < 0:
# 减去负贡献的投影
proj = (dot_product / (torch.sum(g_ref[n] ** 2) + 1e-8)) * g_ref[n]
g_projected[n] = g_new[n] - proj
else:
g_projected[n] = g_new[n]
return g_projected
def project_gradient_multi_ref(self, g_new, g_replay_list):
"""
多个参考梯度的投影
使用贪心策略依次投影到每个约束
"""
g = g_new
for g_ref in g_replay_list:
g = self.project_gradient(g, [g_ref])
return g
def train_step(self, x_new, y_new, optimizer, criterion):
"""
GEM 训练步骤
"""
batch_size = x_new.size(0)
self.model.zero_grad()
# 1. 计算新任务梯度
outputs_new = self.model(x_new)
loss_new = criterion(outputs_new, y_new)
loss_new.backward()
g_new = {n: p.grad.clone()
for n, p in self.model.named_parameters()
if p.requires_grad and p.grad is not None}
# 2. 计算旧任务梯度
g_replay_list = []
replay_samples = self.buffer.sample(min(batch_size, 50))
if replay_samples:
replay_x = torch.stack([s[0] for s in replay_samples]).to(self.device)
replay_y = torch.stack([s[1] for s in replay_samples]).to(self.device)
self.model.zero_grad()
outputs_replay = self.model(replay_x)
loss_replay = criterion(outputs_replay, replay_y)
loss_replay.backward()
g_replay = {n: p.grad.clone()
for n, p in self.model.named_parameters()
if p.requires_grad and p.grad is not None}
g_replay_list.append(g_replay)
# 3. 梯度投影
if g_replay_list:
g_final = self.project_gradient(g_new, g_replay_list)
# 应用投影后的梯度
for n, p in self.model.named_parameters():
if p.requires_grad and n in g_final:
p.grad = g_final[n]
optimizer.step()
return loss_new.item()2.4 GEM vs 标准 ER 对比
| 特性 | GEM | 标准 ER |
|---|---|---|
| 梯度控制 | ✅ 约束梯度方向 | ❌ 无 |
| 理论保证 | ✅ 不增加旧任务损失 | ❌ 无 |
| 计算开销 | 较高(需计算多个梯度) | 低 |
| 实现复杂度 | 中等 | 简单 |
3. A-GEM: 高效GEM
3.1 核心思想
A-GEM(Averaged GEM)是对 GEM 的高效改进,由 Chaudhry 等人在 2019 年提出。3
核心改进:
- 不保存所有旧任务的梯度,只保存平均梯度
- 只在需要时计算参考梯度
3.2 数学形式化
A-GEM 的约束简化为:
其中 是存储的平均参考梯度。
如果违反约束,则投影:
3.3 PyTorch 实现
class AGEM:
"""
高效梯度情景记忆 (Average GEM)
参考文献: Chaudhry et al. "Efficient lifelong learning with A-GEM", ICLR 2019
"""
def __init__(self, model, buffer, device='cuda', reference_every=1):
self.model = model
self.buffer = buffer # 经验回放缓冲区
self.device = device
# 参考梯度(移动平均)
self.g_ref = None
self.reference_every = reference_every
self.step_counter = 0
def update_reference_gradient(self, criterion):
"""
更新参考梯度
"""
replay_samples = self.buffer.sample_balanced(64)
if not replay_samples:
return
replay_x = torch.stack([s[0] for s in replay_samples]).to(self.device)
replay_y = torch.stack([s[1] for s in replay_samples]).to(self.device)
self.model.zero_grad()
outputs = self.model(replay_x)
loss = criterion(outputs, replay_y)
loss.backward()
# 计算当前参考梯度
g_current = {n: p.grad.clone()
for n, p in self.model.named_parameters()
if p.requires_grad and p.grad is not None}
# 移动平均更新
if self.g_ref is None:
self.g_ref = g_current
else:
for n in self.g_ref.keys():
self.g_ref[n] = 0.99 * self.g_ref[n] + 0.01 * g_current[n]
def project_gradient(self, g):
"""
将梯度投影到满足约束的区域
"""
if self.g_ref is None:
return g
g_proj = {}
for n in g.keys():
dot = torch.sum(g[n] * self.g_ref[n])
norm_sq = torch.sum(self.g_ref[n] ** 2) + 1e-8
if dot < 0:
proj_coeff = dot / norm_sq
g_proj[n] = g[n] - proj_coeff * self.g_ref[n]
else:
g_proj[n] = g[n]
return g_proj
def train_step(self, x_new, y_new, optimizer, criterion):
"""
A-GEM 训练步骤
"""
# 更新参考梯度
self.step_counter += 1
if self.step_counter % self.reference_every == 0:
self.update_reference_gradient(criterion)
# 前向和反向传播
self.model.zero_grad()
outputs = self.model(x_new)
loss = criterion(outputs, y_new)
loss.backward()
# 获取梯度
g_new = {n: p.grad.clone()
for n, p in self.model.named_parameters()
if p.requires_grad and p.grad is not None}
# 梯度投影
g_final = self.project_gradient(g_new)
# 应用梯度
for n, p in self.model.named_parameters():
if p.requires_grad and n in g_final:
p.grad = g_final[n]
optimizer.step()
return loss.item()4. GSS: 梯度多样性采样
4.1 核心思想
GSS(Gradient-based Sample Selection)由 Aljundi 等人在 2019 年提出。4
核心洞察:不是所有旧样本对抵抗遗忘同等重要。选择在梯度空间中具有多样性的样本可以使缓冲区更有效。
4.2 评分函数
GSS 为每个样本计算「梯度多样性得分」:
其中 是样本 的梯度, 是缓冲区中所有样本的平均梯度。
高得分样本:与平均梯度差异大,提供互补信息
低得分样本:与平均梯度相似,信息冗余
4.3 实现要点
class GradientDiversityBuffer:
"""
基于梯度多样性的缓冲区
参考文献: Aljundi et al. "Gradient based sample selection
for online continual learning", NeurIPS 2019
"""
def __init__(self, max_size=500, buffer_per_class=50):
self.max_size = max_size
self.buffer_per_class = buffer_per_class
self.buffer = defaultdict(list) # {class: [(x, y, score), ...]}
self.device = 'cuda'
def compute_gradient_diversity_score(self, model, x, y, criterion):
"""
计算样本的梯度多样性得分
"""
model.zero_grad()
output = model(x.unsqueeze(0))
loss = criterion(output, y.unsqueeze(0))
grad = torch.autograd.grad(loss, model.parameters())
grad = torch.cat([g.flatten() for g in grad if g is not None])
return grad
def update(self, model, dataloader, criterion):
"""
基于梯度多样性更新缓冲区
"""
model.eval()
all_scores = []
with torch.no_grad():
for x, y in dataloader:
x, y = x.to(self.device), y.to(self.device)
for i in range(len(x)):
grad = self.compute_gradient_diversity_score(
model, x[i], y[i], criterion
)
score = torch.norm(grad).item()
all_scores.append((x[i].cpu(), y[i].cpu(), score))
# 按类别更新
for x, y, score in all_scores:
cls = y.item()
# 如果类别不存在或未满,直接添加
if cls not in self.buffer or len(self.buffer[cls]) < self.buffer_per_class:
self.buffer[cls].append((x, y, score))
else:
# 如果已满,替换得分最低的样本
min_idx = min(range(len(self.buffer[cls])),
key=lambda i: self.buffer[cls][i][2])
if score > self.buffer[cls][min_idx][2]:
self.buffer[cls][min_idx] = (x, y, score)
def sample(self, batch_size):
"""采样时优先选择高得分样本"""
all_samples = []
for samples in self.buffer.values():
all_samples.extend(samples)
if len(all_samples) == 0:
return None
# 按得分降序采样
all_samples.sort(key=lambda x: x[2], reverse=True)
top_k = min(batch_size, len(all_samples))
# 从 top-k 中随机采样
return random.sample(all_samples[:top_k],
min(batch_size // 2, top_k))5. 混合回放方法
5.1 真实数据 + 合成数据
最新的研究表明,混合真实数据和合成数据的回放策略可以同时获得两种方法的优势。5
| 成分 | 真实数据回放 | 合成数据回放 |
|---|---|---|
| 优势 | 保留真实分布 | 不占用存储空间 |
| 劣势 | 存储开销、隐私问题 | 生成质量依赖模型 |
| 适用 | 小规模数据集 | 大规模场景 |
5.2 实现框架
class HybridReplayBuffer:
"""
混合回放缓冲区:真实样本 + 生成样本
"""
def __init__(self, real_buffer_size=100, generator=None):
self.real_buffer = SimpleReplayBuffer(max_size_per_class=real_buffer_size)
self.generator = generator # 条件生成器(如 cGAN)
def generate_samples(self, cls, num_samples):
"""使用生成器合成样本"""
if self.generator is None:
return []
with torch.no_grad():
z = torch.randn(num_samples, self.generator.latent_dim).to(next(self.generator.parameters()).device)
labels = torch.full((num_samples,), cls).to(next(self.generator.parameters()).device)
fake_samples = self.generator(z, labels)
return fake_samples
def sample(self, batch_size, include_generated=True):
"""混合采样"""
samples = self.real_buffer.sample(batch_size // 2)
if include_generated and self.generator is not None:
# 补充生成样本
num_generated = batch_size - len(samples) if samples else batch_size
generated = []
# 随机选择类别
classes = list(set(s[1].item() for s in samples)) if samples else [0]
for cls in classes:
gen_samples = self.generate_samples(cls, num_generated // len(classes))
for s in gen_samples:
generated.append((s, torch.tensor(cls)))
samples = (samples or []) + generated
return samples6. 方法对比与选择指南
6.1 方法对比表
| 方法 | 计算开销 | 存储开销 | 效果 | 适用场景 |
|---|---|---|---|---|
| ER | 低 | 中 | 中 | 基线方法 |
| GEM | 高 | 中 | 高 | 需要理论保证 |
| A-GEM | 中 | 中 | 高 | 效率-效果平衡 |
| GSS | 中 | 中 | 高 | 样本选择重要 |
| 混合回放 | 高 | 低 | 最高 | 存储受限 |
6.2 缓冲区大小选择
| 数据集大小 | 推荐缓冲区大小 | 说明 |
|---|---|---|
| 小 (<10K) | 50-100/类 | 可存储较多样本 |
| 中 (10K-100K) | 20-50/类 | 需要平衡采样 |
| 大 (>100K) | 10-20/类 | 依赖高质量选择 |
6.3 实践建议
- 首先尝试 A-GEM:效果与 GEM 相当,但计算效率更高
- 结合类平衡采样:避免类别不平衡导致的遗忘
- 定期更新缓冲区:避免存储过时的表示
- 考虑隐私需求:如需隐私保护,使用合成数据回放
7. 与正则化方法的结合
回放方法可以与正则化方法互补使用:
class CombinedContinualLearner:
"""
组合使用回放和正则化的持续学习器
"""
def __init__(self, model, buffer, ewc, device='cuda'):
self.model = model
self.buffer = buffer # 经验回放
self.ewc = ewc # EWC 正则化
self.device = device
def train_step(self, x_new, y_new, optimizer, criterion):
# 1. 获取回放样本
replay_samples = self.buffer.sample_balanced(x_new.size(0) // 2)
# 2. 合并批次
if replay_samples:
replay_x = torch.stack([s[0] for s in replay_samples]).to(self.device)
replay_y = torch.stack([s[1] for s in replay_samples]).to(self.device)
x_batch = torch.cat([x_new, replay_x], dim=0)
y_batch = torch.cat([y_new, replay_y], dim=0)
else:
x_batch, y_batch = x_new, y_new
# 3. 计算当前任务损失
optimizer.zero_grad()
outputs = self.model(x_batch)
loss_current = criterion(outputs, y_batch)
# 4. 添加 EWC 正则化
loss_ewc = self.ewc.penalty()
# 5. 总损失
loss_total = loss_current + loss_ewc
loss_total.backward()
optimizer.step()
return loss_total.item()参考资料
相关阅读:
Footnotes
-
Rolnick, D., et al. (2019). Experience replay for continual learning. NeurIPS. ↩
-
Lopez-Paz, D., & Ranzato, M. (2017). Gradient episodic memory for continual learning. NeurIPS. ↩
-
Chaudhry, A., et al. (2019). Efficient lifelong learning with A-GEM. ICLR. ↩
-
Aljundi, R., et al. (2019). Gradient based sample selection for online continual learning. NeurIPS. ↩
-
Koh, H., et al. (2025). Hybrid Memory Replay. ICLR. ↩