元强化学习理论

元强化学习(Meta-RL)旨在让智能体学会”如何学习”,通过从多个任务中提取可迁移的知识,实现快速适应新任务。1

问题定义

元RL框架

元RL假设存在一个任务分布 ,每个任务 有自己的奖励函数和转移概率。

目标:学习一个策略初始化,使得在新任务上只需少量梯度步骤就能达到高性能。

与多任务学习的对比

方面多任务学习元强化学习
目标联合优化所有任务学会快速适应
评估所有任务平均性能新任务的适应速度
泛化任务内泛化任务间泛化

MAML算法

Finn et al. (2017)

MAML(Model-Agnostic Meta-Learning)是元RL领域最具影响力的算法。2

核心思想

学习一组”好”的初始参数 ,使得在新任务 上只需少量梯度步骤就能达到高性能。

数学形式化

关键点

  • :内层学习率(任务适应步长)
  • :任务特定梯度

算法流程

import torch
import torch.nn as nn
import torch.optim as optim
 
class MAML:
    def __init__(self, model, inner_lr=0.01, outer_lr=0.001):
        self.model = model
        self.inner_lr = inner_lr
        self.optimizer = optim.Adam(model.parameters(), lr=outer_lr)
    
    def inner_update(self, task_support_x, task_support_y, theta=None):
        """内层更新:在支持集上适应"""
        if theta is None:
            theta = {name: param.clone() 
                     for name, param in self.model.named_parameters()}
        
        # 计算支持集损失
        loss = self.model.compute_loss(task_support_x, task_support_y)
        
        # 计算梯度
        grads = torch.autograd.grad(loss, theta.values(), 
                                     create_graph=True)
        
        # 一步梯度更新
        theta_prime = {
            name: param - self.inner_lr * grad 
            for (name, param), grad in zip(theta.items(), grads)
        }
        
        return theta_prime
    
    def meta_loss(self, theta_prime, query_x, query_y):
        """元损失:在查询集上评估"""
        # 使用更新后的参数
        with torch.no_grad():
            # 先用原始参数前传获取需要梯度的参数
            pass
        
        # 应用更新后的参数
        for name, param in self.model.named_parameters():
            param.data = theta_prime[name]
        
        query_loss = self.model.compute_loss(query_x, query_y)
        
        # 恢复原始参数
        for name, param in self.model.named_parameters():
            param.data = theta[name]
        
        return query_loss
    
    def train_step(self, task_batch):
        """
        task_batch: [(support_x, support_y, query_x, query_y), ...]
        """
        meta_loss_sum = 0
        
        for support_x, support_y, query_x, query_y in task_batch:
            # 1. 内层更新
            theta_prime = self.inner_update(support_x, support_y)
            
            # 2. 计算元损失
            meta_loss = self.meta_loss(theta_prime, query_x, query_y)
            meta_loss_sum += meta_loss
        
        # 3. 外层更新
        self.optimizer.zero_grad()
        meta_loss_sum.backward()
        self.optimizer.step()

Reptile算法

Nichol et al. (2018)

Reptile是MAML的简化版本,只使用一阶梯度信息。3

核心思想

通过反复采样任务、在任务上训练,并朝该方向移动参数。

更新公式

其中 是在任务 上多步SGD后的参数。

