引言

基于架构的持续学习方法通过动态调整网络结构来避免灾难性遗忘。其核心思想是为不同任务分配独立的参数或模块,使旧任务的知识被物理「隔离」。

这类方法理论上可以完全避免遗忘,但代价是参数量随任务数量线性增长。


1. PackNet: 迭代剪枝方法

1.1 核心思想

PackNet 由 Mallya 和 Lazebnik 在 2018 年提出。1

核心洞察:训练完成后,剪掉对当前任务「不重要」的权重,固定这些权重不被修改,然后释放剩余「冗余」参数用于后续任务。

┌────────────────────────────────────────────────────────────────┐
│                    PackNet 工作流程                             │
├────────────────────────────────────────────────────────────────┤
│                                                                │
│  任务1训练 ──→ 剪枝20%权重 ──→ 固定权重 ──→ 解冻80%权重        │
│      │                                                            │
│      ↓                                                            │
│  任务2训练 ──→ 剪枝20%权重 ──→ 固定权重 ──→ 解冻64%权重        │
│      │                                                            │
│      ↓                                                            │
│  任务3训练 ──→ ...                                              │
│                                                                │
│  参数量利用率: 20% → 36% → 51.2% → ... (收敛到约80%)          │
└────────────────────────────────────────────────────────────────┘

1.2 数学形式化

剪枝准则:使用权重幅值作为重要性指标

剪枝操作:保留 top- 的权重,其余置零:

其中 是阈值,使得 的权重被置零。

1.3 PyTorch 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
 
class PackNet:
    """
    迭代剪枝持续学习方法
    
    参考文献: Mallya & Lazebnik "PackNet: Adding Multiple Tasks 
             to a Single Network by Iterative Pruning", CVPR 2018
    """
    
    def __init__(self, model, prune_ratio=0.5):
        """
        Args:
            model: 要管理的模型
            prune_ratio: 每次剪枝的比例 (例如 0.5 表示剪掉50%)
        """
        self.model = model
        self.prune_ratio = prune_ratio
        
        # 存储每层的掩码
        self.masks = {}
        
        # 记录每任务后的参数状态
        self.task_boundaries = []  # [(task_id, params_snapshot), ...]
        
        # 当前可训练参数索引
        self.trainable_indices = {}  # {layer_name: [indices]}
        
    def save_model(self):
        """保存当前模型参数"""
        return {n: p.clone().detach() 
                for n, p in self.model.named_parameters()}
    
    def get_param_importance(self, dataloader, criterion):
        """
        计算参数重要性(基于梯度幅值)
        
        Returns:
            importance: {layer_name: importance_score}
        """
        self.model.eval()
        importance = {}
        
        # 初始化重要性
        for n, p in self.model.named_parameters():
            if p.requires_grad:
                importance[n] = torch.zeros_like(p).cpu()
        
        # 累积梯度幅值
        total_samples = 0
        for inputs, targets in dataloader:
            self.model.zero_grad()
            outputs = self.model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            
            for n, p in self.model.named_parameters():
                if p.requires_grad and p.grad is not None:
                    importance[n] += p.grad.data.abs().cpu()
            
            total_samples += inputs.size(0)
        
        # 平均
        for n in importance:
            importance[n] /= total_samples
            
        return importance
    
    def compute_pruning_mask(self, importance, prune_ratio):
        """
        根据重要性计算剪枝掩码
        
        Returns:
            mask: {layer_name: binary_mask}
        """
        mask = {}
        
        for n, imp in importance.items():
            # 计算阈值
            threshold = torch.quantile(imp.flatten(), prune_ratio)
            
            # 生成掩码:重要性高的保留 (1),低的剪掉 (0)
            binary_mask = (imp > threshold).float()
            mask[n] = binary_mask
            
        return mask
    
    def prune_and_freeze(self, dataloader=None, criterion=None):
        """
        执行剪枝并冻结不重要参数
        
        如果没有提供 dataloader,则使用随机剪枝
        """
        if dataloader is not None and criterion is not None:
            # 基于重要性的剪枝
            importance = self.get_param_importance(dataloader, criterion)
            new_mask = self.compute_pruning_mask(importance, self.prune_ratio)
        else:
            # 随机剪枝(简化版本)
            new_mask = {}
            for n, p in self.model.named_parameters():
                if p.requires_grad:
                    # 随机选择要剪枝的位置
                    mask = torch.rand_like(p) > self.prune_ratio
                    new_mask[n] = mask.float()
        
        # 更新掩码(与之前掩码取交集)
        for n, m in new_mask.items():
            if n in self.masks:
                # 已冻结的参数保持冻结
                self.masks[n] = self.masks[n] * m
            else:
                self.masks[n] = m
        
        # 应用掩码并冻结参数
        self.apply_masks()
        
        # 打印统计
        frozen_ratio = self.get_frozen_ratio()
        print(f"冻结比例: {frozen_ratio:.2%}")
        
    def apply_masks(self):
        """应用掩码并冻结不重要参数"""
        for n, p in self.model.named_parameters():
            if n in self.masks:
                # 保存原始梯度(用于后续分析)
                p.register_hook(lambda grad, n=n: grad * self.masks[n].to(grad.device))
                
                # 冻结参数
                if p.requires_grad:
                    p.requires_grad = False
    
    def unfreeze_trainable(self):
        """解冻可训练参数"""
        for n, p in self.model.named_parameters():
            if n in self.masks:
                # 如果该参数还有未被冻结的部分
                if (self.masks[n] > 0).any():
                    p.requires_grad = True
    
    def get_frozen_ratio(self):
        """获取冻结参数的比例"""
        total_params = 0
        frozen_params = 0
        
        for n, p in self.model.named_parameters():
            total_params += p.numel()
            if n in self.masks:
                frozen_params += (self.masks[n] == 0).sum().item()
        
        return frozen_params / total_params
    
    def train_task(self, train_loader, optimizer, criterion, epochs=10):
        """
        训练一个任务
        """
        self.unfreeze_trainable()
        
        self.model.train()
        for epoch in range(epochs):
            for inputs, targets in train_loader:
                optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
    
    def finish_task(self, dataloader=None, criterion=None):
        """
        完成任务后的处理:保存模型状态,执行剪枝
        """
        # 保存当前参数状态
        self.task_boundaries.append(self.save_model())
        
        # 剪枝并冻结
        self.prune_and_freeze(dataloader, criterion)

