元学习中的优化方法

元学习中的”学习如何学习”本质上是一个优化问题。本文档探讨基于优化的元学习方法,重点关注如何设计更好的优化器来适应新任务。

优化视角回顾

三种元学习范式的优化视角

范式优化对象方法
度量学习嵌入空间学习距离函数
模型学习网络结构快速参数更新
优化学习优化器本身学习优化器

双层优化的挑战

元学习的核心是双层优化(Bi-Level Optimization):


LSTM元学习器

核心思想

使用LSTM作为元学习器,学习如何更新神经网络的参数:

为什么用LSTM?

LSTM的门控机制类似于优化器的行为:

  • 遗忘门:决定保留多少历史梯度信息
  • 输入门:控制新梯度的影响
  • 输出门:决定更新方向

代码实现

class LSTMOptimizer(nn.Module):
    """
    LSTM元学习器
    
    学习一个LSTM来预测参数更新
    """
    def __init__(self, model_dim, hidden_dim=20):
        super().__init__()
        
        # 元学习器的LSTM
        self.lstm = nn.LSTMCell(input_size=model_dim * 2, hidden_size=hidden_dim)
        
        # 输出层:预测更新方向和幅度
        self.output = nn.Linear(hidden_dim, 2 * model_dim)  # 方向和log lr
        
        self.model_dim = model_dim
        self.hidden_dim = hidden_dim
    
    def forward(self, grads, params, state=None):
        """
        Args:
            grads: 梯度向量 (d,)
            params: 参数向量 (d,)
            state: (h, c) LSTM状态
        
        Returns:
            update: 更新向量 (d,)
            new_state: 更新后的LSTM状态
        """
        # 拼接梯度和参数
        x = torch.cat([grads, params], dim=-1)
        
        # LSTM前向
        if state is None:
            batch_size = grads.size(0)
            h = torch.zeros(batch_size, self.hidden_dim, device=grads.device)
            c = torch.zeros(batch_size, self.hidden_dim, device=grads.device)
        else:
            h, c = state
        
        h_new, c_new = self.lstm(x, (h, c))
        
        # 预测更新
        out = self.output(h_new)
        update, log_lr = out.chunk(2, dim=-1)
        
        # 缩放更新
        lr = torch.exp(log_lr)
        update = torch.tanh(update) * lr
        
        return update, (h_new, c_new)
 
 
class MetaOptimizer:
    """
    基于LSTM的元优化器
    """
    def __init__(self, model, lr_meta=0.001):
        self.model = model
        self.lr_meta = lr_meta
        
        # 元学习器
        self.optimizer_lstm = LSTMOptimizer(
            model_dim=sum(p.numel() for p in model.parameters()),
            hidden_dim=32
        )
        self.optimizer = torch.optim.Adam(
            self.optimizer_lstm.parameters(),
            lr=lr_meta
        )
    
    def step(self, loss, retain_graph=False):
        """
        执行元优化步骤
        """
        self.optimizer.zero_grad()
        
        # 获取当前参数和梯度
        params = torch.cat([p.data.view(-1) for p in self.model.parameters()])
        grads = torch.cat([
            p.grad.view(-1) for p in self.model.parameters() 
            if p.grad is not None
        ])
        
        # LSTM更新
        update, _ = self.optimizer_lstm(grads, params)
        
        # 应用更新
        offset = 0
        for p in self.model.parameters():
            numel = p.numel()
            p.data.add_(update[offset:offset+numel].view(p.shape))
            offset += numel
        
        return update.norm().item()

近似方法

一阶梯度估计

移位函数

不使用显式梯度,使用移位函数(Shift Function)近似:

其中 是学习的移位函数。

REINFORCE

使用REINFORCE估计梯度:

二阶梯度的近似

对角近似

忽略Hessian矩阵的非对角元素:

K-FAC近似

使用Kronecker分解:

其中 分别是输入和输出的协方差矩阵。


任务自适应学习率

核心思想

不同任务可能需要不同的学习率。MAML++ 等方法引入了任务自适应学习率

实现方式

