MAML:模型无关元学习
MAML(Model-Agnostic Meta-Learning)由Chelsea Finn等人于2017年提出,是元学习领域最具影响力的算法之一。1 MAML的核心思想是:找到一个好的初始参数,使得模型能在少量梯度步内快速适应新任务。
核心思想
与传统微调的区别
传统微调:
模型(随机初始化 or 预训练)→ 大量梯度步 → 适应新任务
MAML:
多个相关任务训练 → 学习好的初始参数 → 少量梯度步 → 快速适应新任务
关键洞察
MAML不学习固定的特征表示,而是学习一个初始化点 ,从该点出发,少量梯度更新就能取得好效果。
数学推导
优化目标
给定任务分布 ,MAML优化:
其中:
- 是任务 上的快速适应参数
- 是任务损失
双层优化结构
MAML采用双层优化(Bi-Level Optimization):
内层优化(Inner Loop)
在支持集上执行快速适应:
这里 是内层学习率,每个任务独立更新。
外层优化(Outer Loop)
在查询集上优化初始参数:
这里 是元学习率。
梯度计算
外层梯度是整个优化过程的关键:
展开 :
这涉及二阶梯度计算,计算量大。
FOMAML(一阶MAML)
近似策略
FOMAML的核心近似:忽略二阶梯度项
即:直接在适应后参数 上计算梯度,而非回传到 。
对比
| 方法 | 梯度项 | 计算复杂度 |
|---|---|---|
| MAML | ||
| FOMAML |
实验结论
Finn等人的实验表明,FOMAML与MAML在大多数任务上性能相当,但训练更稳定、速度更快。
算法流程
输入:任务分布 p(τ),学习率 α, β,任务数 K
输出:最优初始化参数 θ*
1: 随机初始化 θ
2: while not converged do
3: 从 p(τ) 采样 K 个任务 {τ_i}
4:
5: for each τ_i do
6: // 内层:计算快速适应参数
7: θ'_i = θ - α * ∇_θ L(τ_i, f_θ) // 在Support集上
8: end for
9:
10: // 外层:更新元参数
11: θ = θ - β * ∇_θ Σ_i L(τ_i, f_θ'_i) // 在Query集上
12:
13: end while
14: return θ
代码实现
基础MAML实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple
class MAML:
"""
MAML元学习器
支持任意的任务分布和模型架构
"""
def __init__(
self,
model: nn.Module,
lr_inner: float = 0.01,
lr_meta: float = 0.001,
inner_steps: int = 5,
device: str = 'cuda'
):
self.model = model
self.lr_inner = lr_inner
self.lr_meta = lr_meta
self.inner_steps = inner_steps
self.device = device
self.model.to(device)
# 元优化器
self.optimizer = torch.optim.Adam(model.parameters(), lr=lr_meta)
def inner_update(
self,
support_x: torch.Tensor,
support_y: torch.Tensor,
theta: dict = None
) -> Tuple[nn.Module, dict]:
"""
内层优化:在支持集上快速适应
Args:
support_x: 支持集特征 (N, feature_dim)
support_y: 支持集标签 (N,)
theta: 可选的参数副本
Returns:
adapted_model: 适应后的模型
adapted_state: 适应后的参数状态
"""
# 使用指定参数或当前模型参数
if theta is not None:
# 临时加载参数
orig_state = self.model.state_dict()
self.model.load_state_dict(theta)
# 克隆模型用于快速适应
adapted_model = type(self.model)(**self.model_kwargs())
adapted_model.to(self.device)
adapted_model.load_state_dict(self.model.state_dict())
# 快速适应循环
for _ in range(self.inner_steps):
# 计算支持集损失
logits = adapted_model(support_x)
loss = F.cross_entropy(logits, support_y)
# 计算梯度
grads = torch.autograd.grad(
loss,
adapted_model.parameters(),
retain_graph=False
)
# 手动更新参数(不保存梯度到meta网络)
with torch.no_grad():
for (name, param), grad in zip(adapted_model.named_parameters(), grads):
if grad is not None:
param -= self.lr_inner * grad
# 恢复原始模型参数(如果有临时加载)
if theta is not None:
self.model.load_state_dict(orig_state)
return adapted_model, adapted_model.state_dict()
def meta_loss(
self,
query_x: torch.Tensor,
query_y: torch.Tensor,
adapted_state: dict
) -> torch.Tensor:
"""
计算元损失:在查询集上评估适应后的模型
"""
self.model.load_state_dict(adapted_state)
logits = self.model(query_x)
return F.cross_entropy(logits, query_y)
def train_step(
self,
tasks: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]
) -> dict:
"""
单步元训练
Args:
tasks: 任务列表,每个任务包含 (support_x, support_y, query_x, query_y)
Returns:
meta_loss: 元损失值
"""
self.optimizer.zero_grad()
meta_loss_sum = 0.0
for support_x, support_y, query_x, query_y in tasks:
# 移到设备
support_x = support_x.to(self.device)
support_y = support_y.to(self.device)
query_x = query_x.to(self.device)
query_y = query_y.to(self.device)
# 内层:快速适应
_, adapted_state = self.inner_update(support_x, support_y)
# 外层:计算元损失
loss = self.meta_loss(query_x, query_y, adapted_state)
meta_loss_sum += loss
# 平均元损失
meta_loss = meta_loss_sum / len(tasks)
# 反向传播更新元参数
meta_loss.backward()
self.optimizer.step()
return {'meta_loss': meta_loss.item()}
def model_kwargs(self) -> dict:
"""获取模型构造参数(用于克隆)"""
# 根据实际模型结构调整
return {}完整的Omniglot少样本学习示例
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchmeta.datasets import Omniglot
from torchmeta.transforms import ClassSplitter, Categorical
class OmniglotCNN(nn.Module):
"""
Omniglot分类器(用于MAML)
适合少样本学习的简单卷积网络
"""
def __init__(self, in_channels=1, out_features=5):
super().__init__()
self.features = nn.Sequential(
# Block 1
nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
# Block 2
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
# Block 3
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
# Block 4
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.classifier = nn.Linear(64, out_features)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
return self.classifier(x)
def extract(self, x):
"""特征提取(用于原型网络等)"""
x = self.features(x)
return x.view(x.size(0), -1)
class FewShotClassifier(nn.Module):
"""
Few-shot分类器包装器
支持MAML和原型网络
"""
def __init__(self, backbone, num_classes, num_support, num_query):
super().__init__()
self.backbone = backbone
self.num_support = num_support
self.num_query = num_query
self.num_classes = num_classes
def forward(self, support_x, support_y, query_x, mode='maml'):
"""
Args:
support_x: (N_way * K_shot, C, H, W)
support_y: (N_way * K_shot,)
query_x: (N_way * query_per_class, C, H, W)
mode: 'maml' or 'proto'
"""
if mode == 'maml':
# MAML: 使用最后一层分类器
return self._maml_forward(support_x, support_y, query_x)
else:
# 原型网络
return self._proto_forward(support_x, support_y, query_x)
def _maml_forward(self, support_x, support_y, query_x):
# 在支持集上微调分类器
logits = self.backbone(support_x)
loss = F.cross_entropy(logits, support_y)
# 简化的适应(只更新分类器)
grads = torch.autograd.grad(loss, self.backbone.classifier.parameters())
with torch.no_grad():
for param, grad in zip(self.backbone.classifier.parameters(), grads):
if grad is not None:
param -= 0.01 * grad
# 在查询集上预测
return self.backbone(query_x)
def _proto_forward(self, support_x, support_y, query_x):
# 提取特征
support_emb = self.backbone.extract(support_x)
query_emb = self.backbone.extract(query_x)
# 计算每个类的原型
prototypes = torch.zeros(self.num_classes, support_emb.size(1))
for c in range(self.num_classes):
mask = (support_y == c)
prototypes[c] = support_emb[mask].mean(0)
# 计算距离并分类
dists = torch.cdist(query_emb, prototypes)
return -dists
def train_maml():
"""MAML训练流程"""
# 数据集
dataset = Omniglot(
"data",
num_classes_per_task=5,
transform=torchmeta.transforms.ToTensor(),
class_augmentations=[torchmeta.transforms.RandomRotation(90)],
meta_train=True,
download=True
)
loader = DataLoader(dataset, batch_size=16, shuffle=True)
# 模型
model = OmniglotCNN(out_features=5)
maml = MAML(model, lr_inner=0.01, lr_meta=0.001, inner_steps=5)
# 训练循环
for epoch in range(100):
for batch in loader:
# 提取任务数据
support_x, support_y = batch['train']
query_x, query_y = batch['test']
# MAML训练步
loss_dict = maml.train_step([(
support_x, support_y, query_x, query_y
)])
if batch_idx % 100 == 0:
print(f"Epoch {epoch}, Meta Loss: {loss_dict['meta_loss']:.4f}")
if __name__ == '__main__':
train_maml()MAML的扩展
1. iMAML(Implicit MAML)
使用隐函数定理避免二阶梯度计算:
梯度通过求解线性系统得到。
2. MAML++
改进点:
- 任务相关自适应学习率:不同任务用不同学习率
- 多步损失:不仅用最后一步的损失
- 余弦衰减:学习率预热和衰减
3. 领域自适应MAML
- 添加领域判别器
- 对抗训练实现领域不变表示
优缺点分析
优点
| 优点 | 说明 |
|---|---|
| 模型无关 | 适用于任意可微模型 |
| 理论优雅 | 双层优化框架清晰 |
| 效果好 | 少样本学习性能领先 |
| 可扩展 | 易与其他技术结合 |
缺点
| 缺点 | 说明 |
|---|---|
| 计算开销大 | 需要计算二阶梯度(MAML)或多次梯度(FOMAML) |
| 任务相关超参 | 内层学习率、步数需要调优 |
| 鞍点问题 | 双层优化容易陷入局部最优 |
参考文献
相关文章
Footnotes
-
Finn, C., Abbeel, P., & Levine, S. (2017). “Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks”. International Conference on Machine Learning (ICML). ↩