Reptile算法

Reptile由Alex Nichol等人于2018年提出,是一种简化的一阶元学习算法。1 与MAML相比,Reptile去除了复杂的二阶梯度计算,仅通过多次SGD迭代来学习良好的初始化参数。

核心思想

MAML的复杂性

MAML需要计算二阶梯度(或者FOMAML的一阶梯度近似):

其中

Reptile的简化

Reptile的核心观察:从随机初始化的参数出发,朝向某个任务的局部最优参数 移动一步

即:沿着从 的方向更新。


数学推导

优化目标

Reptile最小化:

其中 是任务 上的最优损失。

更新规则

从任务分布中采样任务

  1. 内层循环:对任务 执行 步SGD:

  2. 外层循环:更新初始化参数:

与MAML的关系

展开Reptile的更新:

考虑泰勒展开:

时,Reptile近似FOMAML:

时,Reptile引入高阶项,与MAML更接近。


算法流程

输入:任务分布 p(τ),学习率 α, β,迭代次数 K
输出:最优初始化参数 θ

1: 随机初始化 θ
2: while not converged do
3:   从 p(τ) 采样任务 τ
4:   
5:   // 内层:执行 K 步 SGD
6:   θ₀ = θ
7:   for i = 1 to K do
8:     θᵢ = θᵢ₋₁ - α ∇_θ L_τ(f_{θᵢ₋₁})
9:   end for
10:  θ̃ = θ_K
11:  
12:  // 外层:更新初始化参数
13:  θ = θ + β (θ̃ - θ₀)
14: 
15: end while
16: return θ

与MAML的关键区别

方面MAMLReptile
内层更新精确梯度下降多次SGD
梯度计算需要区分内层/外层普通SGD
二阶梯度MAML需要,FOMAML不需要不需要
任务感知否(批处理相同)

代码实现

基础Reptile实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable, List, Tuple
import copy
 
