简介

经验回放(Experience Replay, ER)是持续学习领域最简单有效的对抗灾难性遗忘的方法之一。其核心思想是维护一个记忆缓冲区,存储来自旧任务的代表性样本,在学习新任务时一并训练这些样本。然而,为什么回放有效?需要多少旧样本才能保证不遗忘?缓冲区应该如何设计?这些问题缺乏严格的理论分析。本文建立记忆回放的系统理论框架,从信息论、泛化界和优化动力学三个角度深入分析回放的有效性。123


1. 回放的问题设置

1.1 标准回放设置

考虑 个顺序任务 ,每个任务有 个样本。回放方法维护一个缓冲区 ,最多存储 个样本。

学习目标:在学完任务 后,模型参数 应在所有已完成任务上表现良好:

约束:训练过程中只能访问

1.2 回放损失的数学形式

标准回放损失

其中 是缓冲区样本的经验分布, 是回放强度参数。

均匀回放 vs 优先级回放

  • 均匀回放:
  • 优先级回放:
import torch
import torch.nn as nn
import numpy as np
from collections import deque
import random
 
class ReplayBuffer:
    """
    经验回放缓冲区
    
    支持均匀回放和优先级回放
    """
    
    def __init__(self, capacity=1000, priority_mode='uniform'):
        """
        Args:
            capacity: 缓冲区最大容量
            priority_mode: 'uniform' 或 'priority'
        """
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
        self.priority_mode = priority_mode
        self.priorities = deque(maxlen=capacity)
    
    def add(self, state, action, reward, next_state, done, priority=None):
        """添加样本到缓冲区"""
        self.buffer.append((state, action, reward, next_state, done))
        
        if self.priority_mode == 'priority':
            # 优先级模式下使用TD误差作为优先级
            self.priorities.append(priority if priority is not None else 1.0)
        else:
            self.priorities.append(1.0)
    
    def add_batch(self, samples, priorities=None):
        """批量添加样本"""
        for i, sample in enumerate(samples):
            if isinstance(sample, tuple):
                self.add(*sample, priority=priorities[i] if priorities else None)
            else:
                self.add(sample, priority=priorities[i] if priorities else None)
    
    def sample(self, batch_size, device='cuda'):
        """从缓冲区采样"""
        if self.priority_mode == 'priority':
            # 加权采样
            probs = np.array(self.priorities)
            probs = probs / probs.sum()
            indices = np.random.choice(
                len(self.buffer), 
                size=min(batch_size, len(self.buffer)), 
                replace=False, 
                p=probs
            )
            batch = [self.buffer[i] for i in indices]
        else:
            # 均匀采样
            batch = random.sample(
                list(self.buffer), 
                min(batch_size, len(self.buffer))
            )
        
        # 解包batch
        if len(batch[0]) == 5:  # (s, a, r, s', done) 格式
            states, actions, rewards, next_states, dones = zip(*batch)
            states = torch.FloatTensor(np.array(states)).to(device)
            actions = torch.LongTensor(actions).to(device)
            rewards = torch.FloatTensor(rewards).to(device)
            next_states = torch.FloatTensor(np.array(next_states)).to(device)
            dones = torch.FloatTensor(dones).to(device)
            return states, actions, rewards, next_states, dones
        else:
            return torch.FloatTensor(np.array(batch)).to(device)
    
    def sample_balanced(self, batch_size, task_samples, device='cuda'):
        """
        类别平衡采样
        
        确保每个类别在batch中均匀分布
        """
        # 分离不同类别的样本
        samples_by_class = {}
        for i, (_, y) in enumerate(self.buffer):
            if y not in samples_by_class:
                samples_by_class[y] = []
            samples_by_class[y].append(self.buffer[i])
        
        # 从每个类别均匀采样
        n_classes = len(samples_by_class)
        samples_per_class = batch_size // n_classes
        
        batch = []
        for y, samples in samples_by_class.items():
            n_sample = min(samples_per_class, len(samples))
            batch.extend(random.sample(samples, n_sample))
        
        # 剩余样本随机填充
        if len(batch) < batch_size:
            remaining = batch_size - len(batch)
            other_samples = [
                s for s in self.buffer 
                if s not in batch
            ]
            batch.extend(random.sample(other_samples, min(remaining, len(other_samples))))
        
        states, _, _, _, _ = zip(*batch)
        return torch.FloatTensor(np.array(states)).to(device)
    
    def update_priorities(self, indices, td_errors):
        """更新优先级(用于优先级回放)"""
        for idx, error in zip(indices, td_errors):
            self.priorities[idx] = abs(error) + 1e-5  # 避免零优先级
    
    def __len__(self):
        return len(self.buffer)
    
    def is_full(self):
        return len(self.buffer) >= self.capacity

