MAML:模型无关元学习

MAML(Model-Agnostic Meta-Learning)由Chelsea Finn等人于2017年提出,是元学习领域最具影响力的算法之一。1 MAML的核心思想是:找到一个好的初始参数,使得模型能在少量梯度步内快速适应新任务

核心思想

与传统微调的区别

传统微调:
模型(随机初始化 or 预训练)→ 大量梯度步 → 适应新任务

MAML:
多个相关任务训练 → 学习好的初始参数 → 少量梯度步 → 快速适应新任务

关键洞察

MAML不学习固定的特征表示,而是学习一个初始化点 ,从该点出发,少量梯度更新就能取得好效果。


数学推导

优化目标

给定任务分布 ,MAML优化:

其中:

  • 是任务 上的快速适应参数
  • 是任务损失

双层优化结构

MAML采用双层优化(Bi-Level Optimization):

内层优化(Inner Loop)

在支持集上执行快速适应:

这里 是内层学习率,每个任务独立更新

外层优化(Outer Loop)

在查询集上优化初始参数:

这里 是元学习率。

梯度计算

外层梯度是整个优化过程的关键:

展开

这涉及二阶梯度计算,计算量大。


FOMAML(一阶MAML)

近似策略

FOMAML的核心近似:忽略二阶梯度项

即:直接在适应后参数 上计算梯度,而非回传到

对比

方法梯度项计算复杂度
MAML
FOMAML

实验结论

Finn等人的实验表明,FOMAML与MAML在大多数任务上性能相当,但训练更稳定、速度更快。


算法流程

输入:任务分布 p(τ),学习率 α, β,任务数 K
输出:最优初始化参数 θ*

1: 随机初始化 θ
2: while not converged do
3:   从 p(τ) 采样 K 个任务 {τ_i}
4:   
5:   for each τ_i do
6:     // 内层:计算快速适应参数
7:     θ'_i = θ - α * ∇_θ L(τ_i, f_θ)     // 在Support集上
8:   end for
9:   
10:  // 外层:更新元参数
11:  θ = θ - β * ∇_θ Σ_i L(τ_i, f_θ'_i)   // 在Query集上
12: 
13: end while
14: return θ

代码实现

基础MAML实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple
 
class MAML:
    """
    MAML元学习器
    
    支持任意的任务分布和模型架构
    """
    def __init__(
        self,
        model: nn.Module,
        lr_inner: float = 0.01,
        lr_meta: float = 0.001,
        inner_steps: int = 5,
        device: str = 'cuda'
    ):
        self.model = model
        self.lr_inner = lr_inner
        self.lr_meta = lr_meta
        self.inner_steps = inner_steps
        self.device = device
        self.model.to(device)
        
        # 元优化器
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr_meta)
    
    def inner_update(
        self,
        support_x: torch.Tensor,
        support_y: torch.Tensor,
        theta: dict = None
    ) -> Tuple[nn.Module, dict]:
        """
        内层优化:在支持集上快速适应
        
        Args:
            support_x: 支持集特征 (N, feature_dim)
            support_y: 支持集标签 (N,)
            theta: 可选的参数副本
        
        Returns:
            adapted_model: 适应后的模型
            adapted_state: 适应后的参数状态
        """
        # 使用指定参数或当前模型参数
        if theta is not None:
            # 临时加载参数
            orig_state = self.model.state_dict()
            self.model.load_state_dict(theta)
        
        # 克隆模型用于快速适应
        adapted_model = type(self.model)(**self.model_kwargs())
        adapted_model.to(self.device)
        adapted_model.load_state_dict(self.model.state_dict())
        
        # 快速适应循环
        for _ in range(self.inner_steps):
            # 计算支持集损失
            logits = adapted_model(support_x)
            loss = F.cross_entropy(logits, support_y)
            
            # 计算梯度
            grads = torch.autograd.grad(
                loss, 
                adapted_model.parameters(),
                retain_graph=False
            )
            
            # 手动更新参数(不保存梯度到meta网络)
            with torch.no_grad():
                for (name, param), grad in zip(adapted_model.named_parameters(), grads):
                    if grad is not None:
                        param -= self.lr_inner * grad
        
        # 恢复原始模型参数(如果有临时加载)
        if theta is not None:
            self.model.load_state_dict(orig_state)
        
        return adapted_model, adapted_model.state_dict()
    
    def meta_loss(
        self,
        query_x: torch.Tensor,
        query_y: torch.Tensor,
        adapted_state: dict
    ) -> torch.Tensor:
        """
        计算元损失:在查询集上评估适应后的模型
        """
        self.model.load_state_dict(adapted_state)
        logits = self.model(query_x)
        return F.cross_entropy(logits, query_y)
    
    def train_step(
        self,
        tasks: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]
    ) -> dict:
        """
        单步元训练
        
        Args:
            tasks: 任务列表,每个任务包含 (support_x, support_y, query_x, query_y)
        
        Returns:
            meta_loss: 元损失值
        """
        self.optimizer.zero_grad()
        
        meta_loss_sum = 0.0
        for support_x, support_y, query_x, query_y in tasks:
            # 移到设备
            support_x = support_x.to(self.device)
            support_y = support_y.to(self.device)
            query_x = query_x.to(self.device)
            query_y = query_y.to(self.device)
            
            # 内层:快速适应
            _, adapted_state = self.inner_update(support_x, support_y)
            
            # 外层:计算元损失
            loss = self.meta_loss(query_x, query_y, adapted_state)
            meta_loss_sum += loss
        
        # 平均元损失
        meta_loss = meta_loss_sum / len(tasks)
        
        # 反向传播更新元参数
        meta_loss.backward()
        self.optimizer.step()
        
        return {'meta_loss': meta_loss.item()}
    
    def model_kwargs(self) -> dict:
        """获取模型构造参数(用于克隆)"""
        # 根据实际模型结构调整
        return {}

