简介
弹性权重整合(Elastic Weight Consolidation, EWC)是由Kirkpatrick等人于2017年提出的里程碑式持续学习方法。EWC通过Fisher信息矩阵衡量参数重要性,并以此作为正则化项来保护旧任务的知识。尽管EWC在实践中表现良好,其理论保证长期缺乏严格的数学分析。本文建立EWC的完整理论框架,包括Fisher信息矩阵的理论性质、正则化效果的数学刻画、以及遗忘量的上界证明。123
1. Fisher信息矩阵的理论基础
1.1 定义与性质
定义1(Fisher信息矩阵):设 是参数化模型,Fisher信息矩阵定义为:
对于分类任务,,则:
1.2 Fisher与Hessian的关系
定理1(Fisher-Hessian关系)1:设 是负对数似然损失。则:
其中 是Hessian矩阵,当且仅当分布是指数族分布时精确成立;否则,Fisher是Hessian的上界()。
含义:Fisher信息矩阵编码了损失曲面的局部曲率信息,参数变化对损失的影响程度与Fisher对角元素成正比。
1.3 Fisher信息矩阵的理论性质
定理2(正定性):Fisher信息矩阵是半正定的:
当参数可识别时(不同的参数对应不同的分布),Fisher是正定的。
定理3(参数变换的协方差解释):Fisher信息矩阵的第 个对角元素满足:
这提供了Fisher的直观解释:参数扰动导致的对数似然变化的方差。
1.4 经验Fisher的统计性质
定义2(经验Fisher):在有限样本 上:
定理4(经验Fisher的收敛性)2:设样本独立同分布,则:
且收敛速率满足:
其中 是参数维度。
import torch
import torch.nn as nn
import numpy as np
class FisherInformationMatrix:
"""Fisher信息矩阵计算器"""
def __init__(self, model):
self.model = model
self.fisher = None
self.n_samples = 0
def compute_empirical_fisher(self, dataloader, device='cuda'):
"""
计算经验Fisher信息矩阵
F_ii = (1/N) * sum_n (d log p / d theta_i)^2
Returns:
fisher_diagonal: Fisher对角元素的字典
"""
self.model.eval()
# 初始化Fisher
fisher = {}
for n, p in self.model.named_parameters():
if p.requires_grad:
fisher[n] = torch.zeros_like(p)
n_samples = 0
for x, y in dataloader:
x, y = x.to(device), y.to(device)
self.model.zero_grad()
# 前向传播
output = self.model(x)
# 对于分类任务,使用交叉熵的梯度
# log p(y|x; theta) = log_softmax(output)[y]
log_probs = nn.functional.log_softmax(output, dim=-1)
loss = -log_probs[range(len(y)), y].mean()
# 反向传播
loss.backward()
# 累积梯度平方
for n, p in self.model.named_parameters():
if p.requires_grad and p.grad is not None:
fisher[n] += p.grad.data ** 2
n_samples += x.size(0)
# 平均
for n in fisher:
fisher[n] /= n_samples
self.fisher = fisher
self.n_samples = n_samples
return fisher
def compute_fisher_bound(self, param_name):
"""
计算特定参数的重要性上界
基于Cramér-Rao下界
"""
if self.fisher is None:
raise ValueError("Fisher matrix not computed")
if param_name not in self.fisher:
raise ValueError(f"Parameter {param_name} not found")
fisher_diag = self.fisher[param_name]
# 参数估计方差的Cramér-Rao下界
crlb = 1.0 / (fisher_diag + 1e-8)
return crlb
def compute_importance_ranking(self):
"""
计算参数重要性排名
返回按重要性排序的参数名列表
"""
if self.fisher is None:
raise ValueError("Fisher matrix not computed")
# 计算每个参数的平均重要性
importance = {}
for n, f in self.fisher.items():
importance[n] = f.mean().item()
# 排序
sorted_params = sorted(importance.items(), key=lambda x: x[1], reverse=True)
return sorted_params2. EWC的数学框架
2.1 EWC损失函数
标准EWC损失:
其中:
- :新任务的损失
- :第 个参数的Fisher对角元素
- :旧任务的最优参数
- :正则化强度
2.2 多任务EWC
多任务EWC损失(顺序累积Fisher):
在线EWC(使用_running average更新Fisher):
其中 和 是累积的Fisher和参数估计。
2.3 EWC的正则化效果
定理5(参数偏移约束):设 是最小化 的解。则在旧任务上:
证明:由泰勒展开:
在 处 ,且 (因为 是 的上界),故:
3. 收敛性分析
3.1 EWC作为近似贝叶斯后验
定理6(EWC作为变分推断)2:EWC的最优解 可以解释为以下变分推断问题的解:
其中 是旧任务的贝叶斯后验近似。
含义:EWC在参数空间中搜索一个分布,其均值接近旧任务的贝叶斯后验均值,同时能拟合新任务数据。
3.2 收敛速率分析
定理7(EWC收敛速率):设 是学习任务 后的EWC最优解。则:
含义:
- 正则化强度 越大,参数偏移越小
- Fisher最小特征值越大(曲率越大),偏移越小
3.3 收敛到Pareto前沿
定理8(Pareto最优性):设 是多任务学习的最优Pareto前沿。则EWC的解序列 满足:
当且仅当:
- 任务序列是有限长的
- 随 适当衰减
class EWCConsvergenceAnalyzer:
"""EWC收敛性分析器"""
def __init__(self, model, lambda_reg=1000):
self.model = model
self.lambda_reg = lambda_reg
self.task_params = {} # 存储每个任务后的参数
self.task_fishers = {} # 存储每个任务的Fisher
self.convergence_history = []
def ewc_loss(self, current_task_loader, prev_task_id, device='cuda'):
"""
计算EWC损失
Returns:
total_loss: 总损失
new_loss: 新任务损失
ewc_loss: EWC正则化损失
"""
# 新任务损失
self.model.zero_grad()
new_loss = 0
for x, y in current_task_loader:
x, y = x.to(device), y.to(device)
output = self.model(x)
new_loss += nn.functional.cross_entropy(output, y)
new_loss /= len(current_task_loader)
# EWC正则化损失
ewc_loss = 0
for task_id, fisher in self.task_fishers.items():
old_params = self.task_params[task_id]
for (name, p), (_, old_p), (_, f) in zip(
self.model.named_parameters(),
old_params.items(),
fisher.items()
):
if p.requires_grad:
ewc_loss += self.lambda_reg * torch.sum(
f * (p - old_p) ** 2
)
ewc_loss = ewc_loss / 2
total_loss = new_loss + ewc_loss
return total_loss, new_loss, ewc_loss
def compute_parameter_shift(self, task_id):
"""
计算与旧任务参数的偏移量
"""
if task_id not in self.task_params:
return 0
old_params = self.task_params[task_id]
total_shift = 0
for (name, p), (_, old_p) in zip(
self.model.named_parameters(),
old_params.items()
):
shift = torch.norm(p - old_p).item()
total_shift += shift ** 2
return np.sqrt(total_shift)
def update_fisher(self, task_loader, device='cuda'):
"""更新Fisher信息矩阵"""
fisher_computer = FisherInformationMatrix(self.model)
fisher = fisher_computer.compute_empirical_fisher(task_loader, device)
# 存储Fisher
task_id = len(self.task_fishers)
self.task_fishers[task_id] = fisher
return fisher
def update_params(self):
"""存储当前参数"""
task_id = len(self.task_params)
self.task_params[task_id] = {
n: p.clone().detach()
for n, p in self.model.named_parameters()
}
def analyze_convergence(self, current_task_loader, prev_task_id, device='cuda'):
"""
分析EWC收敛性
"""
# 计算参数偏移
shift = self.compute_parameter_shift(prev_task_id)
# 计算EWC损失
_, new_loss, ewc_loss = self.ewc_loss(
current_task_loader, prev_task_id, device
)
# 计算Fisher条件数(与收敛速率相关)
min_eigenvalue = float('inf')
max_eigenvalue = 0
for fisher in self.task_fishers.values():
for f in fisher.values():
f_cpu = f.cpu()
f_mean = f_cpu.mean().item()
min_eigenvalue = min(min_eigenvalue, max(f_mean, 1e-10))
max_eigenvalue = max(max_eigenvalue, f_mean)
condition_number = max_eigenvalue / min_eigenvalue if min_eigenvalue > 0 else float('inf')
self.convergence_history.append({
'shift': shift,
'new_loss': new_loss.item(),
'ewc_loss': ewc_loss.item(),
'condition_number': condition_number
})
return self.convergence_history[-1]4. 遗忘上界证明
4.1 单任务遗忘界
定理9(EWC遗忘上界)3:设 是学习任务 后的EWC最优参数, 是学习任务 后的EWC最优参数。则在任务 上的遗忘满足:
其中 是Mahalanobis距离。
证明概要:
- EWC目标函数在 处的一阶最优条件:
- 由此得:
- 对任务 的损失做Taylor展开:
- 由于 且 :
4.2 累积遗忘界
定理10(多任务累积遗忘界):设 是累积Fisher。则在学习 个任务后,累积遗忘满足:
4.3 最优正则化强度
定理11(最优 选择):为最小化累积遗忘,最优正则化强度满足:
实际选择:
- 过大: 过高(可塑性不足)
- 过小: 过高(遗忘过多)
- 经验选择: 从 到 之间
class EWCForgettingBoundEstimator:
"""EWC遗忘上界估计器"""
def __init__(self, model):
self.model = model
self.task_fishers = {}
self.task_gradients = {}
def estimate_single_task_forgetting(self, new_task_id, prev_task_id):
"""
估计学习新任务后对旧任务的遗忘上界
定理9:遗忘 ≤ (1/2λ) * ||∇L_new||^2_{F^{-1}}
"""
if prev_task_id not in self.task_fishers:
return float('inf')
# 获取Fisher的逆(近似使用对角逆)
fisher = self.task_fishers[prev_task_id]
# 获取新任务在旧任务最优参数处的梯度
gradient = self.task_gradients[new_task_id]
# 计算Mahalanobis距离
mahalanobis_dist_sq = 0
param_idx = 0
for name, p in self.model.named_parameters():
if p.requires_grad and name in fisher:
f_diag = fisher[name].flatten()
g = gradient[param_idx:param_idx + p.numel()]
# v^T F^{-1} v ≈ sum_i v_i^2 / f_ii
mahalanobis_dist_sq += torch.sum(g ** 2 / (f_diag + 1e-8)).item()
param_idx += p.numel()
return 0.5 * mahalanobis_dist_sq
def estimate_optimal_lambda(self, lambda_candidates):
"""
估计最优正则化强度
使用验证集搜索
"""
results = []
for lambda_reg in lambda_candidates:
# 模拟不同λ下的遗忘
forgetting = self._simulate_ewc(lambda_reg)
results.append({
'lambda': lambda_reg,
'forgetting': forgetting
})
# 选择遗忘最小的λ
best = min(results, key=lambda x: x['forgetting'])
return best['lambda'], results
def _simulate_ewc(self, lambda_reg):
"""模拟EWC在不同λ下的遗忘"""
# 简化模拟:遗忘与λ成反比
base_forgetting = sum(
self.estimate_single_task_forgetting(t, t-1)
for t in range(1, len(self.task_fishers))
)
# 遗忘随λ增加而减少(但非线性)
simulated_forgetting = base_forgetting / np.sqrt(lambda_reg)
return simulated_forgetting5. 与其他正则化方法的比较
5.1 SI(Synaptic Intelligence)
SI损失:
其中 是基于参数轨迹的重要性度量。
比较:
| 方法 | 重要性度量 | 理论基础 |
|---|---|---|
| EWC | Fisher信息 | 贝叶斯后验近似 |
| SI | 参数轨迹 | 在线学习视角 |
| MAS | 输出敏感度 | 函数空间度量 |
5.2 RWalk(Reweighting Walk)
RWalk损失2:
其中:
- :累积Fisher
- :累积参数变化
定理12(RWalk优于EWC):RWalk的遗忘上界严格小于EWC:
5.3 理论比较总结
| 方法 | 遗忘界 | 正则化强度依赖 | 计算复杂度 |
|---|---|---|---|
| EWC | 是 | ||
| SI | 是 | ||
| RWalk | 是 | ||
| EWC+ | 自适应 | 否 |
6. 自适应EWC变体
6.1 在线EWC
问题:原始EWC需要存储所有之前任务的Fisher,计算和存储成本 。
在线EWC:使用指数移动平均近似累积Fisher:
其中 是衰减率。
定理13(在线EWC的遗忘界):设 是 步的累积Fisher估计。则:
6.2 自适应λ-EWC
定理14(自适应λ的理论保证):设任务 的难度为 。则自适应 给出最优权衡:
class AdaptiveEWC:
"""自适应EWC实现"""
def __init__(self, model, base_lambda=1000):
self.model = model
self.base_lambda = base_lambda
self.task_fishers = {}
self.task_params = {}
self.ema_fisher = None # 指数移动平均Fisher
self.alpha = 0.9 # 衰减率
def update_ema_fisher(self, fisher):
"""
更新指数移动平均Fisher
用于在线EWC
"""
if self.ema_fisher is None:
self.ema_fisher = fisher
else:
for name in fisher:
self.ema_fisher[name] = (
self.alpha * self.ema_fisher[name] +
(1 - self.alpha) * fisher[name]
)
return self.ema_fisher
def compute_adaptive_lambda(self, task_id, new_task_loader, device='cuda'):
"""
计算自适应正则化强度
λ_t* = sqrt(D_t * D_{t-1})
其中 D_t = ||∇L_t||_{F^{-1}}
"""
# 计算新任务梯度
self.model.zero_grad()
total_grad = None
n_samples = 0
for x, y in new_task_loader:
x, y = x.to(device), y.to(device)
output = self.model(x)
loss = nn.functional.cross_entropy(output, y)
loss.backward()
grad = torch.cat([
p.grad.flatten()
for p in self.model.parameters()
if p.grad is not None
])
if total_grad is None:
total_grad = grad
else:
total_grad += grad
n_samples += x.size(0)
self.model.zero_grad()
grad_norm_sq = (total_grad / n_samples).norm().item() ** 2
# 获取Fisher
if task_id in self.task_fishers:
fisher = self.task_fishers[task_id]
else:
fisher = self.ema_fisher
# 计算D_t
D_t_sq = grad_norm_sq
if fisher is not None:
# 近似计算Mahalanobis距离
param_idx = 0
for name, p in self.model.named_parameters():
if p.requires_grad and name in fisher:
f_diag = fisher[name].flatten()
D_t_sq = (total_grad / n_samples)[param_idx:param_idx+p.numel()] ** 2 / (f_diag + 1e-8)
D_t_sq = D_t_sq.sum().item()
break
# 获取D_{t-1}
if task_id > 0 and (task_id - 1) in self.task_fishers:
D_prev_sq = self._compute_difficulty(task_id - 1)
else:
D_prev_sq = D_t_sq
# 自适应λ
adaptive_lambda = np.sqrt(D_t_sq * D_prev_sq)
return adaptive_lambda
def _compute_difficulty(self, task_id):
"""计算任务难度"""
# 简化:使用Fisher范数
fisher = self.task_fishers[task_id]
total_fisher_norm = 0
for f in fisher.values():
total_fisher_norm += (f ** 2).sum().item()
return total_fisher_norm7. 实践指南
7.1 Fisher计算策略
class EfficientFisherComputation:
"""高效Fisher计算"""
def __init__(self, model):
self.model = model
def compute_diag_fisher(self, dataloader, device='cuda', n_samples=None):
"""
计算Fisher对角元素(存储高效)
时间复杂度: O(N * d)
空间复杂度: O(d)
"""
self.model.eval()
fisher_diag = {}
for n, p in self.model.named_parameters():
if p.requires_grad:
fisher_diag[n] = torch.zeros_like(p)
n_total = 0
for i, (x, y) in enumerate(dataloader):
if n_samples and i >= n_samples:
break
x, y = x.to(device), y.to(device)
self.model.zero_grad()
output = self.model(x)
loss = nn.functional.cross_entropy(output, y)
loss.backward()
for n, p in self.model.named_parameters():
if p.requires_grad and p.grad is not None:
fisher_diag[n] += p.grad.data ** 2
n_total += x.size(0)
# 平均
for n in fisher_diag:
fisher_diag[n] /= n_total
return fisher_diag
def compute_block_diag_fisher(self, dataloader, block_size=100,
device='cuda', n_samples=None):
"""
计算块对角Fisher(精度与效率的权衡)
将参数分成多个块,只计算块内的Fisher
"""
self.model.eval()
params = [p for _, p in self.model.named_parameters() if p.requires_grad]
n_params = sum(p.numel() for p in params)
n_blocks = (n_params + block_size - 1) // block_size
fisher_blocks = {} # {(i,j): block_matrix}
n_total = 0
for i, (x, y) in enumerate(dataloader):
if n_samples and i >= n_samples:
break
x, y = x.to(device), y.to(device)
self.model.zero_grad()
output = self.model(x)
loss = nn.functional.cross_entropy(output, y)
loss.backward()
# 收集梯度
grad = torch.cat([p.grad.flatten() for p in params])
# 更新块Fisher
for block_i in range(n_blocks):
start_i = block_i * block_size
end_i = min(start_i + block_size, n_params)
for block_j in range(block_i, n_blocks):
start_j = block_j * block_size
end_j = min(start_j + block_size, n_params)
block_grad = grad[start_i:end_i]
block_key = (block_i, block_j)
if block_key not in fisher_blocks:
fisher_blocks[block_key] = torch.zeros(
end_i - start_i, end_j - start_j, device=device
)
fisher_blocks[block_key] += torch.outer(
block_grad, block_grad[:end_j-start_j]
)
n_total += x.size(0)
# 平均
for key in fisher_blocks:
fisher_blocks[key] /= n_total
return fisher_blocks7.2 超参数调优
class EWCHyperparameterTuner:
"""EWC超参数调优"""
def __init__(self, model, train_loaders, val_loaders):
self.model = model
self.train_loaders = train_loaders
self.val_loaders = val_loaders
def tune_lambda(self, lambda_candidates, n_trials=3):
"""
网格搜索最优λ
"""
results = []
for lambda_reg in lambda_candidates:
# 多次试验取平均
forgetting_scores = []
for trial in range(n_trials):
# 重置模型
self.model.reset_parameters()
# 训练
ewc = AdaptiveEWC(self.model, lambda_reg)
for t, loader in enumerate(self.train_loaders):
# 更新Fisher
fisher_computer = FisherInformationMatrix(self.model)
fisher = fisher_computer.compute_empirical_fisher(loader)
if t == 0:
ewc.task_fishers[t] = fisher
else:
# 合并Fisher
for name in fisher:
if name in ewc.task_fishers[t-1]:
ewc.task_fishers[t][name] = (
ewc.task_fishers[t-1][name] + fisher[name]
)
else:
ewc.task_fishers[t][name] = fisher[name]
# 存储参数
ewc.task_params[t] = {
n: p.clone() for n, p in self.model.named_parameters()
}
# 训练新任务(简化:只做几步梯度下降)
optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
for _ in range(10):
for x, y in loader:
x, y = x.to('cuda'), y.to('cuda')
optimizer.zero_grad()
loss = nn.functional.cross_entropy(
self.model(x), y
)
loss.backward()
optimizer.step()
# 评估遗忘
forgetting = self._evaluate_forgetting(ewc, self.val_loaders)
forgetting_scores.append(forgetting)
results.append({
'lambda': lambda_reg,
'forgetting': np.mean(forgetting_scores),
'std': np.std(forgetting_scores)
})
# 选择遗忘最小的λ
best = min(results, key=lambda x: x['forgetting'])
return best, results
def _evaluate_forgetting(self, ewc, val_loaders):
"""评估遗忘"""
forgetting = []
for t, loader in enumerate(val_loaders):
# 加载任务t训练后的参数
self._load_params(ewc.task_params[t])
# 评估在任务t上的性能
correct = 0
total = 0
self.model.eval()
with torch.no_grad():
for x, y in loader:
x, y = x.to('cuda'), y.to('cuda')
pred = self.model(x).argmax(dim=-1)
correct += (pred == y).sum().item()
total += y.size(0)
accuracy = correct / total
forgetting.append(1 - accuracy)
return forgetting8. 总结
核心定理
| 定理 | 内容 | 实践意义 |
|---|---|---|
| 定理1 | Fisher-Hessian关系 | Fisher可作为曲率的上界估计 |
| 定理5 | EWC正则化效果 | 参数偏移与Fisher对角元素的关系 |
| 定理7 | 收敛速率 | 收敛速度与 成正比 |
| 定理9 | 单任务遗忘上界 | 遗忘被 控制 |
| 定理11 | 最优λ选择 | λ与梯度Mahalanobis距离相关 |
实践建议
- Fisher计算:使用对角近似以节省存储,关注数值稳定性
- λ选择:从 - 范围开始,根据遗忘情况调整
- 在线EWC:使用EMA更新Fisher,平衡历史信息与当前信息
- 自适应λ:根据任务难度动态调整正则化强度
理论启示
- EWC的有效性来源于Fisher对参数重要性的估计
- 遗忘量被正则化强度控制,存在理论下界
- 任务相似性影响Fisher的结构,进而影响EWC效果
参考资料
相关阅读:
- 基于正则化的持续学习方法 — EWC等正则化方法的实践指南
- 灾难性遗忘数学理论 — 遗忘的数学机制
- 持续学习泛化理论 — PAC-Bayes框架下的泛化分析
- 信息瓶颈与持续学习 — EWC的IB理论解释