简介

弹性权重整合(Elastic Weight Consolidation, EWC)是由Kirkpatrick等人于2017年提出的里程碑式持续学习方法。EWC通过Fisher信息矩阵衡量参数重要性,并以此作为正则化项来保护旧任务的知识。尽管EWC在实践中表现良好,其理论保证长期缺乏严格的数学分析。本文建立EWC的完整理论框架,包括Fisher信息矩阵的理论性质、正则化效果的数学刻画、以及遗忘量的上界证明。123


1. Fisher信息矩阵的理论基础

1.1 定义与性质

定义1(Fisher信息矩阵):设 是参数化模型,Fisher信息矩阵定义为:

对于分类任务,,则:

1.2 Fisher与Hessian的关系

定理1(Fisher-Hessian关系)1:设 是负对数似然损失。则:

其中 是Hessian矩阵,当且仅当分布是指数族分布时精确成立;否则,Fisher是Hessian的上界()。

含义:Fisher信息矩阵编码了损失曲面的局部曲率信息,参数变化对损失的影响程度与Fisher对角元素成正比。

1.3 Fisher信息矩阵的理论性质

定理2(正定性):Fisher信息矩阵是半正定的:

当参数可识别时(不同的参数对应不同的分布),Fisher是正定的。

定理3(参数变换的协方差解释):Fisher信息矩阵的第 个对角元素满足:

这提供了Fisher的直观解释:参数扰动导致的对数似然变化的方差

1.4 经验Fisher的统计性质

定义2(经验Fisher):在有限样本 上:

定理4(经验Fisher的收敛性)2:设样本独立同分布,则:

且收敛速率满足:

其中 是参数维度。

import torch
import torch.nn as nn
import numpy as np
 
class FisherInformationMatrix:
    """Fisher信息矩阵计算器"""
    
    def __init__(self, model):
        self.model = model
        self.fisher = None
        self.n_samples = 0
    
    def compute_empirical_fisher(self, dataloader, device='cuda'):
        """
        计算经验Fisher信息矩阵
        
        F_ii = (1/N) * sum_n (d log p / d theta_i)^2
        
        Returns:
            fisher_diagonal: Fisher对角元素的字典
        """
        self.model.eval()
        
        # 初始化Fisher
        fisher = {}
        for n, p in self.model.named_parameters():
            if p.requires_grad:
                fisher[n] = torch.zeros_like(p)
        
        n_samples = 0
        
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            
            self.model.zero_grad()
            
            # 前向传播
            output = self.model(x)
            
            # 对于分类任务,使用交叉熵的梯度
            # log p(y|x; theta) = log_softmax(output)[y]
            log_probs = nn.functional.log_softmax(output, dim=-1)
            loss = -log_probs[range(len(y)), y].mean()
            
            # 反向传播
            loss.backward()
            
            # 累积梯度平方
            for n, p in self.model.named_parameters():
                if p.requires_grad and p.grad is not None:
                    fisher[n] += p.grad.data ** 2
            
            n_samples += x.size(0)
        
        # 平均
        for n in fisher:
            fisher[n] /= n_samples
        
        self.fisher = fisher
        self.n_samples = n_samples
        
        return fisher
    
    def compute_fisher_bound(self, param_name):
        """
        计算特定参数的重要性上界
        
        基于Cramér-Rao下界
        """
        if self.fisher is None:
            raise ValueError("Fisher matrix not computed")
        
        if param_name not in self.fisher:
            raise ValueError(f"Parameter {param_name} not found")
        
        fisher_diag = self.fisher[param_name]
        
        # 参数估计方差的Cramér-Rao下界
        crlb = 1.0 / (fisher_diag + 1e-8)
        
        return crlb
    
    def compute_importance_ranking(self):
        """
        计算参数重要性排名
        
        返回按重要性排序的参数名列表
        """
        if self.fisher is None:
            raise ValueError("Fisher matrix not computed")
        
        # 计算每个参数的平均重要性
        importance = {}
        for n, f in self.fisher.items():
            importance[n] = f.mean().item()
        
        # 排序
        sorted_params = sorted(importance.items(), key=lambda x: x[1], reverse=True)
        
        return sorted_params