1.4 使用示例

def train_with_packnet():
    model = MyModel()
    criterion = nn.CrossEntropyLoss()
    
    # 初始化 PackNet
    packnet = PackNet(model, prune_ratio=0.5)
    
    # === 任务1 ===
    print("训练任务1...")
    task1_loader = get_task_loader('task1')
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    packnet.train_task(task1_loader, optimizer, criterion, epochs=10)
    packnet.finish_task(task1_loader, criterion)
    
    # === 任务2 ===
    print("训练任务2...")
    task2_loader = get_task_loader('task2')
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    packnet.train_task(task2_loader, optimizer, criterion, epochs=10)
    packnet.finish_task(task2_loader, criterion)
    
    # === 任务3 ===
    print("训练任务3...")
    task3_loader = get_task_loader('task3')
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    packnet.train_task(task3_loader, optimizer, criterion, epochs=10)
    packnet.finish_task(task3_loader, criterion)
    
    return model

2. Progressive Neural Networks

2.1 核心思想

Progressive Neural Networks (PNN) 由 Rusu 等人在 2016 年提出。2

核心洞察:不尝试复用已有参数,而是为每个新任务添加新的网络列(columns),通过横向连接复用之前任务的特征表示。

┌────────────────────────────────────────────────────────────────┐
│                 Progressive Neural Networks                     │
├────────────────────────────────────────────────────────────────┤
│                                                                │
│   任务1列    任务2列    任务3列                                 │
│   ┌─────┐   ┌─────┐   ┌─────┐                                 │
│   │Col 1│──→│     │   │     │  横向连接复用特征                 │
│   └─────┘   │Col 2│──→│     │                                 │
│             └─────┘   │Col 3│                                 │
│                        └─────┘                                 │
│                                                                │
│   特点: 每列独立,横向连接促进迁移                              │
└────────────────────────────────────────────────────────────────┘

2.2 架构设计

PNN 的每一列是一个完整的神经网络,新增任务时:

  1. 添加新列:完整的网络结构(包括输入层到输出层)
  2. 横向连接:将之前任务列的输出连接到新列的隐藏层
  3. 固定旧列:不修改已训练列的参数

横向连接公式

其中 是第 列第 层的输出, 是从列 到列 的横向连接权重。

2.3 PyTorch 实现

class ProgressiveColumn(nn.Module):
    """PNN 的单列网络"""
    
    def __init__(self, input_dim, hidden_dims, output_dim, lateral_input_dim=0):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        
        for h_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim + lateral_input_dim, h_dim))
            layers.append(nn.ReLU())
            prev_dim = h_dim
        
        layers.append(nn.Linear(prev_dim, output_dim))
        
        self.network = nn.Sequential(*layers)
        
    def forward(self, x, lateral_input=None):
        if lateral_input is not None:
            x = torch.cat([x, lateral_input], dim=-1)
        return self.network(x)
 
 