完整的Omniglot少样本学习示例

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchmeta.datasets import Omniglot
from torchmeta.transforms import ClassSplitter, Categorical
 
class OmniglotCNN(nn.Module):
    """
    Omniglot分类器(用于MAML)
    
    适合少样本学习的简单卷积网络
    """
    def __init__(self, in_channels=1, out_features=5):
        super().__init__()
        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            # Block 2
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            # Block 3
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            
            # Block 4
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        
        self.classifier = nn.Linear(64, out_features)
        
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)
    
    def extract(self, x):
        """特征提取(用于原型网络等)"""
        x = self.features(x)
        return x.view(x.size(0), -1)
 
 
class FewShotClassifier(nn.Module):
    """
    Few-shot分类器包装器
    
    支持MAML和原型网络
    """
    def __init__(self, backbone, num_classes, num_support, num_query):
        super().__init__()
        self.backbone = backbone
        self.num_support = num_support
        self.num_query = num_query
        self.num_classes = num_classes
        
    def forward(self, support_x, support_y, query_x, mode='maml'):
        """
        Args:
            support_x: (N_way * K_shot, C, H, W)
            support_y: (N_way * K_shot,)
            query_x: (N_way * query_per_class, C, H, W)
            mode: 'maml' or 'proto'
        """
        if mode == 'maml':
            # MAML: 使用最后一层分类器
            return self._maml_forward(support_x, support_y, query_x)
        else:
            # 原型网络
            return self._proto_forward(support_x, support_y, query_x)
    
    def _maml_forward(self, support_x, support_y, query_x):
        # 在支持集上微调分类器
        logits = self.backbone(support_x)
        loss = F.cross_entropy(logits, support_y)
        
        # 简化的适应(只更新分类器)
        grads = torch.autograd.grad(loss, self.backbone.classifier.parameters())
        with torch.no_grad():
            for param, grad in zip(self.backbone.classifier.parameters(), grads):
                if grad is not None:
                    param -= 0.01 * grad
        
        # 在查询集上预测
        return self.backbone(query_x)
    
    def _proto_forward(self, support_x, support_y, query_x):
        # 提取特征
        support_emb = self.backbone.extract(support_x)
        query_emb = self.backbone.extract(query_x)
        
        # 计算每个类的原型
        prototypes = torch.zeros(self.num_classes, support_emb.size(1))
        for c in range(self.num_classes):
            mask = (support_y == c)
            prototypes[c] = support_emb[mask].mean(0)
        
        # 计算距离并分类
        dists = torch.cdist(query_emb, prototypes)
        return -dists
 
 
def train_maml():
    """MAML训练流程"""
    # 数据集
    dataset = Omniglot(
        "data",
        num_classes_per_task=5,
        transform=torchmeta.transforms.ToTensor(),
        class_augmentations=[torchmeta.transforms.RandomRotation(90)],
        meta_train=True,
        download=True
    )
    
    loader = DataLoader(dataset, batch_size=16, shuffle=True)
    
    # 模型
    model = OmniglotCNN(out_features=5)
    maml = MAML(model, lr_inner=0.01, lr_meta=0.001, inner_steps=5)
    
    # 训练循环
    for epoch in range(100):
        for batch in loader:
            # 提取任务数据
            support_x, support_y = batch['train']
            query_x, query_y = batch['test']
            
            # MAML训练步
            loss_dict = maml.train_step([(
                support_x, support_y, query_x, query_y
            )])
            
            if batch_idx % 100 == 0:
                print(f"Epoch {epoch}, Meta Loss: {loss_dict['meta_loss']:.4f}")
 
 
if __name__ == '__main__':
    train_maml()

MAML的扩展

1. iMAML(Implicit MAML)

使用隐函数定理避免二阶梯度计算:

梯度通过求解线性系统得到。

2. MAML++

改进点:

  • 任务相关自适应学习率:不同任务用不同学习率
  • 多步损失:不仅用最后一步的损失
  • 余弦衰减:学习率预热和衰减

3. 领域自适应MAML

  • 添加领域判别器
  • 对抗训练实现领域不变表示

优缺点分析

优点

优点说明
模型无关适用于任意可微模型
理论优雅双层优化框架清晰
效果好少样本学习性能领先
可扩展易与其他技术结合

缺点

缺点说明
计算开销大需要计算二阶梯度(MAML)或多次梯度(FOMAML)
任务相关超参内层学习率、步数需要调优
鞍点问题双层优化容易陷入局部最优

参考文献

相关文章

Footnotes

  1. Finn, C., Abbeel, P., & Levine, S. (2017). “Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks”. International Conference on Machine Learning (ICML).