2. EWC的数学框架

2.1 EWC损失函数

标准EWC损失

其中:

  • :新任务的损失
  • :第 个参数的Fisher对角元素
  • :旧任务的最优参数
  • :正则化强度

2.2 多任务EWC

多任务EWC损失(顺序累积Fisher):

在线EWC(使用_running average更新Fisher):

其中 是累积的Fisher和参数估计。

2.3 EWC的正则化效果

定理5(参数偏移约束):设 是最小化 的解。则在旧任务上:

证明:由泰勒展开:

,且 (因为 的上界),故:


3. 收敛性分析

3.1 EWC作为近似贝叶斯后验

定理6(EWC作为变分推断)2:EWC的最优解 可以解释为以下变分推断问题的解:

其中 是旧任务的贝叶斯后验近似。

含义:EWC在参数空间中搜索一个分布,其均值接近旧任务的贝叶斯后验均值,同时能拟合新任务数据。

3.2 收敛速率分析

定理7(EWC收敛速率):设 是学习任务 后的EWC最优解。则:

含义

  • 正则化强度 越大,参数偏移越小
  • Fisher最小特征值越大(曲率越大),偏移越小

3.3 收敛到Pareto前沿

定理8(Pareto最优性):设 是多任务学习的最优Pareto前沿。则EWC的解序列 满足:

当且仅当:

  1. 任务序列是有限长的
  2. 适当衰减
class EWCConsvergenceAnalyzer:
    """EWC收敛性分析器"""
    
    def __init__(self, model, lambda_reg=1000):
        self.model = model
        self.lambda_reg = lambda_reg
        self.task_params = {}  # 存储每个任务后的参数
        self.task_fishers = {}  # 存储每个任务的Fisher
        self.convergence_history = []
    
    def ewc_loss(self, current_task_loader, prev_task_id, device='cuda'):
        """
        计算EWC损失
        
        Returns:
            total_loss: 总损失
            new_loss: 新任务损失
            ewc_loss: EWC正则化损失
        """
        # 新任务损失
        self.model.zero_grad()
        new_loss = 0
        for x, y in current_task_loader:
            x, y = x.to(device), y.to(device)
            output = self.model(x)
            new_loss += nn.functional.cross_entropy(output, y)
        new_loss /= len(current_task_loader)
        
        # EWC正则化损失
        ewc_loss = 0
        for task_id, fisher in self.task_fishers.items():
            old_params = self.task_params[task_id]
            for (name, p), (_, old_p), (_, f) in zip(
                self.model.named_parameters(),
                old_params.items(),
                fisher.items()
            ):
                if p.requires_grad:
                    ewc_loss += self.lambda_reg * torch.sum(
                        f * (p - old_p) ** 2
                    )
        
        ewc_loss = ewc_loss / 2
        
        total_loss = new_loss + ewc_loss
        
        return total_loss, new_loss, ewc_loss
    
    def compute_parameter_shift(self, task_id):
        """
        计算与旧任务参数的偏移量
        """
        if task_id not in self.task_params:
            return 0
        
        old_params = self.task_params[task_id]
        
        total_shift = 0
        for (name, p), (_, old_p) in zip(
            self.model.named_parameters(),
            old_params.items()
        ):
            shift = torch.norm(p - old_p).item()
            total_shift += shift ** 2
        
        return np.sqrt(total_shift)
    
    def update_fisher(self, task_loader, device='cuda'):
        """更新Fisher信息矩阵"""
        fisher_computer = FisherInformationMatrix(self.model)
        fisher = fisher_computer.compute_empirical_fisher(task_loader, device)
        
        # 存储Fisher
        task_id = len(self.task_fishers)
        self.task_fishers[task_id] = fisher
        
        return fisher
    
    def update_params(self):
        """存储当前参数"""
        task_id = len(self.task_params)
        self.task_params[task_id] = {
            n: p.clone().detach() 
            for n, p in self.model.named_parameters()
        }
    
    def analyze_convergence(self, current_task_loader, prev_task_id, device='cuda'):
        """
        分析EWC收敛性
        """
        # 计算参数偏移
        shift = self.compute_parameter_shift(prev_task_id)
        
        # 计算EWC损失
        _, new_loss, ewc_loss = self.ewc_loss(
            current_task_loader, prev_task_id, device
        )
        
        # 计算Fisher条件数(与收敛速率相关)
        min_eigenvalue = float('inf')
        max_eigenvalue = 0
        
        for fisher in self.task_fishers.values():
            for f in fisher.values():
                f_cpu = f.cpu()
                f_mean = f_cpu.mean().item()
                min_eigenvalue = min(min_eigenvalue, max(f_mean, 1e-10))
                max_eigenvalue = max(max_eigenvalue, f_mean)
        
        condition_number = max_eigenvalue / min_eigenvalue if min_eigenvalue > 0 else float('inf')
        
        self.convergence_history.append({
            'shift': shift,
            'new_loss': new_loss.item(),
            'ewc_loss': ewc_loss.item(),
            'condition_number': condition_number
        })
        
        return self.convergence_history[-1]