class ProgressiveNeuralNetwork(nn.Module):
    """
    Progressive Neural Networks
    
    参考文献: Rusu et al. "Progressive neural networks", arXiv 2016
    """
    
    def __init__(self, input_dim, hidden_dims, output_dim):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.output_dim = output_dim
        
        # 列列表
        self.columns = nn.ModuleList()
        # 横向连接列表
        self.lateral_connections = nn.ModuleDict()
        
        # 创建第一个列(无横向连接)
        self.add_column()
        
    def add_column(self):
        """添加新的网络列"""
        column_idx = len(self.columns)
        
        # 创建新列
        # 对于第一个列,没有横向输入
        # 对于后续列,横向输入维度 = 之前所有列的输出维度之和
        lateral_dim = sum(self.hidden_dims) if column_idx > 0 else 0
        
        column = ProgressiveColumn(
            input_dim=self.input_dim,
            hidden_dims=self.hidden_dims,
            output_dim=self.output_dim,
            lateral_input_dim=lateral_dim
        )
        
        self.columns.append(column)
        
        # 创建横向连接(从之前所有列到新列)
        if column_idx > 0:
            for prev_idx in range(column_idx):
                lateral = nn.ModuleDict()
                for layer_idx, h_dim in enumerate(self.hidden_dims):
                    lateral[f'h{layer_idx}'] = nn.Linear(
                        self.hidden_dims[layer_idx],  # 输入:之前列的该层输出
                        self.hidden_dims[layer_idx]    # 输出:新列的该层输入
                    )
                self.lateral_connections[f'{prev_idx}->{column_idx}'] = lateral
        
    def forward(self, x, task_id=None):
        """
        前向传播
        
        Args:
            x: 输入
            task_id: 指定任务的列(如果为None,使用最后一列)
        """
        if task_id is None:
            task_id = len(self.columns) - 1
        
        outputs = []
        for col_idx in range(task_id + 1):
            column = self.columns[col_idx]
            
            if col_idx == 0:
                # 第一列没有横向连接
                col_output = column(x, lateral_input=None)
            else:
                # 聚合之前所有列的输出
                lateral_input = []
                for prev_idx in range(col_idx):
                    # 获取之前列的输出
                    prev_output = column.network[col_idx * 2](outputs[prev_idx])  # 经过激活
                    # 通过横向连接
                    lateral_conn = self.lateral_connections[f'{prev_idx}->{col_idx}']
                    lateral_transformed = lateral_conn[f'h{col_idx}'](prev_output)
                    lateral_input.append(lateral_transformed)
                
                lateral_input = torch.cat(lateral_input, dim=-1)
                col_output = column(x, lateral_input=lateral_input)
            
            outputs.append(col_output)
        
        return outputs[task_id]
    
    def freeze_columns(self, column_ids):
        """冻结指定的列"""
        for col_id in column_ids:
            for param in self.columns[col_id].parameters():
                param.requires_grad = False
                
    def unfreeze_columns(self, column_ids):
        """解冻指定的列"""
        for col_id in column_ids:
            for param in self.columns[col_id].parameters():
                param.requires_grad = True

3. HAT: 硬注意力机制

3.1 核心思想

HAT(Hard Attention to the Task)由 Serra 等人在 2018 年提出。3

核心洞察:使用二进制注意力掩码来隔离不同任务的参数。与 PackNet 的剪枝不同,HAT 通过学习注意力掩码来「软性」地隔离参数。

3.2 数学形式化

注意力掩码

其中 是第 层的参数数量, 是任务索引。

参数掩码应用

这里 表示逐元素乘法。直观上,对于新任务的每个参数:

  • 如果掩码为 1,保持旧值(不更新)
  • 如果掩码为 0,使用新值

3.3 PyTorch 实现

