Reptile算法
Reptile由Alex Nichol等人于2018年提出,是一种简化的一阶元学习算法。1 与MAML相比,Reptile去除了复杂的二阶梯度计算,仅通过多次SGD迭代来学习良好的初始化参数。
核心思想
MAML的复杂性
MAML需要计算二阶梯度(或者FOMAML的一阶梯度近似):
其中
Reptile的简化
Reptile的核心观察:从随机初始化的参数出发,朝向某个任务的局部最优参数 移动一步。
即:沿着从 到 的方向更新。
数学推导
优化目标
Reptile最小化:
其中 是任务 上的最优损失。
更新规则
从任务分布中采样任务 :
-
内层循环:对任务 执行 步SGD:
-
外层循环:更新初始化参数:
与MAML的关系
展开Reptile的更新:
考虑泰勒展开:
当 时,Reptile近似FOMAML:
当 时,Reptile引入高阶项,与MAML更接近。
算法流程
输入:任务分布 p(τ),学习率 α, β,迭代次数 K
输出:最优初始化参数 θ
1: 随机初始化 θ
2: while not converged do
3: 从 p(τ) 采样任务 τ
4:
5: // 内层:执行 K 步 SGD
6: θ₀ = θ
7: for i = 1 to K do
8: θᵢ = θᵢ₋₁ - α ∇_θ L_τ(f_{θᵢ₋₁})
9: end for
10: θ̃ = θ_K
11:
12: // 外层:更新初始化参数
13: θ = θ + β (θ̃ - θ₀)
14:
15: end while
16: return θ
与MAML的关键区别
| 方面 | MAML | Reptile |
|---|---|---|
| 内层更新 | 精确梯度下降 | 多次SGD |
| 梯度计算 | 需要区分内层/外层 | 普通SGD |
| 二阶梯度 | MAML需要,FOMAML不需要 | 不需要 |
| 任务感知 | 是 | 否(批处理相同) |
代码实现
基础Reptile实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable, List, Tuple
import copy
class Reptile:
"""
Reptile元学习器
简化的一阶元学习算法
"""
def __init__(
self,
model: nn.Module,
lr_inner: float = 0.01,
lr_meta: float = 0.001,
inner_steps: int = 5,
device: str = 'cuda'
):
self.model = model
self.lr_inner = lr_inner
self.lr_meta = lr_meta
self.inner_steps = inner_steps
self.device = device
self.model.to(device)
# 元优化器
self.optimizer = torch.optim.Adam(model.parameters(), lr=lr_meta)
def inner_loop(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
theta: dict = None
) -> dict:
"""
内层循环:执行多次SGD更新
Args:
support_x: 支持集特征
support_y: 支持集标签
theta: 初始参数(如果为None则使用当前模型参数)
Returns:
adapted_state: 适应后的参数状态
"""
# 保存原始参数
if theta is None:
theta = {name: param.clone() for name, param in self.model.named_parameters()}
# 克隆模型用于内层更新
adapted_model = type(self.model)(**self._get_model_config())
adapted_model.to(self.device)
adapted_model.load_state_dict({
name: param.clone() for name, param in theta.items()
})
# 多次SGD更新
for _ in range(self.inner_steps):
# 前向
logits = adapted_model(support_x)
loss = F.cross_entropy(logits, support_y)
# 梯度
grads = torch.autograd.grad(loss, adapted_model.parameters())
# SGD更新
with torch.no_grad():
for (name, param), grad in zip(adapted_model.named_parameters(), grads):
if grad is not None:
param.sub_(self.lr_inner * grad)
return adapted_model.state_dict()
def meta_loss(
self,
query_x: torch.Tensor,
query_y: torch.Tensor,
adapted_state: dict
) -> torch.Tensor:
"""
计算元损失
"""
self.model.load_state_dict(adapted_state)
logits = self.model(query_x)
return F.cross_entropy(logits, query_y)
def meta_step(
self,
tasks: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
use_reptile: bool = True
) -> dict:
"""
元步骤
Args:
tasks: 任务列表
use_reptile: True使用Reptile,False使用FOMAML
"""
self.optimizer.zero_grad()
meta_loss_sum = 0.0
for support_x, support_y, query_x, query_y in tasks:
# 移到设备
support_x = support_x.to(self.device)
support_y = support_y.to(self.device)
query_x = query_x.to(self.device)
query_y = query_y.to(self.device)
if use_reptile:
# Reptile: 计算初始和最终的参数差异
init_state = {name: param.clone() for name, param in self.model.named_parameters()}
adapted_state = self.inner_loop(support_x, support_y, init_state)
# Reptile损失 = 适应后模型在查询集上的损失
loss = self.meta_loss(query_x, query_y, adapted_state)
# 参数更新将通过Reptile方式(外层循环)
# 这里只计算损失用于监控
meta_loss_sum += loss.item()
else:
# FOMAML: 直接在适应后参数上计算梯度
_, adapted_state = self.inner_loop(support_x, support_y)
loss = self.meta_loss(query_x, query_y, adapted_state)
# 反向传播
loss.backward()
meta_loss_sum += loss.item()
if use_reptile:
# Reptile: 直接更新参数
self._reptile_update(tasks)
else:
# FOMAML: 已有梯度
self.optimizer.step()
return {'meta_loss': meta_loss_sum / len(tasks)}
def _reptile_update(
self,
tasks: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]
):
"""
Reptile参数更新
θ ← θ + β (θ̃ - θ₀)
"""
init_params = {name: param.clone() for name, param in self.model.named_parameters()}
for support_x, support_y, _, _ in tasks:
support_x = support_x.to(self.device)
support_y = support_y.to(self.device)
# 内层更新
adapted_state = self.inner_loop(support_x, support_y, init_params)
# 累积参数差异
with torch.no_grad():
for name, param in self.model.named_parameters():
param.add_(adapted_state[name] - init_params[name], alpha=self.lr_meta)
def _get_model_config(self) -> dict:
"""获取模型配置(用于克隆)"""
# 根据实际模型结构调整
return {}简化的Reptile实现(PyTorch风格)
class SimpleReptile(nn.Module):
"""
简化的Reptile实现(PyTorch原生风格)
"""
def __init__(self, model_class, model_kwargs, lr_inner=0.01, lr_meta=0.001, inner_steps=5):
super().__init__()
self.lr_inner = lr_inner
self.lr_meta = lr_meta
self.inner_steps = inner_steps
# 元参数
self.theta = nn.Parameter(self._init_params(model_kwargs))
self.model_class = model_class
self.model_kwargs = model_kwargs
def _init_params(self, kwargs):
"""初始化参数"""
params = {}
# 简化的参数初始化
for name, shape in kwargs.items():
if isinstance(shape, tuple):
params[name] = torch.randn(*shape) * 0.01
return params
def inner_update(self, params, support_x, support_y):
"""内层SGD更新"""
# 克隆参数
new_params = {k: v.clone() for k, v in params.items()}
model = self.model_class(**self.model_kwargs)
model.load_state_dict({k: v for k, v in zip(model.state_dict().keys(), new_params.values())})
for _ in range(self.inner_steps):
logits = model(support_x)
loss = F.cross_entropy(logits, support_y)
grads = torch.autograd.grad(loss, model.parameters())
with torch.no_grad():
for (name, param), grad in zip(model.named_parameters(), grads):
if grad is not None and name in new_params:
new_params[name] -= self.lr_inner * grad
return new_params
def forward(self, tasks):
"""
tasks: [(support_x, support_y, query_x, query_y), ...]
"""
# 复制初始参数
theta_0 = {k: v.clone() for k, v in self.theta.items()}
# 对每个任务执行内层更新
adapted_params_list = []
for support_x, support_y, _, _ in tasks:
adapted = self.inner_update(theta_0, support_x, support_y)
adapted_params_list.append(adapted)
# Reptile更新
with torch.no_grad():
grad_direction = torch.zeros_like(self.theta)
for adapted in adapted_params_list:
for name in self.theta:
grad_direction[name] += (adapted[name] - theta_0[name])
grad_direction /= len(tasks)
# 更新参数
with torch.no_grad():
for name in self.theta:
self.theta[name] += self.lr_meta * grad_direction[name]端到端训练示例
import torch
from torch.utils.data import DataLoader, TensorDataset
class SimpleCNN(nn.Module):
"""简单CNN用于少样本学习"""
def __init__(self, in_channels=1, num_classes=5):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(in_channels, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 64, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.classifier = nn.Linear(64, num_classes)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
def train_reptile():
"""Reptile训练流程"""
# 配置
WAY = 5
SHOT = 5
QUERY = 15
EPOCHS = 100
LR_INNER = 0.01
LR_META = 0.001
INNER_STEPS = 5
# 模型
model = SimpleCNN(in_channels=1, num_classes=WAY)
reptile = Reptile(
model=model,
lr_inner=LR_INNER,
lr_meta=LR_META,
inner_steps=INNER_STEPS
)
# 训练
for epoch in range(EPOCHS):
# 生成随机任务(示例)
support_x = torch.randn(WAY * SHOT, 1, 7, 7)
support_y = torch.randint(0, WAY, (WAY * SHOT,))
query_x = torch.randn(WAY * QUERY, 1, 7, 7)
query_y = torch.randint(0, WAY, (WAY * QUERY,))
# 元训练步
tasks = [(support_x, support_y, query_x, query_y)]
loss_dict = reptile.meta_step(tasks, use_reptile=True)
if epoch % 10 == 0:
print(f"Epoch {epoch}: Loss={loss_dict['meta_loss']:.4f}")
if __name__ == '__main__':
train_reptile()Reptile vs MAML vs FOMAML
算法对比
| 特性 | MAML | FOMAML | Reptile |
|---|---|---|---|
| 梯度类型 | 二阶 | 一阶(近似) | 零阶/伪梯度 |
| 内层计算 | 1步梯度下降 | 1步梯度下降 | K步SGD |
| 计算量 | 大 | 中 | 小 |
| 实现复杂度 | 高 | 中 | 低 |
| 收敛速度 | 快 | 快 | 中 |
| 最终性能 | 最好 | 接近MAML | 相当 |
何时选择Reptile?
选择 Reptile 当:
- 计算资源有限
- 需要简单的实现
- 任务的内层优化步数较多()
选择 MAML/FOMAML 当:
- 需要更快的收敛
- 内层只需1步适应
- 需要与特定模型架构结合
扩展:Reptile + SGD
k-Shot Reptile
def reptile_kshot(model, support_x, support_y, query_x, query_y,
lr_inner=0.01, k=5):
"""
K-Shot Reptile
对多个不同的k值取平均
"""
predictions = []
for k_i in range(1, k + 1):
# 使用不同的内层步数
adapted_state = inner_loop(model, support_x, support_y, steps=k_i)
# 评估
model.load_state_dict(adapted_state)
with torch.no_grad():
pred = model(query_x)
predictions.append(pred)
# 平均预测
return torch.stack(predictions).mean(0)参考文献
相关文章
Footnotes
-
Nichol, A., & Schulman, J. (2018). “On First-Order Meta-Learning Algorithms”. arXiv:1803.02999. ↩