4. 遗忘上界证明

4.1 单任务遗忘界

定理9(EWC遗忘上界)3:设 是学习任务 后的EWC最优参数, 是学习任务 后的EWC最优参数。则在任务 上的遗忘满足:

其中 是Mahalanobis距离。

证明概要

  1. EWC目标函数在 处的一阶最优条件:
  1. 由此得:
  1. 对任务 的损失做Taylor展开:
  1. 由于

4.2 累积遗忘界

定理10(多任务累积遗忘界):设 是累积Fisher。则在学习 个任务后,累积遗忘满足:

4.3 最优正则化强度

定理11(最优 选择):为最小化累积遗忘,最优正则化强度满足:

实际选择

  • 过大: 过高(可塑性不足)
  • 过小: 过高(遗忘过多)
  • 经验选择: 之间
class EWCForgettingBoundEstimator:
    """EWC遗忘上界估计器"""
    
    def __init__(self, model):
        self.model = model
        self.task_fishers = {}
        self.task_gradients = {}
    
    def estimate_single_task_forgetting(self, new_task_id, prev_task_id):
        """
        估计学习新任务后对旧任务的遗忘上界
        
        定理9:遗忘 ≤ (1/2λ) * ||∇L_new||^2_{F^{-1}}
        """
        if prev_task_id not in self.task_fishers:
            return float('inf')
        
        # 获取Fisher的逆(近似使用对角逆)
        fisher = self.task_fishers[prev_task_id]
        
        # 获取新任务在旧任务最优参数处的梯度
        gradient = self.task_gradients[new_task_id]
        
        # 计算Mahalanobis距离
        mahalanobis_dist_sq = 0
        param_idx = 0
        
        for name, p in self.model.named_parameters():
            if p.requires_grad and name in fisher:
                f_diag = fisher[name].flatten()
                g = gradient[param_idx:param_idx + p.numel()]
                
                # v^T F^{-1} v ≈ sum_i v_i^2 / f_ii
                mahalanobis_dist_sq += torch.sum(g ** 2 / (f_diag + 1e-8)).item()
                
                param_idx += p.numel()
        
        return 0.5 * mahalanobis_dist_sq
    
    def estimate_optimal_lambda(self, lambda_candidates):
        """
        估计最优正则化强度
        
        使用验证集搜索
        """
        results = []
        
        for lambda_reg in lambda_candidates:
            # 模拟不同λ下的遗忘
            forgetting = self._simulate_ewc(lambda_reg)
            results.append({
                'lambda': lambda_reg,
                'forgetting': forgetting
            })
        
        # 选择遗忘最小的λ
        best = min(results, key=lambda x: x['forgetting'])
        
        return best['lambda'], results
    
    def _simulate_ewc(self, lambda_reg):
        """模拟EWC在不同λ下的遗忘"""
        # 简化模拟:遗忘与λ成反比
        base_forgetting = sum(
            self.estimate_single_task_forgetting(t, t-1) 
            for t in range(1, len(self.task_fishers))
        )
        
        # 遗忘随λ增加而减少(但非线性)
        simulated_forgetting = base_forgetting / np.sqrt(lambda_reg)
        
        return simulated_forgetting