class HardAttentionTask(nn.Module):
    """
    硬注意力任务模块
    
    为每个任务学习一个注意力掩码
    """
    
    def __init__(self, num_layers, embed_dim, num_heads=4):
        super().__init__()
        
        self.num_layers = num_layers
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        
        # 任务嵌入
        self.task_embed = nn.Embedding(1, embed_dim)
        
        # 层归一化
        self.layer_norm = nn.LayerNorm(embed_dim)
        
        # 注意力网络:为每层生成掩码
        self.attention = nn.ModuleList([
            nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
            for _ in range(num_layers)
        ])
        
        # 掩码生成器
        self.mask_generator = nn.ModuleList([
            nn.Sequential(
                nn.Linear(embed_dim, embed_dim),
                nn.ReLU(),
                nn.Linear(embed_dim, 1),
                nn.Sigmoid()  # 输出 [0, 1] 范围
            )
            for _ in range(num_layers)
        ])
        
        # 可学习的缩放参数
        self.s = nn.Parameter(torch.tensor(5.0))
        
    def get_mask(self, task_id, num_params):
        """
        生成二值掩码
        
        Args:
            task_id: 任务ID
            num_params: 参数数量
            
        Returns:
            mask: 二值掩码
        """
        # 获取任务嵌入
        task_embedding = self.task_embed(
            torch.tensor([task_id], device=next(self.parameters()).device)
        )
        
        # 对嵌入进行归一化
        task_embedding = self.layer_norm(task_embedding)
        
        # 通过注意力层增强
        for attn_layer in self.attention:
            task_embedding, _ = attn_layer(
                task_embedding.unsqueeze(0),
                task_embedding.unsqueeze(0),
                task_embedding.unsqueeze(0)
            )
            task_embedding = self.layer_norm(task_embedding.squeeze(0))
        
        # 生成掩码
        mask = self.mask_generator[0](task_embedding)
        
        # 阈值化为二值掩码
        mask = (mask > 0.5).float()
        
        return mask
    
    def forward(self, x, task_id):
        return x
 
 
class HATContinualLearner(nn.Module):
    """
    HAT 持续学习器
    
    参考文献: Serra et al. "Overcoming catastrophic forgetting with 
             hard attention to the task", ICML 2018
    """
    
    def __init__(self, model, num_tasks, embed_dim=64):
        super().__init__()
        
        self.model = model
        self.num_tasks = num_tasks
        self.embed_dim = embed_dim
        
        # HAT 模块
        self.hat_modules = nn.ModuleDict()
        
        # 为每个任务创建 HAT 模块
        for task_id in range(num_tasks):
            self.hat_modules[f'task_{task_id}'] = HardAttentionTask(
                num_layers=len(list(model.parameters())),
                embed_dim=embed_dim
            )
        
        # 存储每个任务的重要掩码
        self.task_masks = {}
        
    def forward(self, x, task_id):
        """前向传播"""
        # 获取任务掩码
        mask = self.get_task_mask(task_id)
        
        # 应用掩码到参数
        param_idx = 0
        for p in self.model.parameters():
            if p.requires_grad and param_idx < len(mask):
                mask_val = mask[param_idx].to(p.device)
                # 这里需要特殊处理,因为掩码需要 reshape 为参数形状
                # 简化实现:使用传入的掩码
                pass
            param_idx += 1
        
        return self.model(x)
    
    def get_task_mask(self, task_id):
        """获取任务的注意力掩码"""
        if task_id not in self.task_masks:
            hat_module = self.hat_modules[f'task_{task_id}']
            # 生成掩码(需要知道参数形状,这里简化处理)
            total_params = sum(p.numel() for p in self.model.parameters())
            self.task_masks[task_id] = hat_module.get_mask(task_id, total_params)
        
        return self.task_masks[task_id]
    
    def get_attention_loss(self, task_id, lambda_attn=0.01):
        """
        计算注意力正则化损失
        促进不同任务使用不同的参数子集
        """
        loss = 0
        masks = list(self.task_masks.values())
        
        # 鼓励不同任务的掩码正交
        for i in range(len(masks)):
            for j in range(i + 1, len(masks)):
                overlap = (masks[i] * masks[j]).sum()
                loss += overlap  # 最小化重叠
        
        return lambda_attn * loss

4. P&C: 进展与压缩

4.1 核心思想

P&C(Progress & Compress)由 Schwarz 等人在 2018 年提出。4

核心洞察:交替执行「进展阶段」(学习新任务)和「压缩阶段」(将知识蒸馏到基座网络),结合了架构方法和知识蒸馏的优势。

┌────────────────────────────────────────────────────────────────┐
│                    P&C 工作流程                                 │
├────────────────────────────────────────────────────────────────┤
│                                                                │
│  ┌─────────────┐                                               │
│  │  活跃列      │ ← 当前任务的列,可训练                         │
│  └─────────────┘                                               │
│       │                                                         │
│       ↓                                                         │
│  ┌─────────────┐    ┌─────────────┐                            │
│  │  进展阶段    │───→│  压缩阶段    │                           │
│  │ (学习新任务)  │    │ (蒸馏到基座) │                           │
│  └─────────────┘    └─────────────┘                            │
│                              │                                  │
│                              ↓                                  │
│  ┌─────────────────────────────────────────────┐               │
│  │              基座网络(知识库)                │               │
│  └─────────────────────────────────────────────┘               │
└────────────────────────────────────────────────────────────────┘