2. 回放有效性的信息论分析

2.1 互信息视角

定义1(任务-表示互信息)

其中 是网络中间层表示, 是任务 的标签。

定理1(回放增加互信息):设 是缓冲区样本分布。则:

其中 是分布差异的熵度量。

含义:回放通过引入旧任务的分布信息,帮助网络保持对旧任务的表示能力。

2.2 压缩-遗忘权衡

定理2(回放的率失真界):设 是网络的信息容量。则存在最优回放比例 使得:

:最优 满足:

2.3 信息瓶颈与回放

定理3(IB框架下的回放)1:设 是当前任务数据, 是缓冲区数据。则回放的目标可形式化为:

其中第二项是通过回放保持旧任务的信息。

class InformationReplayBuffer:
    """
    基于信息论的回放缓冲区
    """
    
    def __init__(self, model, capacity=1000):
        self.model = model
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
        self.information_scores = {}
    
    def compute_mutual_information(self, samples, device='cuda'):
        """
        估计样本的互信息 I(Y; Z)
        
        使用基于梯度的估计
        """
        self.model.eval()
        mutual_infos = []
        
        with torch.no_grad():
            for x, y in samples:
                x = x.to(device)
                
                # 获取表示
                z = self.model.get_embedding(x)
                
                # 估计 I(Y; Z) ≈ -dKL(p(y|z) || p(y))
                logits = self.model.classifier(z)
                probs = torch.softmax(logits, dim=-1)
                
                # 熵
                entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=-1)
                
                # 条件熵估计
                mutual_info = entropy.mean().item()
                mutual_infos.append(mutual_info)
        
        return mutual_infos
    
    def compute_distributional_divergence(self, new_samples, buffer_samples, 
                                        device='cuda'):
        """
        估计新旧分布的散度 D(p_new || p_buffer)
        
        使用最大均值差异(MMD)近似
        """
        self.model.eval()
        
        with torch.no_grad():
            # 获取新旧样本的嵌入
            z_new = []
            for x, _ in new_samples:
                x = x.to(device)
                z = self.model.get_embedding(x)
                z_new.append(z.cpu())
            z_new = torch.cat(z_new, dim=0)
            
            z_buffer = []
            for x, _ in buffer_samples:
                x = x.to(device)
                z = self.model.get_embedding(x)
                z_buffer.append(z.cpu())
            z_buffer = torch.cat(z_buffer, dim=0)
        
        # MMD估计
        def rbf_kernel(x, y, sigma=1.0):
            diff = x.unsqueeze(1) - y.unsqueeze(0)
            return torch.exp(-torch.sum(diff ** 2, dim=-1) / (2 * sigma ** 2))
        
        mmd = (
            rbf_kernel(z_new, z_new).mean() + 
            rbf_kernel(z_buffer, z_buffer).mean() - 
            2 * rbf_kernel(z_new, z_buffer).mean()
        )
        
        return mmd.item()
    
    def select_diverse_samples(self, candidates, n_select, device='cuda'):
        """
        选择信息量最大的多样化样本
        """
        # 计算每个候选样本的信息分数
        mutual_infos = self.compute_mutual_information(candidates, device)
        
        # 计算候选样本之间的多样性
        diversities = []
        with torch.no_grad():
            embeddings = []
            for x, _ in candidates:
                x = x.to(device)
                z = self.model.get_embedding(x)
                embeddings.append(z.cpu())
            embeddings = torch.cat(embeddings, dim=0)
        
        # 选择多样性最高的样本
        selected_indices = []
        remaining = list(range(len(candidates)))
        
        for _ in range(n_select):
            if not remaining:
                break
            
            # 计算剩余样本的多样性
            best_idx = None
            best_score = -float('inf')
            
            for idx in remaining:
                score = mutual_infos[idx]
                
                # 惩罚与已选样本过于相似的
                if selected_indices:
                    selected_embs = embeddings[selected_indices]
                    dist = torch.norm(
                        embeddings[idx] - selected_embs.mean(dim=0)
                    ).item()
                    score += 0.1 * dist  # 多样性奖励
                
                if score > best_score:
                    best_score = score
                    best_idx = idx
            
            if best_idx is not None:
                selected_indices.append(best_idx)
                remaining.remove(best_idx)
        
        return [candidates[i] for i in selected_indices]