5. 与其他正则化方法的比较

5.1 SI(Synaptic Intelligence)

SI损失

其中 是基于参数轨迹的重要性度量。

比较

方法重要性度量理论基础
EWCFisher信息贝叶斯后验近似
SI参数轨迹在线学习视角
MAS输出敏感度函数空间度量

5.2 RWalk(Reweighting Walk)

RWalk损失2

其中:

  • :累积Fisher
  • :累积参数变化

定理12(RWalk优于EWC):RWalk的遗忘上界严格小于EWC:

5.3 理论比较总结

方法遗忘界正则化强度依赖计算复杂度
EWC
SI
RWalk
EWC+自适应

6. 自适应EWC变体

6.1 在线EWC

问题:原始EWC需要存储所有之前任务的Fisher,计算和存储成本

在线EWC:使用指数移动平均近似累积Fisher:

其中 是衰减率。

定理13(在线EWC的遗忘界):设 步的累积Fisher估计。则:

6.2 自适应λ-EWC

定理14(自适应λ的理论保证):设任务 的难度为 。则自适应 给出最优权衡:

class AdaptiveEWC:
    """自适应EWC实现"""
    
    def __init__(self, model, base_lambda=1000):
        self.model = model
        self.base_lambda = base_lambda
        self.task_fishers = {}
        self.task_params = {}
        self.ema_fisher = None  # 指数移动平均Fisher
        self.alpha = 0.9  # 衰减率
    
    def update_ema_fisher(self, fisher):
        """
        更新指数移动平均Fisher
        
        用于在线EWC
        """
        if self.ema_fisher is None:
            self.ema_fisher = fisher
        else:
            for name in fisher:
                self.ema_fisher[name] = (
                    self.alpha * self.ema_fisher[name] + 
                    (1 - self.alpha) * fisher[name]
                )
        
        return self.ema_fisher
    
    def compute_adaptive_lambda(self, task_id, new_task_loader, device='cuda'):
        """
        计算自适应正则化强度
        
        λ_t* = sqrt(D_t * D_{t-1})
        其中 D_t = ||∇L_t||_{F^{-1}}
        """
        # 计算新任务梯度
        self.model.zero_grad()
        total_grad = None
        n_samples = 0
        
        for x, y in new_task_loader:
            x, y = x.to(device), y.to(device)
            output = self.model(x)
            loss = nn.functional.cross_entropy(output, y)
            loss.backward()
            
            grad = torch.cat([
                p.grad.flatten() 
                for p in self.model.parameters() 
                if p.grad is not None
            ])
            
            if total_grad is None:
                total_grad = grad
            else:
                total_grad += grad
            n_samples += x.size(0)
            self.model.zero_grad()
        
        grad_norm_sq = (total_grad / n_samples).norm().item() ** 2
        
        # 获取Fisher
        if task_id in self.task_fishers:
            fisher = self.task_fishers[task_id]
        else:
            fisher = self.ema_fisher
        
        # 计算D_t
        D_t_sq = grad_norm_sq
        if fisher is not None:
            # 近似计算Mahalanobis距离
            param_idx = 0
            for name, p in self.model.named_parameters():
                if p.requires_grad and name in fisher:
                    f_diag = fisher[name].flatten()
                    D_t_sq = (total_grad / n_samples)[param_idx:param_idx+p.numel()] ** 2 / (f_diag + 1e-8)
                    D_t_sq = D_t_sq.sum().item()
                    break
        
        # 获取D_{t-1}
        if task_id > 0 and (task_id - 1) in self.task_fishers:
            D_prev_sq = self._compute_difficulty(task_id - 1)
        else:
            D_prev_sq = D_t_sq
        
        # 自适应λ
        adaptive_lambda = np.sqrt(D_t_sq * D_prev_sq)
        
        return adaptive_lambda
    
    def _compute_difficulty(self, task_id):
        """计算任务难度"""
        # 简化:使用Fisher范数
        fisher = self.task_fishers[task_id]
        total_fisher_norm = 0
        
        for f in fisher.values():
            total_fisher_norm += (f ** 2).sum().item()
        
        return total_fisher_norm