4.2 阶段说明

阶段描述可训练参数
进展阶段使用活跃列学习新任务活跃列 + 新增容量
压缩阶段知识蒸馏到基座网络基座网络

4.3 实现框架

class ProgressAndCompress(nn.Module):
    """
    进展与压缩框架
    
    参考文献: Schwarz et al. "Progress & Compress: A scalable 
             framework for continual learning", ICML 2018
    """
    
    def __init__(self, base_model, num_tasks, compression_ratio=0.5):
        super().__init__()
        
        self.base_model = base_model  # 基座网络
        self.num_tasks = num_tasks
        self.compression_ratio = compression_ratio
        
        # 活跃列
        self.active_column = copy.deepcopy(base_model)
        
        # 知识库(用于存储压缩后的知识)
        self.knowledge_base = copy.deepcopy(base_model)
        
        # 每任务后的 EWC 备份
        self.ewc_modules = {}
        
    def progress_phase(self, task_data, optimizer, criterion, epochs=10):
        """
        进展阶段:学习新任务
        """
        self.active_column.train()
        
        for epoch in range(epochs):
            for inputs, targets in task_data:
                optimizer.zero_grad()
                outputs = self.active_column(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
    
    def compress_phase(self, task_data, optimizer, criterion, epochs=5):
        """
        压缩阶段:将活跃列的知识蒸馏到基座网络
        """
        self.base_model.train()
        
        for epoch in range(epochs):
            for inputs, targets in task_data:
                optimizer.zero_grad()
                
                # 基座网络输出
                base_outputs = self.base_model(inputs)
                
                # 活跃列输出(作为软目标)
                with torch.no_grad():
                    active_outputs = self.active_column(inputs)
                
                # 知识蒸馏损失
                T = 2.0  # 温度
                soft_target = F.softmax(active_outputs / T, dim=-1)
                soft_pred = F.log_softmax(base_outputs / T, dim=-1)
                loss_kd = F.kl_div(soft_pred, soft_target, reduction='batchmean') * (T ** 2)
                
                # 任务损失
                loss_task = criterion(base_outputs, targets)
                
                # 总损失
                loss = loss_task + 0.5 * loss_kd
                loss.backward()
                optimizer.step()
        
        # 更新知识库
        self.knowledge_base = copy.deepcopy(self.base_model)
    
    def train_task(self, task_data, epochs=10):
        """
        训练单个任务:进展 + 压缩
        """
        optimizer = torch.optim.Adam(
            list(self.active_column.parameters()) + 
            list(self.base_model.parameters()),
            lr=0.001
        )
        criterion = nn.CrossEntropyLoss()
        
        # 进展阶段
        print("进展阶段...")
        self.progress_phase(task_data, optimizer, criterion, epochs)
        
        # 压缩阶段
        print("压缩阶段...")
        self.compress_phase(task_data, optimizer, criterion, epochs)
    
    def forward(self, x, task_id):
        """
        推理时使用基座网络
        """
        return self.base_model(x)

5. 方法对比与实践建议

5.1 方法对比表

方法参数量增长计算开销存储开销优点缺点
PackNet恒定中等无额外参数可能剪掉重要参数
PNNO(任务数)O(任务数)完全隔离参数量大
HAT恒定 + O(任务数)软性隔离掩码学习复杂
P&C恒定结合蒸馏两阶段训练

5.2 参数量分析

假设基础网络有 个参数,每个新任务:

方法新增参数 个任务后总参数量
PackNet0
PNN
HAT
P&C0(两列)

5.3 实践建议

  1. 任务数量少 (< 5):优先使用 PNN,完全隔离无遗忘风险
  2. 参数量受限:使用 PackNet,效率高
  3. 需要平衡效果和效率:使用 P&C,结合蒸馏优势
  4. 需要灵活控制:使用 HAT,可学习掩码

5.4 扩展阅读

  • 网络剪枝:PackNet 的思想可扩展到权重重要性排序
  • 模块化网络:HAT 可与 Mixture of Experts 结合
  • 知识蒸馏:P&C 框架可与多种蒸馏方法结合

参考资料


相关阅读

Footnotes

  1. Mallya, A., & Lazebnik, S. (2018). PackNet: Adding Multiple Tasks to a Single Network by Iterative Pruning. CVPR.

  2. Rusu, A. A., et al. (2016). Progressive neural networks. arXiv.

  3. Serra, J., et al. (2018). Overcoming catastrophic forgetting with hard attention to the task. ICML.

  4. Schwarz, J., et al. (2018). Progress & Compress: A scalable framework for continual learning. ICML.