class TaskAdaptiveLR(nn.Module):
    """
    任务自适应学习率
    
    为每个任务学习个性化的学习率缩放
    """
    def __init__(self, model_dim, num_tasks):
        super().__init__()
        # 每个任务的学习率缩放因子
        self.lr_scalers = nn.Embedding(num_tasks, model_dim)
        nn.init.ones_(self.lr_scalers.weight)
    
    def get_lr(self, task_id):
        """获取特定任务的学习率"""
        return self.lr_scalers(task_id)
 
 
class MAMLPlusPlus(nn.Module):
    """
    MAML++ 实现
    
    改进点:
    1. 任务自适应学习率
    2. 多步损失
    3. 余弦衰减
    """
    def __init__(self, model, num_tasks=1000):
        super().__init__()
        self.model = model
        self.task_lr = TaskAdaptiveLR(
            model_dim=sum(p.numel() for p in model.parameters()),
            num_tasks=num_tasks
        )
        
        # 元学习率
        self.meta_lr = 0.001
        self.num_inner_steps = 5
    
    def inner_update(self, support_x, support_y, task_id, first_order=True):
        """
        任务内更新(带自适应学习率)
        """
        params = {name: p.clone() for name, p in self.model.named_parameters()}
        lr_scalers = self.task_lr.get_lr(task_id)
        lr_base = 0.01
        
        losses = []
        for step in range(self.num_inner_steps):
            # 前向
            logits = self._forward(params, support_x)
            loss = F.cross_entropy(logits, support_y)
            losses.append(loss)
            
            # 梯度
            grads = torch.autograd.grad(
                loss, params.values(),
                retain_graph=(step < self.num_inner_steps - 1)
            )
            
            # 自适应学习率更新
            offset = 0
            for name, param in params.items():
                numel = param.numel()
                lr = lr_base * lr_scalers[offset:offset+numel].mean()
                param.data.sub_(lr * grads[list(params.keys()).index(name)])
                offset += numel
        
        # 多步损失
        return sum(losses) / len(losses)
    
    def meta_update(self, tasks):
        """
        元更新
        """
        meta_grads = None
        
        for task_id, (support_x, support_y, query_x, query_y) in enumerate(tasks):
            # 内层更新
            inner_loss = self.inner_update(support_x, support_y, task_id)
            
            # 外层梯度
            grads = torch.autograd.grad(
                inner_loss,
                self.model.parameters(),
                retain_graph=True
            )
            
            if meta_grads is None:
                meta_grads = grads
            else:
                meta_grads = [g1 + g2 for g1, g2 in zip(meta_grads, grads)]
        
        # 更新
        with torch.no_grad():
            for param, grad in zip(self.model.parameters(), meta_grads):
                param -= self.meta_lr * grad / len(tasks)

无梯度方法

进化策略

使用进化算法优化元参数:

class EvolutionaryMetaLearning:
    """
    基于进化策略的元学习
    """
    def __init__(self, model, pop_size=50, lr=0.01):
        self.model = model
        self.pop_size = pop_size
        self.lr = lr
    
    def evolve_step(self, task):
        """
        一步进化
        """
        # 当前参数
        theta = self._get_params()
        
        # 采样噪声
        noises = [
            {name: torch.randn_like(p) for name, p in self.model.named_parameters()}
            for _ in range(self.pop_size)
        ]
        
        # 评估
        fitness = []
        for noise in noises:
            # 扰动参数
            perturbed = {name: theta[name] + self.lr * noise[name] for name in theta}
            
            # 计算适应度
            fitness.append(self._evaluate(perturbed, task))
        
        # 计算梯度估计
        grad_est = {name: torch.zeros_like(p) for name, p in theta.items()}
        for noise, fit in zip(noises, fitness):
            for name in theta:
                grad_est[name] += fit * noise[name]
        
        grad_est = {name: g / (self.pop_size * self.lr) for name, g in grad_est.items()}
        
        # 更新
        for name in theta:
            theta[name] -= self.lr * grad_est[name]
        
        self._set_params(theta)
        
        return sum(fitness) / len(fitness)
    
    def _get_params(self):
        return {name: p.clone() for name, p in self.model.named_parameters()}
    
    def _set_params(self, params):
        for name, p in self.model.named_parameters():
            p.data.copy_(params[name])
    
    def _evaluate(self, params, task):
        self._set_params(params)
        loss = self._task_loss(task)
        return -loss.item()

参考文献

相关文章