7. 实践指南

7.1 Fisher计算策略

class EfficientFisherComputation:
    """高效Fisher计算"""
    
    def __init__(self, model):
        self.model = model
    
    def compute_diag_fisher(self, dataloader, device='cuda', n_samples=None):
        """
        计算Fisher对角元素(存储高效)
        
        时间复杂度: O(N * d)
        空间复杂度: O(d)
        """
        self.model.eval()
        
        fisher_diag = {}
        for n, p in self.model.named_parameters():
            if p.requires_grad:
                fisher_diag[n] = torch.zeros_like(p)
        
        n_total = 0
        
        for i, (x, y) in enumerate(dataloader):
            if n_samples and i >= n_samples:
                break
                
            x, y = x.to(device), y.to(device)
            self.model.zero_grad()
            
            output = self.model(x)
            loss = nn.functional.cross_entropy(output, y)
            loss.backward()
            
            for n, p in self.model.named_parameters():
                if p.requires_grad and p.grad is not None:
                    fisher_diag[n] += p.grad.data ** 2
            
            n_total += x.size(0)
        
        # 平均
        for n in fisher_diag:
            fisher_diag[n] /= n_total
        
        return fisher_diag
    
    def compute_block_diag_fisher(self, dataloader, block_size=100, 
                                  device='cuda', n_samples=None):
        """
        计算块对角Fisher(精度与效率的权衡)
        
        将参数分成多个块,只计算块内的Fisher
        """
        self.model.eval()
        
        params = [p for _, p in self.model.named_parameters() if p.requires_grad]
        n_params = sum(p.numel() for p in params)
        n_blocks = (n_params + block_size - 1) // block_size
        
        fisher_blocks = {}  # {(i,j): block_matrix}
        
        n_total = 0
        
        for i, (x, y) in enumerate(dataloader):
            if n_samples and i >= n_samples:
                break
            
            x, y = x.to(device), y.to(device)
            self.model.zero_grad()
            
            output = self.model(x)
            loss = nn.functional.cross_entropy(output, y)
            loss.backward()
            
            # 收集梯度
            grad = torch.cat([p.grad.flatten() for p in params])
            
            # 更新块Fisher
            for block_i in range(n_blocks):
                start_i = block_i * block_size
                end_i = min(start_i + block_size, n_params)
                
                for block_j in range(block_i, n_blocks):
                    start_j = block_j * block_size
                    end_j = min(start_j + block_size, n_params)
                    
                    block_grad = grad[start_i:end_i]
                    block_key = (block_i, block_j)
                    
                    if block_key not in fisher_blocks:
                        fisher_blocks[block_key] = torch.zeros(
                            end_i - start_i, end_j - start_j, device=device
                        )
                    
                    fisher_blocks[block_key] += torch.outer(
                        block_grad, block_grad[:end_j-start_j]
                    )
            
            n_total += x.size(0)
        
        # 平均
        for key in fisher_blocks:
            fisher_blocks[key] /= n_total
        
        return fisher_blocks

7.2 超参数调优