class Reptile:
    """
    Reptile元学习器
    
    简化的一阶元学习算法
    """
    def __init__(
        self,
        model: nn.Module,
        lr_inner: float = 0.01,
        lr_meta: float = 0.001,
        inner_steps: int = 5,
        device: str = 'cuda'
    ):
        self.model = model
        self.lr_inner = lr_inner
        self.lr_meta = lr_meta
        self.inner_steps = inner_steps
        self.device = device
        self.model.to(device)
        
        # 元优化器
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr_meta)
    
    def inner_loop(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        theta: dict = None
    ) -> dict:
        """
        内层循环:执行多次SGD更新
        
        Args:
            support_x: 支持集特征
            support_y: 支持集标签
            theta: 初始参数(如果为None则使用当前模型参数)
        
        Returns:
            adapted_state: 适应后的参数状态
        """
        # 保存原始参数
        if theta is None:
            theta = {name: param.clone() for name, param in self.model.named_parameters()}
        
        # 克隆模型用于内层更新
        adapted_model = type(self.model)(**self._get_model_config())
        adapted_model.to(self.device)
        adapted_model.load_state_dict({
            name: param.clone() for name, param in theta.items()
        })
        
        # 多次SGD更新
        for _ in range(self.inner_steps):
            # 前向
            logits = adapted_model(support_x)
            loss = F.cross_entropy(logits, support_y)
            
            # 梯度
            grads = torch.autograd.grad(loss, adapted_model.parameters())
            
            # SGD更新
            with torch.no_grad():
                for (name, param), grad in zip(adapted_model.named_parameters(), grads):
                    if grad is not None:
                        param.sub_(self.lr_inner * grad)
        
        return adapted_model.state_dict()
    
    def meta_loss(
        self,
        query_x: torch.Tensor,
        query_y: torch.Tensor,
        adapted_state: dict
    ) -> torch.Tensor:
        """
        计算元损失
        """
        self.model.load_state_dict(adapted_state)
        logits = self.model(query_x)
        return F.cross_entropy(logits, query_y)
    
    def meta_step(
        self,
        tasks: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
        use_reptile: bool = True
    ) -> dict:
        """
        元步骤
        
        Args:
            tasks: 任务列表
            use_reptile: True使用Reptile,False使用FOMAML
        """
        self.optimizer.zero_grad()
        
        meta_loss_sum = 0.0
        
        for support_x, support_y, query_x, query_y in tasks:
            # 移到设备
            support_x = support_x.to(self.device)
            support_y = support_y.to(self.device)
            query_x = query_x.to(self.device)
            query_y = query_y.to(self.device)
            
            if use_reptile:
                # Reptile: 计算初始和最终的参数差异
                init_state = {name: param.clone() for name, param in self.model.named_parameters()}
                adapted_state = self.inner_loop(support_x, support_y, init_state)
                
                # Reptile损失 = 适应后模型在查询集上的损失
                loss = self.meta_loss(query_x, query_y, adapted_state)
                
                # 参数更新将通过Reptile方式(外层循环)
                # 这里只计算损失用于监控
                meta_loss_sum += loss.item()
            else:
                # FOMAML: 直接在适应后参数上计算梯度
                _, adapted_state = self.inner_loop(support_x, support_y)
                loss = self.meta_loss(query_x, query_y, adapted_state)
                
                # 反向传播
                loss.backward()
                meta_loss_sum += loss.item()
        
        if use_reptile:
            # Reptile: 直接更新参数
            self._reptile_update(tasks)
        else:
            # FOMAML: 已有梯度
            self.optimizer.step()
        
        return {'meta_loss': meta_loss_sum / len(tasks)}
    
    def _reptile_update(
        self,
        tasks: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]
    ):
        """
        Reptile参数更新
        
        θ ← θ + β (θ̃ - θ₀)
        """
        init_params = {name: param.clone() for name, param in self.model.named_parameters()}
        
        for support_x, support_y, _, _ in tasks:
            support_x = support_x.to(self.device)
            support_y = support_y.to(self.device)
            
            # 内层更新
            adapted_state = self.inner_loop(support_x, support_y, init_params)
            
            # 累积参数差异
            with torch.no_grad():
                for name, param in self.model.named_parameters():
                    param.add_(adapted_state[name] - init_params[name], alpha=self.lr_meta)
    
    def _get_model_config(self) -> dict:
        """获取模型配置(用于克隆)"""
        # 根据实际模型结构调整
        return {}

简化的Reptile实现(PyTorch风格)

class SimpleReptile(nn.Module):
    """
    简化的Reptile实现(PyTorch原生风格)
    """
    def __init__(self, model_class, model_kwargs, lr_inner=0.01, lr_meta=0.001, inner_steps=5):
        super().__init__()
        self.lr_inner = lr_inner
        self.lr_meta = lr_meta
        self.inner_steps = inner_steps
        
        # 元参数
        self.theta = nn.Parameter(self._init_params(model_kwargs))
        self.model_class = model_class
        self.model_kwargs = model_kwargs
    
    def _init_params(self, kwargs):
        """初始化参数"""
        params = {}
        # 简化的参数初始化
        for name, shape in kwargs.items():
            if isinstance(shape, tuple):
                params[name] = torch.randn(*shape) * 0.01
        return params
    
    def inner_update(self, params, support_x, support_y):
        """内层SGD更新"""
        # 克隆参数
        new_params = {k: v.clone() for k, v in params.items()}
        
        model = self.model_class(**self.model_kwargs)
        model.load_state_dict({k: v for k, v in zip(model.state_dict().keys(), new_params.values())})
        
        for _ in range(self.inner_steps):
            logits = model(support_x)
            loss = F.cross_entropy(logits, support_y)
            grads = torch.autograd.grad(loss, model.parameters())
            
            with torch.no_grad():
                for (name, param), grad in zip(model.named_parameters(), grads):
                    if grad is not None and name in new_params:
                        new_params[name] -= self.lr_inner * grad
        
        return new_params
    
    def forward(self, tasks):
        """
        tasks: [(support_x, support_y, query_x, query_y), ...]
        """
        # 复制初始参数
        theta_0 = {k: v.clone() for k, v in self.theta.items()}
        
        # 对每个任务执行内层更新
        adapted_params_list = []
        for support_x, support_y, _, _ in tasks:
            adapted = self.inner_update(theta_0, support_x, support_y)
            adapted_params_list.append(adapted)
        
        # Reptile更新
        with torch.no_grad():
            grad_direction = torch.zeros_like(self.theta)
            for adapted in adapted_params_list:
                for name in self.theta:
                    grad_direction[name] += (adapted[name] - theta_0[name])
            grad_direction /= len(tasks)
        
        # 更新参数
        with torch.no_grad():
            for name in self.theta:
                self.theta[name] += self.lr_meta * grad_direction[name]

