引言
基于架构的持续学习方法通过动态调整网络结构来避免灾难性遗忘。其核心思想是为不同任务分配独立的参数或模块,使旧任务的知识被物理「隔离」。
这类方法理论上可以完全避免遗忘,但代价是参数量随任务数量线性增长。
1. PackNet: 迭代剪枝方法
1.1 核心思想
PackNet 由 Mallya 和 Lazebnik 在 2018 年提出。1
核心洞察:训练完成后,剪掉对当前任务「不重要」的权重,固定这些权重不被修改,然后释放剩余「冗余」参数用于后续任务。
┌────────────────────────────────────────────────────────────────┐
│ PackNet 工作流程 │
├────────────────────────────────────────────────────────────────┤
│ │
│ 任务1训练 ──→ 剪枝20%权重 ──→ 固定权重 ──→ 解冻80%权重 │
│ │ │
│ ↓ │
│ 任务2训练 ──→ 剪枝20%权重 ──→ 固定权重 ──→ 解冻64%权重 │
│ │ │
│ ↓ │
│ 任务3训练 ──→ ... │
│ │
│ 参数量利用率: 20% → 36% → 51.2% → ... (收敛到约80%) │
└────────────────────────────────────────────────────────────────┘
1.2 数学形式化
剪枝准则:使用权重幅值作为重要性指标
剪枝操作:保留 top- 的权重,其余置零:
其中 是阈值,使得 的权重被置零。
1.3 PyTorch 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class PackNet:
"""
迭代剪枝持续学习方法
参考文献: Mallya & Lazebnik "PackNet: Adding Multiple Tasks
to a Single Network by Iterative Pruning", CVPR 2018
"""
def __init__(self, model, prune_ratio=0.5):
"""
Args:
model: 要管理的模型
prune_ratio: 每次剪枝的比例 (例如 0.5 表示剪掉50%)
"""
self.model = model
self.prune_ratio = prune_ratio
# 存储每层的掩码
self.masks = {}
# 记录每任务后的参数状态
self.task_boundaries = [] # [(task_id, params_snapshot), ...]
# 当前可训练参数索引
self.trainable_indices = {} # {layer_name: [indices]}
def save_model(self):
"""保存当前模型参数"""
return {n: p.clone().detach()
for n, p in self.model.named_parameters()}
def get_param_importance(self, dataloader, criterion):
"""
计算参数重要性(基于梯度幅值)
Returns:
importance: {layer_name: importance_score}
"""
self.model.eval()
importance = {}
# 初始化重要性
for n, p in self.model.named_parameters():
if p.requires_grad:
importance[n] = torch.zeros_like(p).cpu()
# 累积梯度幅值
total_samples = 0
for inputs, targets in dataloader:
self.model.zero_grad()
outputs = self.model(inputs)
loss = criterion(outputs, targets)
loss.backward()
for n, p in self.model.named_parameters():
if p.requires_grad and p.grad is not None:
importance[n] += p.grad.data.abs().cpu()
total_samples += inputs.size(0)
# 平均
for n in importance:
importance[n] /= total_samples
return importance
def compute_pruning_mask(self, importance, prune_ratio):
"""
根据重要性计算剪枝掩码
Returns:
mask: {layer_name: binary_mask}
"""
mask = {}
for n, imp in importance.items():
# 计算阈值
threshold = torch.quantile(imp.flatten(), prune_ratio)
# 生成掩码:重要性高的保留 (1),低的剪掉 (0)
binary_mask = (imp > threshold).float()
mask[n] = binary_mask
return mask
def prune_and_freeze(self, dataloader=None, criterion=None):
"""
执行剪枝并冻结不重要参数
如果没有提供 dataloader,则使用随机剪枝
"""
if dataloader is not None and criterion is not None:
# 基于重要性的剪枝
importance = self.get_param_importance(dataloader, criterion)
new_mask = self.compute_pruning_mask(importance, self.prune_ratio)
else:
# 随机剪枝(简化版本)
new_mask = {}
for n, p in self.model.named_parameters():
if p.requires_grad:
# 随机选择要剪枝的位置
mask = torch.rand_like(p) > self.prune_ratio
new_mask[n] = mask.float()
# 更新掩码(与之前掩码取交集)
for n, m in new_mask.items():
if n in self.masks:
# 已冻结的参数保持冻结
self.masks[n] = self.masks[n] * m
else:
self.masks[n] = m
# 应用掩码并冻结参数
self.apply_masks()
# 打印统计
frozen_ratio = self.get_frozen_ratio()
print(f"冻结比例: {frozen_ratio:.2%}")
def apply_masks(self):
"""应用掩码并冻结不重要参数"""
for n, p in self.model.named_parameters():
if n in self.masks:
# 保存原始梯度(用于后续分析)
p.register_hook(lambda grad, n=n: grad * self.masks[n].to(grad.device))
# 冻结参数
if p.requires_grad:
p.requires_grad = False
def unfreeze_trainable(self):
"""解冻可训练参数"""
for n, p in self.model.named_parameters():
if n in self.masks:
# 如果该参数还有未被冻结的部分
if (self.masks[n] > 0).any():
p.requires_grad = True
def get_frozen_ratio(self):
"""获取冻结参数的比例"""
total_params = 0
frozen_params = 0
for n, p in self.model.named_parameters():
total_params += p.numel()
if n in self.masks:
frozen_params += (self.masks[n] == 0).sum().item()
return frozen_params / total_params
def train_task(self, train_loader, optimizer, criterion, epochs=10):
"""
训练一个任务
"""
self.unfreeze_trainable()
self.model.train()
for epoch in range(epochs):
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = self.model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
def finish_task(self, dataloader=None, criterion=None):
"""
完成任务后的处理:保存模型状态,执行剪枝
"""
# 保存当前参数状态
self.task_boundaries.append(self.save_model())
# 剪枝并冻结
self.prune_and_freeze(dataloader, criterion)1.4 使用示例
def train_with_packnet():
model = MyModel()
criterion = nn.CrossEntropyLoss()
# 初始化 PackNet
packnet = PackNet(model, prune_ratio=0.5)
# === 任务1 ===
print("训练任务1...")
task1_loader = get_task_loader('task1')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
packnet.train_task(task1_loader, optimizer, criterion, epochs=10)
packnet.finish_task(task1_loader, criterion)
# === 任务2 ===
print("训练任务2...")
task2_loader = get_task_loader('task2')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
packnet.train_task(task2_loader, optimizer, criterion, epochs=10)
packnet.finish_task(task2_loader, criterion)
# === 任务3 ===
print("训练任务3...")
task3_loader = get_task_loader('task3')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
packnet.train_task(task3_loader, optimizer, criterion, epochs=10)
packnet.finish_task(task3_loader, criterion)
return model2. Progressive Neural Networks
2.1 核心思想
Progressive Neural Networks (PNN) 由 Rusu 等人在 2016 年提出。2
核心洞察:不尝试复用已有参数,而是为每个新任务添加新的网络列(columns),通过横向连接复用之前任务的特征表示。
┌────────────────────────────────────────────────────────────────┐
│ Progressive Neural Networks │
├────────────────────────────────────────────────────────────────┤
│ │
│ 任务1列 任务2列 任务3列 │
│ ┌─────┐ ┌─────┐ ┌─────┐ │
│ │Col 1│──→│ │ │ │ 横向连接复用特征 │
│ └─────┘ │Col 2│──→│ │ │
│ └─────┘ │Col 3│ │
│ └─────┘ │
│ │
│ 特点: 每列独立,横向连接促进迁移 │
└────────────────────────────────────────────────────────────────┘
2.2 架构设计
PNN 的每一列是一个完整的神经网络,新增任务时:
- 添加新列:完整的网络结构(包括输入层到输出层)
- 横向连接:将之前任务列的输出连接到新列的隐藏层
- 固定旧列:不修改已训练列的参数
横向连接公式:
其中 是第 列第 层的输出, 是从列 到列 的横向连接权重。
2.3 PyTorch 实现
class ProgressiveColumn(nn.Module):
"""PNN 的单列网络"""
def __init__(self, input_dim, hidden_dims, output_dim, lateral_input_dim=0):
super().__init__()
layers = []
prev_dim = input_dim
for h_dim in hidden_dims:
layers.append(nn.Linear(prev_dim + lateral_input_dim, h_dim))
layers.append(nn.ReLU())
prev_dim = h_dim
layers.append(nn.Linear(prev_dim, output_dim))
self.network = nn.Sequential(*layers)
def forward(self, x, lateral_input=None):
if lateral_input is not None:
x = torch.cat([x, lateral_input], dim=-1)
return self.network(x)
class ProgressiveNeuralNetwork(nn.Module):
"""
Progressive Neural Networks
参考文献: Rusu et al. "Progressive neural networks", arXiv 2016
"""
def __init__(self, input_dim, hidden_dims, output_dim):
super().__init__()
self.input_dim = input_dim
self.hidden_dims = hidden_dims
self.output_dim = output_dim
# 列列表
self.columns = nn.ModuleList()
# 横向连接列表
self.lateral_connections = nn.ModuleDict()
# 创建第一个列(无横向连接)
self.add_column()
def add_column(self):
"""添加新的网络列"""
column_idx = len(self.columns)
# 创建新列
# 对于第一个列,没有横向输入
# 对于后续列,横向输入维度 = 之前所有列的输出维度之和
lateral_dim = sum(self.hidden_dims) if column_idx > 0 else 0
column = ProgressiveColumn(
input_dim=self.input_dim,
hidden_dims=self.hidden_dims,
output_dim=self.output_dim,
lateral_input_dim=lateral_dim
)
self.columns.append(column)
# 创建横向连接(从之前所有列到新列)
if column_idx > 0:
for prev_idx in range(column_idx):
lateral = nn.ModuleDict()
for layer_idx, h_dim in enumerate(self.hidden_dims):
lateral[f'h{layer_idx}'] = nn.Linear(
self.hidden_dims[layer_idx], # 输入:之前列的该层输出
self.hidden_dims[layer_idx] # 输出:新列的该层输入
)
self.lateral_connections[f'{prev_idx}->{column_idx}'] = lateral
def forward(self, x, task_id=None):
"""
前向传播
Args:
x: 输入
task_id: 指定任务的列(如果为None,使用最后一列)
"""
if task_id is None:
task_id = len(self.columns) - 1
outputs = []
for col_idx in range(task_id + 1):
column = self.columns[col_idx]
if col_idx == 0:
# 第一列没有横向连接
col_output = column(x, lateral_input=None)
else:
# 聚合之前所有列的输出
lateral_input = []
for prev_idx in range(col_idx):
# 获取之前列的输出
prev_output = column.network[col_idx * 2](outputs[prev_idx]) # 经过激活
# 通过横向连接
lateral_conn = self.lateral_connections[f'{prev_idx}->{col_idx}']
lateral_transformed = lateral_conn[f'h{col_idx}'](prev_output)
lateral_input.append(lateral_transformed)
lateral_input = torch.cat(lateral_input, dim=-1)
col_output = column(x, lateral_input=lateral_input)
outputs.append(col_output)
return outputs[task_id]
def freeze_columns(self, column_ids):
"""冻结指定的列"""
for col_id in column_ids:
for param in self.columns[col_id].parameters():
param.requires_grad = False
def unfreeze_columns(self, column_ids):
"""解冻指定的列"""
for col_id in column_ids:
for param in self.columns[col_id].parameters():
param.requires_grad = True3. HAT: 硬注意力机制
3.1 核心思想
HAT(Hard Attention to the Task)由 Serra 等人在 2018 年提出。3
核心洞察:使用二进制注意力掩码来隔离不同任务的参数。与 PackNet 的剪枝不同,HAT 通过学习注意力掩码来「软性」地隔离参数。
3.2 数学形式化
注意力掩码:
其中 是第 层的参数数量, 是任务索引。
参数掩码应用:
这里 表示逐元素乘法。直观上,对于新任务的每个参数:
- 如果掩码为 1,保持旧值(不更新)
- 如果掩码为 0,使用新值
3.3 PyTorch 实现
class HardAttentionTask(nn.Module):
"""
硬注意力任务模块
为每个任务学习一个注意力掩码
"""
def __init__(self, num_layers, embed_dim, num_heads=4):
super().__init__()
self.num_layers = num_layers
self.embed_dim = embed_dim
self.num_heads = num_heads
# 任务嵌入
self.task_embed = nn.Embedding(1, embed_dim)
# 层归一化
self.layer_norm = nn.LayerNorm(embed_dim)
# 注意力网络:为每层生成掩码
self.attention = nn.ModuleList([
nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
for _ in range(num_layers)
])
# 掩码生成器
self.mask_generator = nn.ModuleList([
nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, 1),
nn.Sigmoid() # 输出 [0, 1] 范围
)
for _ in range(num_layers)
])
# 可学习的缩放参数
self.s = nn.Parameter(torch.tensor(5.0))
def get_mask(self, task_id, num_params):
"""
生成二值掩码
Args:
task_id: 任务ID
num_params: 参数数量
Returns:
mask: 二值掩码
"""
# 获取任务嵌入
task_embedding = self.task_embed(
torch.tensor([task_id], device=next(self.parameters()).device)
)
# 对嵌入进行归一化
task_embedding = self.layer_norm(task_embedding)
# 通过注意力层增强
for attn_layer in self.attention:
task_embedding, _ = attn_layer(
task_embedding.unsqueeze(0),
task_embedding.unsqueeze(0),
task_embedding.unsqueeze(0)
)
task_embedding = self.layer_norm(task_embedding.squeeze(0))
# 生成掩码
mask = self.mask_generator[0](task_embedding)
# 阈值化为二值掩码
mask = (mask > 0.5).float()
return mask
def forward(self, x, task_id):
return x
class HATContinualLearner(nn.Module):
"""
HAT 持续学习器
参考文献: Serra et al. "Overcoming catastrophic forgetting with
hard attention to the task", ICML 2018
"""
def __init__(self, model, num_tasks, embed_dim=64):
super().__init__()
self.model = model
self.num_tasks = num_tasks
self.embed_dim = embed_dim
# HAT 模块
self.hat_modules = nn.ModuleDict()
# 为每个任务创建 HAT 模块
for task_id in range(num_tasks):
self.hat_modules[f'task_{task_id}'] = HardAttentionTask(
num_layers=len(list(model.parameters())),
embed_dim=embed_dim
)
# 存储每个任务的重要掩码
self.task_masks = {}
def forward(self, x, task_id):
"""前向传播"""
# 获取任务掩码
mask = self.get_task_mask(task_id)
# 应用掩码到参数
param_idx = 0
for p in self.model.parameters():
if p.requires_grad and param_idx < len(mask):
mask_val = mask[param_idx].to(p.device)
# 这里需要特殊处理,因为掩码需要 reshape 为参数形状
# 简化实现:使用传入的掩码
pass
param_idx += 1
return self.model(x)
def get_task_mask(self, task_id):
"""获取任务的注意力掩码"""
if task_id not in self.task_masks:
hat_module = self.hat_modules[f'task_{task_id}']
# 生成掩码(需要知道参数形状,这里简化处理)
total_params = sum(p.numel() for p in self.model.parameters())
self.task_masks[task_id] = hat_module.get_mask(task_id, total_params)
return self.task_masks[task_id]
def get_attention_loss(self, task_id, lambda_attn=0.01):
"""
计算注意力正则化损失
促进不同任务使用不同的参数子集
"""
loss = 0
masks = list(self.task_masks.values())
# 鼓励不同任务的掩码正交
for i in range(len(masks)):
for j in range(i + 1, len(masks)):
overlap = (masks[i] * masks[j]).sum()
loss += overlap # 最小化重叠
return lambda_attn * loss4. P&C: 进展与压缩
4.1 核心思想
P&C(Progress & Compress)由 Schwarz 等人在 2018 年提出。4
核心洞察:交替执行「进展阶段」(学习新任务)和「压缩阶段」(将知识蒸馏到基座网络),结合了架构方法和知识蒸馏的优势。
┌────────────────────────────────────────────────────────────────┐
│ P&C 工作流程 │
├────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ │
│ │ 活跃列 │ ← 当前任务的列,可训练 │
│ └─────────────┘ │
│ │ │
│ ↓ │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ 进展阶段 │───→│ 压缩阶段 │ │
│ │ (学习新任务) │ │ (蒸馏到基座) │ │
│ └─────────────┘ └─────────────┘ │
│ │ │
│ ↓ │
│ ┌─────────────────────────────────────────────┐ │
│ │ 基座网络(知识库) │ │
│ └─────────────────────────────────────────────┘ │
└────────────────────────────────────────────────────────────────┘
4.2 阶段说明
| 阶段 | 描述 | 可训练参数 |
|---|---|---|
| 进展阶段 | 使用活跃列学习新任务 | 活跃列 + 新增容量 |
| 压缩阶段 | 知识蒸馏到基座网络 | 基座网络 |
4.3 实现框架
class ProgressAndCompress(nn.Module):
"""
进展与压缩框架
参考文献: Schwarz et al. "Progress & Compress: A scalable
framework for continual learning", ICML 2018
"""
def __init__(self, base_model, num_tasks, compression_ratio=0.5):
super().__init__()
self.base_model = base_model # 基座网络
self.num_tasks = num_tasks
self.compression_ratio = compression_ratio
# 活跃列
self.active_column = copy.deepcopy(base_model)
# 知识库(用于存储压缩后的知识)
self.knowledge_base = copy.deepcopy(base_model)
# 每任务后的 EWC 备份
self.ewc_modules = {}
def progress_phase(self, task_data, optimizer, criterion, epochs=10):
"""
进展阶段:学习新任务
"""
self.active_column.train()
for epoch in range(epochs):
for inputs, targets in task_data:
optimizer.zero_grad()
outputs = self.active_column(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
def compress_phase(self, task_data, optimizer, criterion, epochs=5):
"""
压缩阶段:将活跃列的知识蒸馏到基座网络
"""
self.base_model.train()
for epoch in range(epochs):
for inputs, targets in task_data:
optimizer.zero_grad()
# 基座网络输出
base_outputs = self.base_model(inputs)
# 活跃列输出(作为软目标)
with torch.no_grad():
active_outputs = self.active_column(inputs)
# 知识蒸馏损失
T = 2.0 # 温度
soft_target = F.softmax(active_outputs / T, dim=-1)
soft_pred = F.log_softmax(base_outputs / T, dim=-1)
loss_kd = F.kl_div(soft_pred, soft_target, reduction='batchmean') * (T ** 2)
# 任务损失
loss_task = criterion(base_outputs, targets)
# 总损失
loss = loss_task + 0.5 * loss_kd
loss.backward()
optimizer.step()
# 更新知识库
self.knowledge_base = copy.deepcopy(self.base_model)
def train_task(self, task_data, epochs=10):
"""
训练单个任务:进展 + 压缩
"""
optimizer = torch.optim.Adam(
list(self.active_column.parameters()) +
list(self.base_model.parameters()),
lr=0.001
)
criterion = nn.CrossEntropyLoss()
# 进展阶段
print("进展阶段...")
self.progress_phase(task_data, optimizer, criterion, epochs)
# 压缩阶段
print("压缩阶段...")
self.compress_phase(task_data, optimizer, criterion, epochs)
def forward(self, x, task_id):
"""
推理时使用基座网络
"""
return self.base_model(x)5. 方法对比与实践建议
5.1 方法对比表
| 方法 | 参数量增长 | 计算开销 | 存储开销 | 优点 | 缺点 |
|---|---|---|---|---|---|
| PackNet | 恒定 | 中等 | 低 | 无额外参数 | 可能剪掉重要参数 |
| PNN | O(任务数) | 高 | O(任务数) | 完全隔离 | 参数量大 |
| HAT | 恒定 + O(任务数) | 中 | 中 | 软性隔离 | 掩码学习复杂 |
| P&C | 恒定 | 高 | 中 | 结合蒸馏 | 两阶段训练 |
5.2 参数量分析
假设基础网络有 个参数,每个新任务:
| 方法 | 新增参数 | 个任务后总参数量 |
|---|---|---|
| PackNet | 0 | |
| PNN | ||
| HAT | ||
| P&C | 0 | (两列) |
5.3 实践建议
- 任务数量少 (< 5):优先使用 PNN,完全隔离无遗忘风险
- 参数量受限:使用 PackNet,效率高
- 需要平衡效果和效率:使用 P&C,结合蒸馏优势
- 需要灵活控制:使用 HAT,可学习掩码
5.4 扩展阅读
- 网络剪枝:PackNet 的思想可扩展到权重重要性排序
- 模块化网络:HAT 可与 Mixture of Experts 结合
- 知识蒸馏:P&C 框架可与多种蒸馏方法结合
参考资料
相关阅读:
Footnotes
-
Mallya, A., & Lazebnik, S. (2018). PackNet: Adding Multiple Tasks to a Single Network by Iterative Pruning. CVPR. ↩
-
Rusu, A. A., et al. (2016). Progressive neural networks. arXiv. ↩
-
Serra, J., et al. (2018). Overcoming catastrophic forgetting with hard attention to the task. ICML. ↩
-
Schwarz, J., et al. (2018). Progress & Compress: A scalable framework for continual learning. ICML. ↩