引言
基于正则化的持续学习方法通过在损失函数中添加正则化项来限制模型参数的变化,从而保护旧任务学到的知识。这类方法不需要额外的样本存储(除非使用知识蒸馏),是持续学习研究中最早且最经典的方法之一。
1. EWC: 弹性权重整合
1.1 核心思想
EWC(Elastic Weight Consolidation)由 Kirkpatrick 等人在 2017 年提出,是持续学习领域的里程碑式工作。1
核心洞察:并非所有参数对旧任务都同等重要。那些对旧任务影响大的参数应该被「锁定」,而影响小的参数可以自由调整。
1.2 数学推导
对于任务 (旧任务),训练后得到最优参数 。现在要学习任务 (新任务)。
标准梯度下降会直接最小化 ,导致 远离 。
EWC 的目标:
其中:
- :正则化强度超参数
- :Fisher 信息矩阵的对角元素,表示参数 对旧任务的重要性
1.3 Fisher 信息矩阵
Fisher 信息矩阵 定义为:
对角近似:实际计算中使用对角近似以节省内存:
这也是经验 Fisher 的计算方式。
1.4 PyTorch 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import copy
class EWC:
"""
弹性权重整合 (Elastic Weight Consolidation)
参考文献: Kirkpatrick et al. "Overcoming catastrophic forgetting
in neural networks", PNAS 2017
"""
def __init__(self, model, dataset, criterion, device='cuda'):
self.model = model
self.dataset = dataset # 旧任务的数据集
self.criterion = criterion
self.device = device
# 保存旧任务的最优参数
self.params = {n: p.clone().detach().cpu()
for n, p in model.named_parameters()
if p.requires_grad}
# 计算 Fisher 信息矩阵
self.fisher = self._compute_fisher()
def _compute_fisher(self):
"""计算 Fisher 信息矩阵的对角元素"""
fisher = {n: torch.zeros_like(p).cpu()
for n, p in self.model.named_parameters()
if p.requires_grad}
self.model.eval()
dataloader = DataLoader(self.dataset, batch_size=32, shuffle=True)
for inputs, targets in dataloader:
inputs, targets = inputs.to(self.device), targets.to(self.device)
self.model.zero_grad()
outputs = self.model(inputs)
loss = self.criterion(outputs, targets)
loss.backward()
# 累加梯度平方
for n, p in self.model.named_parameters():
if p.requires_grad and p.grad is not None:
fisher[n] += p.grad.data.cpu() ** 2
# 归一化
n_samples = len(dataloader.dataset)
for n in fisher:
fisher[n] /= n_samples
return fisher
def penalty(self):
"""
计算 EWC 正则化损失项
Returns:
penalty: EWC 正则化损失
"""
penalty = 0
for n, p in self.model.named_parameters():
if p.requires_grad:
# 从当前参数中减去旧参数
param_diff = p.cpu() - self.params[n].to(p.device)
# Fisher 加权的 L2 损失
penalty += (self.fisher[n].to(p.device) * param_diff ** 2).sum()
return penalty / 2
def loss(self, current_loss, lambda_ewc=1000):
"""
计算带 EWC 正则化的总损失
Args:
current_loss: 当前任务的损失
lambda_ewc: EWC 正则化强度
Returns:
total_loss: 带正则化的总损失
"""
return current_loss + lambda_ewc * self.penalty()
class EWCBootstrap:
"""EWC++: 使用多次采样估计 Fisher"""
def __init__(self, model, dataset, criterion, n_samples=1000, device='cuda'):
self.model = model
self.criterion = criterion
self.device = device
# 保存旧参数
self.params = {n: p.clone().detach()
for n, p in model.named_parameters()
if p.requires_grad}
# 使用多次前向传播估计 Fisher
self.fisher = self._compute_bootstrap_fisher(dataset, n_samples)
def _compute_bootstrap_fisher(self, dataset, n_samples):
"""使用输出扰动估计 Fisher(适合分类任务)"""
fisher = {n: torch.zeros_like(p)
for n, p in self.model.named_parameters()
if p.requires_grad}
self.model.eval()
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
sample_count = 0
for inputs, _ in dataloader:
if sample_count >= n_samples:
break
inputs = inputs.to(self.device)
self.model.zero_grad()
outputs = self.model(inputs)
# 使用 softmax 概率作为目标
probs = F.softmax(outputs, dim=-1)
# 对每个类别计算梯度并累加
for c in range(outputs.size(-1)):
self.model.zero_grad()
loss = probs[0, c]
loss.backward()
for n, p in self.model.named_parameters():
if p.requires_grad and p.grad is not None:
fisher[n] += p.grad ** 2
sample_count += 1
fisher = {n: f / (n_samples * outputs.size(-1))
for n, f in fisher.items()}
return fisher1.5 使用示例
# 训练流程示例
def train_with_ewc():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 初始化模型
model = MyModel().to(device)
criterion = nn.CrossEntropyLoss()
# 任务1: 训练基础模型
task1_loader = get_task_loader('task1')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
for inputs, targets in task1_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# 任务1结束后,创建 EWC 备份
ewc_task1 = EWC(model, task1_dataset, criterion, device)
# 任务2: 带 EWC 的训练
task2_loader = get_task_loader('task2')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(10):
for inputs, targets in task2_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
# 当前任务损失
outputs = model(inputs)
current_loss = criterion(outputs, targets)
# EWC 正则化损失
ewc_loss = ewc_task1.loss(current_loss, lambda_ewc=1000)
ewc_loss.backward()
optimizer.step()
return model, ewc_task12. SI: 突触智能
2.1 核心思想
SI(Synaptic Intelligence)由 Zenke 等人在 2017 年提出。2
核心洞察:参数对损失的「二阶导数」反映了其重要性。SI 通过追踪训练过程中参数变化的「累积贡献」来估计重要性。
2.2 数学推导
SI 定义参数 的重要性为:
其中 是参数更新量。
正则化损失:
2.3 在线估计实现
class SynapticIntelligence:
"""
突触智能 (Synaptic Intelligence)
参考文献: Zenke et al. "Continual learning through synaptic intelligence", ICML 2017
"""
def __init__(self, model, lambda_si=1.0, eps=1e-3):
self.model = model
self.lambda_si = lambda_si
self.eps = eps # 防止除零
# 保存旧参数
self.params_old = {n: p.clone().detach()
for n, p in model.named_parameters()
if p.requires_grad}
# 累积重要性
self.W = {n: torch.zeros_like(p).cpu()
for n, p in model.named_parameters()
if p.requires_grad}
# 记录参数变化
self.delta_theta = {n: torch.zeros_like(p).cpu()
for n, p in model.named_parameters()
if p.requires_grad}
self.prev_params = {n: p.clone().detach()
for n, p in model.named_parameters()
if p.requires_grad}
def compute_importance(self):
"""
计算参数重要性
在每次参数更新后调用
"""
for n, p in self.model.named_parameters():
if p.requires_grad:
# 计算参数变化
delta = p.detach().cpu() - self.prev_params[n]
self.delta_theta[n] = delta
# 获取梯度
if p.grad is not None:
grad = p.grad.detach().cpu()
# 累积重要性: delta * (-grad)
self.W[n] += delta * (-grad)
# 更新参考点
self.prev_params[n] = p.detach().cpu().clone()
def penalty(self):
"""
计算 SI 正则化损失
"""
penalty = 0
for n, p in self.model.named_parameters():
if p.requires_grad:
param_diff = p.cpu() - self.params_old[n]
# 使用累积重要性加权
importance = self.W[n] / (self.delta_theta[n] ** 2 + self.eps)
penalty += (importance * param_diff ** 2).sum()
return penalty * self.lambda_si / 22.4 EWC vs SI 对比
| 特性 | EWC | SI |
|---|---|---|
| 重要性度量 | Fisher 信息矩阵(一阶统计) | 累积梯度贡献(二阶效应) |
| 计算时机 | 任务切换后离线计算 | 在线累积(每次更新后) |
| 存储需求 | O(参数量) | O(参数量) |
| 理论基础 | 贝叶斯后验近似 | 损失曲率近似 |
| 计算复杂度 | 中等(需遍历数据集) | 低(在线更新) |
3. LwF: 无遗忘学习
3.1 核心思想
LwF(Learning without Forgetting)由 Li 和 Hoiem 在 2017 年提出。3
核心洞察:使用知识蒸馏将旧模型的知识「迁移」到新模型中。具体做法是保存旧模型的输出(作为软标签),然后要求新模型同时满足:
- 在新任务上表现良好
- 在旧任务上保持类似的输出(通过知识蒸馏损失)
3.2 数学推导
总损失函数:
其中知识蒸馏损失为:
这里:
- 是旧模型的软输出
- 是新模型的软输出
- 是温度参数(通常 )
- 是蒸馏损失权重
3.3 PyTorch 实现
class LearningWithoutForgetting:
"""
无遗忘学习 (Learning without Forgetting)
参考文献: Li & Hoiem "Learning without forgetting", TPAMI 2017
"""
def __init__(self, model, temperature=2.0, alpha=1.0):
self.model = model
self.temperature = temperature
self.alpha = alpha # 新任务损失的权重
# 保存旧模型的参数快照
self.old_model = copy.deepcopy(model)
self.old_model.eval()
for param in self.old_model.parameters():
param.requires_grad = False
def distillation_loss(self, student_logits, teacher_logits):
"""
知识蒸馏损失
Args:
student_logits: 新模型的 logits
teacher_logits: 旧模型的 logits
Returns:
loss_kd: 蒸馏损失
"""
# 使用温度 T 进行软化
soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
log_soft_student = F.log_softmax(student_logits / self.temperature, dim=-1)
# KL 散度损失
loss_kd = F.kl_div(
log_soft_student,
soft_teacher,
reduction='batchmean'
) * (self.temperature ** 2)
return loss_kd
def forward(self, x_new, y_new, x_old=None):
"""
前向计算总损失
Args:
x_new: 新任务的输入
y_new: 新任务的标签
x_old: 旧任务的输入(用于蒸馏,如果为 None 则使用 x_new)
Returns:
total_loss: 总损失
loss_dict: 各项损失的字典
"""
# 新任务损失
outputs_new = self.model(x_new)
loss_new = F.cross_entropy(outputs_new, y_new)
# 知识蒸馏损失
# 如果没有旧任务数据,使用新任务的输入通过旧模型获取软标签
if x_old is None:
x_old = x_new.detach()
with torch.no_grad():
teacher_logits = self.old_model(x_old)
student_logits = self.model(x_old)
loss_kd = self.distillation_loss(student_logits, teacher_logits)
# 总损失
total_loss = loss_new + self.alpha * loss_kd
return total_loss, {
'loss_new': loss_new.item(),
'loss_kd': loss_kd.item(),
'total_loss': total_loss.item()
}
def update_old_model(self):
"""在新任务训练后更新旧模型快照"""
self.old_model = copy.deepcopy(self.model)
self.old_model.eval()
for param in self.old_model.parameters():
param.requires_grad = False3.4 完整训练流程
def train_with_lwf():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MyModel().to(device)
criterion = nn.CrossEntropyLoss()
# 初始化 LwF
lwf = LearningWithoutForgetting(model, temperature=2.0, alpha=1.0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# === 任务1训练 ===
print("训练任务1...")
task1_loader = get_task_loader('task1')
for epoch in range(10):
for inputs, targets in task1_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets) # 任务1直接训练
loss.backward()
optimizer.step()
# 任务1结束后保存旧模型
lwf.update_old_model()
# === 任务2训练 ===
print("训练任务2 (带LwF)...")
task2_loader = get_task_loader('task2')
for epoch in range(10):
for inputs, targets in task2_loader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
total_loss, loss_dict = lwf.forward(inputs, targets)
total_loss.backward()
optimizer.step()
return model4. RWalk: 鲁棒无遗忘学习
4.1 核心思想
RWalk(Robust Walk)是 Chaudhry 等人在 2019 年提出的方法,结合了 Fisher 信息和路径重要性。4
核心洞察:综合使用 EWC 的 Fisher 重要性和 SI 的累积贡献,提供更鲁棒的重要性估计。
4.2 数学推导
RWalk 定义参数重要性为 Fisher 和路径重要性的加权组合:
其中 是基于路径的重要性度量:
4.3 实现要点
class RWalk:
"""
鲁棒无遗忘学习 (RWalk)
参考文献: Chaudhry et al. "Efficient lifelong learning with A-GEM", ICLR 2019
"""
def __init__(self, model, epsilon=0.2, lambda_rwalk=0.1):
self.model = model
self.epsilon = epsilon # 梯度投影阈值
self.lambda_rwalk = lambda_rwalk
# Fisher 信息
self.fisher = {}
# 累积重要性
self.omega = {}
# 旧参数
self.params_old = {}
def compute_importance(self, dataloader, criterion, device):
"""计算 Fisher 信息和累积重要性"""
# 初始化
for n, p in self.model.named_parameters():
if p.requires_grad:
self.fisher[n] = torch.zeros_like(p).cpu()
self.omega[n] = torch.zeros_like(p).cpu()
self.params_old[n] = p.clone().detach().cpu()
# 遍历数据集
self.model.eval()
for inputs, targets in dataloader:
inputs, targets = inputs.to(device), targets.to(device)
self.model.zero_grad()
outputs = self.model(inputs)
loss = criterion(outputs, targets)
loss.backward()
for n, p in self.model.named_parameters():
if p.requires_grad and p.grad is not None:
# Fisher: 梯度平方
self.fisher[n] += p.grad.data.cpu() ** 2
# Omega: 梯度与参数变化的乘积
delta = p.data.cpu() - self.params_old[n]
self.omega[n] += (delta * (-p.grad.data.cpu()))
# 归一化
n_samples = len(dataloader.dataset)
for n in self.fisher:
self.fisher[n] /= n_samples
# 更新旧参数
for n, p in self.model.named_parameters():
if p.requires_grad:
self.params_old[n] = p.clone().detach().cpu()
def penalty(self):
"""RWalk 正则化损失"""
penalty = 0
for n, p in self.model.named_parameters():
if p.requires_grad:
# 综合 Fisher 和 Omega
importance = self.lambda_rwalk * self.fisher[n] + \
(1 - self.lambda_rwalk) * self.omega[n]
importance = importance / (importance.max() + 1e-8)
param_diff = p.cpu() - self.params_old[n]
penalty += (importance * param_diff ** 2).sum()
return penalty / 25. 方法对比与实践建议
5.1 方法对比表
| 方法 | 正则化形式 | 计算开销 | 存储开销 | 适用场景 |
|---|---|---|---|---|
| EWC | Fisher 加权 L2 | 中等(需遍历数据) | O(参数量) | 任务边界明确 |
| SI | 累积贡献加权 L2 | 低(在线更新) | O(参数量) | 任务边界模糊 |
| LwF | 知识蒸馏 | 低 | O(模型参数) | 需要保持输出分布 |
| RWalk | 综合加权 L2 | 中等 | O(参数量) | 需要鲁棒性 |
5.2 超参数选择指南
| 参数 | 建议范围 | 调整策略 |
|---|---|---|
| 遗忘严重时增大 | ||
| 遗忘严重时增大 | ||
| (温度) | 软标签过于集中时增大 | |
| 旧任务性能下降时增大 |
5.3 实践注意事项
- 多任务场景:每学习一个新任务,需要保存其 Fisher/Omega 信息
- 计算效率:Fisher 计算可以只使用部分数据(采样)
- 数值稳定性:添加小的 epsilon 防止除零
- 与其他方法结合:正则化可以与回放方法结合使用
6. 代码资源推荐
| 仓库 | 语言 | 内容 |
|---|---|---|
| GMvandeVen/continual-learning | Python | 包含 EWC, SI, LwF 等方法的完整实现 |
| moskomule/ewc.pytorch | Python | EWC 的简洁 PyTorch 实现 |
| ContinuualAI/avalanche | Python | 端到端持续学习框架 |
参考资料
相关阅读:
- 持续学习基础 — 灾难性遗忘的定义与设置
- 知识蒸馏基础 — LwF 的理论基础
- Sharp vs Flat Minima — 优化景观与泛化的联系
Footnotes
-
Kirkpatrick, J., et al. (2017). Overcoming catastrophic forgetting in neural networks. PNAS. ↩
-
Zenke, F., et al. (2017). Continual learning through synaptic intelligence. ICML. ↩
-
Li, Z., & Hoiem, D. (2017). Learning without forgetting. TPAMI. ↩
-
Chaudhry, A., et al. (2019). Efficient lifelong learning with A-GEM. ICLR. ↩