3. 样本复杂度分析

3.1 单任务样本复杂度

定理4(单任务回放样本界)2:设缓冲区包含 个来自任务 的样本。以概率至少

其中 是参数维度, 是任务 的原始样本数。

3.2 多任务累积样本界

定理5(累积样本复杂度):设缓冲区总容量为 ,任务数为 。则在学完所有任务后,任务 的泛化误差满足:

其中 是分配给任务 的缓冲区大小。

3.3 最优缓冲区分配

定理6(最优分配):为最小化累积泛化误差,最优缓冲区分配满足:

其中 是任务 相对于其他任务的遗忘敏感性。

均匀分配 vs 比例分配

策略优势劣势
均匀分配实现简单未考虑任务重要性
比例分配任务重要性加权需要估计权重
自适应分配根据遗忘动态调整计算开销大
class OptimalBufferAllocator:
    """最优缓冲区分配器"""
    
    def __init__(self, total_capacity):
        self.total_capacity = total_capacity
        self.task_sizes = {}  # 每个任务的缓冲区大小
        self.task_weights = {}  # 每个任务的权重
    
    def compute_uniform_allocation(self, n_tasks):
        """均匀分配"""
        size_per_task = self.total_capacity // n_tasks
        for t in range(n_tasks):
            self.task_sizes[t] = size_per_task
        return self.task_sizes
    
    def compute_proportional_allocation(self, task_sample_counts):
        """
        按样本数比例分配
        """
        total_samples = sum(task_sample_counts.values())
        
        for t, count in task_sample_counts.items():
            self.task_sizes[t] = int(
                self.total_capacity * count / total_samples
            )
        
        return self.task_sizes
    
    def compute_sensitivity_weighted_allocation(self, task_sample_counts,
                                               sensitivity_scores):
        """
        基于遗忘敏感性的加权分配
        
        B_i* ∝ m_i * sqrt(1/ω_i)
        """
        # 计算归一化权重
        total_weight = 0
        for t in task_sample_counts:
            weight = task_sample_counts[t] / np.sqrt(sensitivity_scores[t] + 1e-8)
            self.task_weights[t] = weight
            total_weight += weight
        
        # 分配
        for t in task_sample_counts:
            self.task_sizes[t] = int(
                self.total_capacity * self.task_weights[t] / total_weight
            )
        
        return self.task_sizes
    
    def estimate_sensitivity(self, model, task_loader, prev_task_loaders,
                           device='cuda'):
        """
        估计每个任务的遗忘敏感性
        
        敏感性 = 学习新任务后该任务性能的下降
        """
        sensitivity_scores = {}
        
        for t, loader in enumerate(task_loaders):
            # 记录任务t在当前模型上的性能
            perf_before = self._evaluate_task(model, loader, device)
            
            # 模拟学习一个新任务(简化:使用随机梯度)
            self._simulate_learning(model, device)
            
            # 记录任务t在学习后的性能
            perf_after = self._evaluate_task(model, loader, device)
            
            # 敏感性 = 性能下降
            sensitivity_scores[t] = perf_before - perf_after
        
        return sensitivity_scores
    
    def _evaluate_task(self, model, loader, device):
        """评估任务性能"""
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for x, y in loader:
                x, y = x.to(device), y.to(device)
                pred = model(x).argmax(dim=-1)
                correct += (pred == y).sum().item()
                total += y.size(0)
        
        return correct / total if total > 0 else 0
    
    def _simulate_learning(self, model, device, n_steps=10):
        """模拟一步学习"""
        model.train()
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        
        # 使用随机数据模拟
        x = torch.randn(32, 3, 32, 32, device=device)
        y = torch.randint(0, 10, (32,), device=device)
        
        for _ in range(n_steps):
            optimizer.zero_grad()
            loss = nn.functional.cross_entropy(model(x), y)
            loss.backward()
            optimizer.step()