端到端训练示例

import torch
from torch.utils.data import DataLoader, TensorDataset
 
class SimpleCNN(nn.Module):
    """简单CNN用于少样本学习"""
    def __init__(self, in_channels=1, num_classes=5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Linear(64, num_classes)
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)
 
 
def train_reptile():
    """Reptile训练流程"""
    # 配置
    WAY = 5
    SHOT = 5
    QUERY = 15
    EPOCHS = 100
    LR_INNER = 0.01
    LR_META = 0.001
    INNER_STEPS = 5
    
    # 模型
    model = SimpleCNN(in_channels=1, num_classes=WAY)
    reptile = Reptile(
        model=model,
        lr_inner=LR_INNER,
        lr_meta=LR_META,
        inner_steps=INNER_STEPS
    )
    
    # 训练
    for epoch in range(EPOCHS):
        # 生成随机任务(示例)
        support_x = torch.randn(WAY * SHOT, 1, 7, 7)
        support_y = torch.randint(0, WAY, (WAY * SHOT,))
        query_x = torch.randn(WAY * QUERY, 1, 7, 7)
        query_y = torch.randint(0, WAY, (WAY * QUERY,))
        
        # 元训练步
        tasks = [(support_x, support_y, query_x, query_y)]
        loss_dict = reptile.meta_step(tasks, use_reptile=True)
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss={loss_dict['meta_loss']:.4f}")
 
 
if __name__ == '__main__':
    train_reptile()

Reptile vs MAML vs FOMAML

算法对比

特性MAMLFOMAMLReptile
梯度类型二阶一阶(近似)零阶/伪梯度
内层计算1步梯度下降1步梯度下降K步SGD
计算量
实现复杂度
收敛速度
最终性能最好接近MAML相当

何时选择Reptile?

选择 Reptile 当:

  • 计算资源有限
  • 需要简单的实现
  • 任务的内层优化步数较多(

选择 MAML/FOMAML 当:

  • 需要更快的收敛
  • 内层只需1步适应
  • 需要与特定模型架构结合

扩展:Reptile + SGD

k-Shot Reptile

def reptile_kshot(model, support_x, support_y, query_x, query_y, 
                  lr_inner=0.01, k=5):
    """
    K-Shot Reptile
    
    对多个不同的k值取平均
    """
    predictions = []
    
    for k_i in range(1, k + 1):
        # 使用不同的内层步数
        adapted_state = inner_loop(model, support_x, support_y, steps=k_i)
        
        # 评估
        model.load_state_dict(adapted_state)
        with torch.no_grad():
            pred = model(query_x)
        predictions.append(pred)
    
    # 平均预测
    return torch.stack(predictions).mean(0)

参考文献

相关文章

Footnotes

  1. Nichol, A., & Schulman, J. (2018). “On First-Order Meta-Learning Algorithms”. arXiv:1803.02999.