class EWCHyperparameterTuner:
    """EWC超参数调优"""
    
    def __init__(self, model, train_loaders, val_loaders):
        self.model = model
        self.train_loaders = train_loaders
        self.val_loaders = val_loaders
    
    def tune_lambda(self, lambda_candidates, n_trials=3):
        """
        网格搜索最优λ
        """
        results = []
        
        for lambda_reg in lambda_candidates:
            # 多次试验取平均
            forgetting_scores = []
            
            for trial in range(n_trials):
                # 重置模型
                self.model.reset_parameters()
                
                # 训练
                ewc = AdaptiveEWC(self.model, lambda_reg)
                
                for t, loader in enumerate(self.train_loaders):
                    # 更新Fisher
                    fisher_computer = FisherInformationMatrix(self.model)
                    fisher = fisher_computer.compute_empirical_fisher(loader)
                    
                    if t == 0:
                        ewc.task_fishers[t] = fisher
                    else:
                        # 合并Fisher
                        for name in fisher:
                            if name in ewc.task_fishers[t-1]:
                                ewc.task_fishers[t][name] = (
                                    ewc.task_fishers[t-1][name] + fisher[name]
                                )
                            else:
                                ewc.task_fishers[t][name] = fisher[name]
                    
                    # 存储参数
                    ewc.task_params[t] = {
                        n: p.clone() for n, p in self.model.named_parameters()
                    }
                    
                    # 训练新任务(简化:只做几步梯度下降)
                    optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
                    for _ in range(10):
                        for x, y in loader:
                            x, y = x.to('cuda'), y.to('cuda')
                            optimizer.zero_grad()
                            loss = nn.functional.cross_entropy(
                                self.model(x), y
                            )
                            loss.backward()
                            optimizer.step()
            
            # 评估遗忘
            forgetting = self._evaluate_forgetting(ewc, self.val_loaders)
            forgetting_scores.append(forgetting)
            
            results.append({
                'lambda': lambda_reg,
                'forgetting': np.mean(forgetting_scores),
                'std': np.std(forgetting_scores)
            })
        
        # 选择遗忘最小的λ
        best = min(results, key=lambda x: x['forgetting'])
        
        return best, results
    
    def _evaluate_forgetting(self, ewc, val_loaders):
        """评估遗忘"""
        forgetting = []
        
        for t, loader in enumerate(val_loaders):
            # 加载任务t训练后的参数
            self._load_params(ewc.task_params[t])
            
            # 评估在任务t上的性能
            correct = 0
            total = 0
            self.model.eval()
            with torch.no_grad():
                for x, y in loader:
                    x, y = x.to('cuda'), y.to('cuda')
                    pred = self.model(x).argmax(dim=-1)
                    correct += (pred == y).sum().item()
                    total += y.size(0)
            
            accuracy = correct / total
            forgetting.append(1 - accuracy)
        
        return forgetting

8. 总结

核心定理

定理内容实践意义
定理1Fisher-Hessian关系Fisher可作为曲率的上界估计
定理5EWC正则化效果参数偏移与Fisher对角元素的关系
定理7收敛速率收敛速度与 成正比
定理9单任务遗忘上界遗忘被 控制
定理11最优λ选择λ与梯度Mahalanobis距离相关

实践建议

  1. Fisher计算:使用对角近似以节省存储,关注数值稳定性
  2. λ选择:从 - 范围开始,根据遗忘情况调整
  3. 在线EWC:使用EMA更新Fisher,平衡历史信息与当前信息
  4. 自适应λ:根据任务难度动态调整正则化强度

理论启示

  • EWC的有效性来源于Fisher对参数重要性的估计
  • 遗忘量被正则化强度控制,存在理论下界
  • 任务相似性影响Fisher的结构,进而影响EWC效果

参考资料


相关阅读

Footnotes

  1. Kirkpatrick et al. (2017). Overcoming catastrophic forgetting in neural networks. PNAS. 2

  2. Chaudhry et al. (2018). Efficient lifelong learning with A-GEM. ICLR. 2 3 4

  3. Ritter et al. (2018). Online structured Laplace approximations for overcoming catastrophic forgetting. NeurIPS. 2