4. 缓冲区容量与遗忘的关系

4.1 容量下界

定理7(防止遗忘的容量下界)3:设任务 的样本复杂度分别为 。为防止任务 被遗忘,缓冲区必须至少包含:

其中 是任务 可接受的最大遗忘量。

4.2 容量-遗忘权衡曲线

定理8(指数遗忘模型):设缓冲区容量为 ,任务数为 。则任务 的遗忘量满足:

其中 是与任务结构相关的常数。

含义

  • 缓冲区比例 越大,遗忘越小
  • 任务间隔 越大(越旧的任务),遗忘越大

4.3 临界容量

定义2(临界容量 :使遗忘量低于阈值 所需的最小缓冲区大小:

class ForgettingCapacityModel:
    """
    遗忘-容量模型
    """
    
    def __init__(self, alpha=1.0, beta=0.5):
        """
        Args:
            alpha: 缓冲区比例对遗忘的影响系数
            beta: 任务间隔对遗忘的影响系数
        """
        self.alpha = alpha
        self.beta = beta
        self.forgetting_history = []
    
    def predict_forgetting(self, buffer_ratio, task_gap):
        """
        预测遗忘量
        
        ΔL ≈ exp(-α * B/m) * exp(-β * gap)
        
        Args:
            buffer_ratio: B_i / m_i,缓冲区与样本数的比例
            task_gap: T - i,任务间隔
        
        Returns:
            predicted_forgetting: 预测的遗忘量
        """
        forgetting = np.exp(-self.alpha * buffer_ratio) * np.exp(
            -self.beta * task_gap
        )
        return forgetting
    
    def estimate_critical_capacity(self, initial_forgetting, threshold, 
                                  task_gap=0):
        """
        估计临界容量
        
        B_c = (1/α) * ln(ΔL_0 / ε) * m
        """
        if initial_forgetting <= threshold:
            return 0
        
        buffer_ratio = (1 / self.alpha) * np.log(initial_forgetting / threshold)
        
        return buffer_ratio
    
    def fit_model(self, observed_data):
        """
        从观察数据拟合模型参数
        
        observed_data: [(buffer_ratio, task_gap, forgetting), ...]
        """
        import scipy.optimize as opt
        
        def loss(params):
            alpha, beta = params
            total_loss = 0
            for ratio, gap, forgetting in observed_data:
                pred = np.exp(-alpha * ratio) * np.exp(-beta * gap)
                total_loss += (pred - forgetting) ** 2
            return total_loss
        
        # 优化
        result = opt.minimize(
            loss, 
            x0=[self.alpha, self.beta],
            bounds=[(0.1, 10), (0.01, 5)]
        )
        
        self.alpha, self.beta = result.x
        
        return self.alpha, self.beta
    
    def compute_capacity_schedule(self, total_capacity, n_tasks, 
                                 task_sample_counts):
        """
        计算容量分配计划
        
        使各任务的遗忘量均匀
        """
        # 计算各任务的基础遗忘量
        base_forgetting = np.exp(-self.beta * np.arange(n_tasks))
        
        # 逆推所需容量
        capacities = []
        for t in range(n_tasks):
            # 使遗忘量小于阈值
            target_forgetting = 0.1  # 10%
            
            # B/m = (1/α) * ln(ΔL_base / target)
            if base_forgetting[t] > target_forgetting:
                ratio = (1 / self.alpha) * np.log(
                    base_forgetting[t] / target_forgetting
                )
                capacity = int(ratio * task_sample_counts[t])
            else:
                capacity = 0
            
            capacities.append(capacity)
        
        # 归一化到总容量
        total_needed = sum(capacities)
        if total_needed > total_capacity:
            # 按比例缩减
            scale = total_capacity / total_needed
            capacities = [int(c * scale) for c in capacities]
        
        return capacities

5. 最优回放策略

5.1 随机回放 vs 确定性回放

定理9(随机回放的优势):随机回放通过在参数空间中引入噪声,有助于逃离局部极小:

5.2 优先级回放的理论基础

定理10(基于不确定性的优先级):设样本 的不确定性为 。则最优采样概率满足:

时退化为均匀采样,当 时专注于高不确定性样本。

5.3 平衡回放

定理11(类别平衡回放):为防止类别不平衡导致的遗忘,最优策略是均匀采样每个类别:

class OptimalReplayStrategy:
    """
    最优回放策略
    """
    
    def __init__(self, model, buffer_capacity):
        self.model = model
        self.buffer_capacity = buffer_capacity
    
    def uncertainty_sampling(self, candidates, device='cuda'):
        """
        基于不确定性的优先级采样
        
        p*(x) ∝ u(x)^(1/(1+β))
        """
        self.model.eval()
        uncertainties = []
        
        with torch.no_grad():
            for x, _ in candidates:
                x = x.to(device)
                logits = self.model(x)
                probs = torch.softmax(logits, dim=-1)
                
                # 预测熵作为不确定性
                entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=-1)
                uncertainties.append(entropy.mean().item())
        
        uncertainties = np.array(uncertainties)
        
        # 转换为概率
        beta = 0.5  # 超参数
        weights = uncertainties ** (1 / (1 + beta))
        probs = weights / weights.sum()
        
        return probs
    
    def diversity_sampling(self, candidates, n_select, device='cuda'):
        """
        基于多样性的采样
        
        选择与已选样本距离最远的样本
        """
        self.model.eval()
        
        with torch.no_grad():
            # 获取所有候选的嵌入
            embeddings = []
            for x, _ in candidates:
                x = x.to(device)
                z = self.model.get_embedding(x)
                embeddings.append(z.cpu())
            embeddings = torch.stack(embeddings, dim=0)
        
        n_candidates = len(candidates)
        selected = []
        remaining = list(range(n_candidates))
        
        for _ in range(n_select):
            if not remaining:
                break
            
            best_score = -float('inf')
            best_idx = None
            
            for idx in remaining:
                if not selected:
                    score = 0  # 第一个样本随机选
                else:
                    # 与已选样本的平均距离
                    selected_embs = embeddings[selected]
                    dist = torch.norm(
                        embeddings[idx].unsqueeze(0) - selected_embs.mean(dim=0, keepdim=True)
                    ).item()
                    score = dist
                
                if score > best_score:
                    best_score = score
                    best_idx = idx
            
            selected.append(best_idx)
            remaining.remove(best_idx)
        
        return [candidates[i] for i in selected]
    
    def herding_selection(self, dataset, n_select, device='cuda'):
        """
        Herding采样:选择最接近类别均值的样本
        
        论文: T DEFAULT et al. "Completing the Loop: 
              Good Data-driven ERC with Human-robot Interaction"
        """
        self.model.eval()
        
        # 获取所有样本的嵌入
        all_embeddings = []
        all_labels = []
        
        with torch.no_grad():
            for x, y in dataset:
                x = x.to(device)
                z = self.model.get_embedding(x)
                all_embeddings.append(z.cpu())
                all_labels.append(y)
        
        all_embeddings = torch.stack(all_embeddings, dim=0)
        all_labels = torch.tensor(all_labels)
        
        # 计算每个类别的均值
        classes = torch.unique(all_labels)
        class_means = {}
        
        for c in classes:
            mask = all_labels == c
            class_means[c.item()] = all_embeddings[mask].mean(dim=0)
        
        # Herding选择
        selected = []
        class_counts = {c.item(): 0 for c in classes}
        class_targets = {c.item(): n_select // len(classes) for c in classes}
        
        running_means = {c.item(): torch.zeros_like(class_means[c.item()]) 
                       for c in classes}
        
        indices = list(range(len(dataset)))
        random.shuffle(indices)
        
        for idx in indices:
            c = all_labels[idx].item()
            
            if class_counts[c] >= class_targets[c]:
                continue
            
            # 检查添加后是否更接近目标均值
            old_mean = running_means[c]
            new_mean = (running_means[c] * class_counts[c] + all_embeddings[idx]) / (
                class_counts[c] + 1
            )
            
            # 距离改善
            old_dist = torch.norm(running_means[c] - class_means[c])
            new_dist = torch.norm(new_mean - class_means[c])
            
            if new_dist < old_dist or class_counts[c] < class_targets[c] * 0.5:
                selected.append(idx)
                running_means[c] = new_mean
                class_counts[c] += 1
        
        return [dataset[i] for i in selected]

6. 梯度约束回放的理论分析

6.1 A-GEM的泛化界

A-GEM(平均梯度Episodic Memory):约束当前任务梯度与旧任务梯度的夹角。

损失函数

定理12(A-GEM遗忘界)3:设 是参考梯度(来自缓冲区)。则A-GEM的遗忘满足:

其中 是梯度夹角。

6.2 梯度投影的分析

定理13(投影梯度下降的收敛性):设投影后的梯度为 ,其中 。则:

时单调下降。

class AGRESSiveReplay:
    """
    A-GEM + 经验回放的结合
    """
    
    def __init__(self, model, buffer_capacity, device='cuda'):
        self.model = model
        self.buffer = ReplayBuffer(buffer_capacity)
        self.device = device
    
    def compute_reference_gradient(self, buffer_samples):
        """
        计算参考梯度(来自缓冲区)
        """
        self.model.zero_grad()
        
        total_loss = 0
        for x, y in buffer_samples:
            x, y = x.to(self.device), y.to(self.device)
            output = self.model(x)
            loss = nn.functional.cross_entropy(output, y)
            total_loss += loss
        
        total_loss.backward()
        
        # 收集梯度
        ref_grad = torch.cat([
            p.grad.flatten() 
            for p in self.model.parameters() 
            if p.grad is not None
        ])
        
        return ref_grad
    
    def project_gradient(self, grad, ref_grad, eps=1e-3):
        """
        梯度投影
        
        如果当前梯度与参考梯度夹角为钝角,则投影到约束边界
        """
        dot_product = torch.dot(grad, ref_grad)
        
        if dot_product >= -eps:
            # 夹角为锐角或直角,不需要投影
            return grad
        
        # 投影到与参考梯度正交的超平面
        # g' = g - <g, g_ref> / <g_ref, g_ref> * g_ref
        projected_grad = grad - (dot_product / (ref_grad.norm() ** 2 + 1e-8)) * ref_grad
        
        return projected_grad
    
    def apply_projected_gradient(self, loss, ref_grad, lr=0.1):
        """
        应用投影梯度更新
        """
        self.model.zero_grad()
        loss.backward()
        
        # 获取当前梯度
        current_grad = torch.cat([
            p.grad.flatten() 
            for p in self.model.parameters() 
            if p.grad is not None
        ])
        
        # 投影
        projected_grad = self.project_gradient(current_grad, ref_grad)
        
        # 反向填充投影后的梯度
        param_idx = 0
        for p in self.model.parameters():
            if p.requires_grad:
                numel = p.numel()
                p.grad = projected_grad[param_idx:param_idx+numel].reshape(p.shape)
                param_idx += numel
        
        # 更新
        for p in self.model.parameters():
            if p.grad is not None:
                with torch.no_grad():
                    p -= lr * p.grad
        
        return projected_grad.norm().item()
    
    def train_step(self, current_batch, buffer_batch, lr=0.1):
        """
        一步训练
        """
        # 计算当前任务损失
        x_curr, y_curr = current_batch
        x_curr, y_curr = x_curr.to(self.device), y_curr.to(self.device)
        
        loss_curr = nn.functional.cross_entropy(
            self.model(x_curr), y_curr
        )
        
        # 计算参考梯度
        ref_grad = self.compute_reference_gradient(buffer_batch)
        
        # 应用投影梯度
        grad_norm = self.apply_projected_gradient(loss_curr, ref_grad, lr)
        
        return loss_curr.item(), grad_norm

7. 生成回放的理论分析

7.1 生成回放的问题设置

生成回放:使用生成模型(如GAN、VAE)生成旧任务的伪样本,而非存储真实样本。

优势

  • 可生成无限多样化的样本
  • 节省存储空间

挑战

  • 生成质量影响效果
  • 额外的计算开销

7.2 生成回放的泛化界

定理14(生成回放的泛化界):设 是生成模型分布, 是缓冲区分布。则:

其中 是在生成样本上的经验损失。

7.3 生成质量的要求

定理15(生成质量阈值):设 是目标遗忘量。则生成模型必须满足:

class GenerativeReplayBuffer:
    """
    生成回放缓冲区
    """
    
    def __init__(self, generator, buffer_capacity, device='cuda'):
        self.generator = generator  # 生成模型
        self.buffer_capacity = buffer_capacity
        self.device = device
        
        # 存储少量真实样本作为锚点
        self.anchor_samples = deque(maxlen=buffer_capacity // 10)
    
    def add_anchor_samples(self, samples):
        """添加真实样本作为锚点"""
        for sample in samples:
            self.anchor_samples.append(sample)
    
    def generate_samples(self, n_samples, class_label=None):
        """
        生成伪样本
        
        如果提供了class_label,生成指定类别的样本
        """
        generated = []
        
        with torch.no_grad():
            for _ in range(n_samples):
                # 采样噪声
                z = torch.randn(1, self.generator.latent_dim, device=self.device)
                
                # 条件生成(如果有条件信息)
                if class_label is not None:
                    y = torch.tensor([class_label], device=self.device)
                    generated_sample = self.generator(z, y)
                else:
                    generated_sample = self.generator(z)
                
                generated.append(generated_sample.cpu())
        
        return generated
    
    def get_replay_samples(self, n_samples, class_balanced=True):
        """
        获取回放样本(生成 + 锚点)
        """
        # 生成样本
        n_generated = n_samples - len(self.anchor_samples)
        if n_generated > 0:
            generated = self.generate_samples(n_generated)
        else:
            generated = []
        
        # 混合生成样本和锚点
        replay_samples = list(self.anchor_samples) + generated
        
        if class_balanced:
            # 类别平衡重采样
            samples_by_class = {}
            for x, y in replay_samples:
                if y not in samples_by_class:
                    samples_by_class[y] = []
                samples_by_class[y].append((x, y))
            
            # 从每个类别均匀采样
            n_per_class = n_samples // len(samples_by_class)
            balanced_samples = []
            
            for y, samples in samples_by_class.items():
                balanced_samples.extend(random.sample(samples, min(n_per_class, len(samples))))
            
            return balanced_samples
        else:
            return random.sample(replay_samples, min(n_samples, len(replay_samples)))

8. 总结

核心定理

定理内容实践意义
定理1回放增加互信息回放帮助保持旧任务表示
定理4单任务样本界缓冲区大小影响泛化
定理7容量下界防止遗忘所需的最小容量
定理8指数遗忘模型遗忘与容量、间隔的关系
定理10优先级采样不确定性采样优于均匀采样
定理12A-GEM遗忘界梯度约束的有效性

实践建议

  1. 缓冲区容量:使用定理7估计临界容量
  2. 采样策略:不确定性采样 + 多样性采样组合
  3. 类别平衡:防止类别不平衡导致的遗忘
  4. 生成回放:当存储受限且生成质量足够时使用

理论启示

  • 回放的有效性来源于信息保持和梯度约束
  • 样本复杂度和缓冲区容量存在明确的下界
  • 最优策略取决于任务结构和计算资源

参考资料


相关阅读

Footnotes

  1. Lopez-Paz & Ranzato (2017). Gradient episodic memory for continual learning. NeurIPS. 2

  2. Chaudhry et al. (2019). Efficient lifelong learning with A-GEM. ICLR. 2

  3. Rolnick et al. (2019). Experience replay for continual learning. NeurIPS. 2 3