class Reptile:
    def __init__(self, model, lr=0.1, inner_steps=5):
        self.model = model
        self.lr = lr
        self.inner_steps = inner_steps
    
    def inner_loop(self, task_x, task_y):
        """在任务上执行多步SGD"""
        theta = {name: param.clone() 
                 for name, param in self.model.named_parameters()}
        
        for _ in range(self.inner_steps):
            loss = self.model.compute_loss(task_x, task_y, theta)
            grads = torch.autograd.grad(loss, theta.values())
            theta = {
                name: param - self.lr * grad 
                for (name, param), grad in zip(theta.items(), grads)
            }
        
        return theta
    
    def train_step(self, task_batch):
        # 保存原始参数
        theta_0 = {name: param.clone() 
                   for name, param in self.model.named_parameters()}
        
        # 对每个任务执行内层循环
        new_thetas = []
        for task_x, task_y in task_batch:
            theta_i = self.inner_loop(task_x, task_y)
            new_thetas.append(theta_i)
        
        # 计算平均更新方向
        update = {
            name: torch.zeros_like(param)
            for name, param in theta_0.items()
        }
        
        for theta_i in new_thetas:
            for name in update:
                update[name] += (theta_i[name] - theta_0[name])
        
        for name in update:
            update[name] /= len(new_thetas)
        
        # 应用更新
        for name, param in self.model.named_parameters():
            param.data += update[name]

与In-Context Learning的联系

形式对应

元RLIn-Context Learning (ICL)
任务分布提示中的示例分布
内层梯度更新Transformer注意力操作
快速适应少样本泛化
外层优化预训练

ICL作为隐式元学习

最近的研究表明,Transformer的ICL能力可以理解为一种隐式元学习4

class ICLAsMetaLearning:
    """
    Transformer ICL 机制可以被视为一种隐式的元学习过程
    """
    
    def forward(self, context_tokens, query_token):
        """
        context_tokens: 支持集示例 [(x_i, y_i), ...]
        query_token: 查询输入 x_q
        
        过程类似于:
        1. 在上下文tokens上"训练"(注意力聚合)
        2. 在查询token上"评估"(预测 y_q)
        """
        
        # 1. 上下文编码
        context_repr = self.attention_to_context(context_tokens)
        
        # 2. 查询处理(类似于任务适应)
        query_repr = self.process_query(
            query_token, 
            context_repr  # 条件于上下文
        )
        
        # 3. 预测
        return self.predict(query_repr)

Meta-SGD

Li et al. (2017)

Meta-SGD扩展MAML,同时学习初始参数和每个参数的学习率。5

关键创新

除了学习 ,还学习每个参数的学习率

class MetaSGD(MAML):
    def __init__(self, model, inner_lr=0.01):
        super().__init__(model)
        # 为每个参数学习一个学习率
        self.alpha = nn.Parameter(
            torch.ones_like(
                torch.nn.utils.parameters_to_vector(model.parameters())
            ) * inner_lr
        )
    
    def inner_update(self, task_support_x, task_support_y, theta=None, alpha=None):
        """使用参数级学习率更新"""
        if theta is None:
            theta = {name: param.clone() 
                     for name, param in self.model.named_parameters()}
        
        loss = self.model.compute_loss(task_support_x, task_support_y, theta)
        grads = torch.autograd.grad(loss, theta.values(), create_graph=True)
        
        # 使用参数级学习率
        alpha_i = alpha if alpha is not None else self.alpha
        theta_prime = {
            name: param - alpha_i[i] * grad 
            for i, (name, param, grad) in enumerate(
                zip(theta.keys(), theta.values(), grads)
            )
        }
        
        return theta_prime

基准测试

Meta-World

Meta-World是元RL的标准基准,包含50个机械臂操作任务。6

评估协议

设置描述
ML-11个新任务,10个演示
ML-1010个训练任务,新测试任务
ML-5050个训练任务,大量新测试任务

参考资料


相关链接

Footnotes

  1. Beck et al., “A Survey of Meta-Reinforcement Learning”, arXiv:2301.08028, 2023

  2. Finn et al., “Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks”, ICML, 2017

  3. Nichol et al., “On First-Order Meta-Learning Algorithms”, arXiv:1803.02999, 2018

  4. Garg et al., “What Can Transformers Learn In-Context?”, NeurIPS, 2022

  5. Li et al., “Meta-SGD: Learning to Learn Quickly for Few-Shot Learning”, arXiv:1703.03400, 2017

  6. Yu et al., “Meta-World: A Benchmark and Evaluation for Multi-Task and Meta Reinforcement Learning”, CoRL, 2019