元强化学习理论
元强化学习(Meta-RL)旨在让智能体学会”如何学习”,通过从多个任务中提取可迁移的知识,实现快速适应新任务。1
问题定义
元RL框架
元RL假设存在一个任务分布 ,每个任务 有自己的奖励函数和转移概率。
目标:学习一个策略初始化,使得在新任务上只需少量梯度步骤就能达到高性能。
与多任务学习的对比
| 方面 | 多任务学习 | 元强化学习 |
|---|---|---|
| 目标 | 联合优化所有任务 | 学会快速适应 |
| 评估 | 所有任务平均性能 | 新任务的适应速度 |
| 泛化 | 任务内泛化 | 任务间泛化 |
MAML算法
Finn et al. (2017)
MAML(Model-Agnostic Meta-Learning)是元RL领域最具影响力的算法。2
核心思想
学习一组”好”的初始参数 ,使得在新任务 上只需少量梯度步骤就能达到高性能。
数学形式化
关键点:
- :内层学习率(任务适应步长)
- :任务特定梯度
算法流程
import torch
import torch.nn as nn
import torch.optim as optim
class MAML:
def __init__(self, model, inner_lr=0.01, outer_lr=0.001):
self.model = model
self.inner_lr = inner_lr
self.optimizer = optim.Adam(model.parameters(), lr=outer_lr)
def inner_update(self, task_support_x, task_support_y, theta=None):
"""内层更新:在支持集上适应"""
if theta is None:
theta = {name: param.clone()
for name, param in self.model.named_parameters()}
# 计算支持集损失
loss = self.model.compute_loss(task_support_x, task_support_y)
# 计算梯度
grads = torch.autograd.grad(loss, theta.values(),
create_graph=True)
# 一步梯度更新
theta_prime = {
name: param - self.inner_lr * grad
for (name, param), grad in zip(theta.items(), grads)
}
return theta_prime
def meta_loss(self, theta_prime, query_x, query_y):
"""元损失:在查询集上评估"""
# 使用更新后的参数
with torch.no_grad():
# 先用原始参数前传获取需要梯度的参数
pass
# 应用更新后的参数
for name, param in self.model.named_parameters():
param.data = theta_prime[name]
query_loss = self.model.compute_loss(query_x, query_y)
# 恢复原始参数
for name, param in self.model.named_parameters():
param.data = theta[name]
return query_loss
def train_step(self, task_batch):
"""
task_batch: [(support_x, support_y, query_x, query_y), ...]
"""
meta_loss_sum = 0
for support_x, support_y, query_x, query_y in task_batch:
# 1. 内层更新
theta_prime = self.inner_update(support_x, support_y)
# 2. 计算元损失
meta_loss = self.meta_loss(theta_prime, query_x, query_y)
meta_loss_sum += meta_loss
# 3. 外层更新
self.optimizer.zero_grad()
meta_loss_sum.backward()
self.optimizer.step()Reptile算法
Nichol et al. (2018)
Reptile是MAML的简化版本,只使用一阶梯度信息。3
核心思想
通过反复采样任务、在任务上训练,并朝该方向移动参数。
更新公式
其中 是在任务 上多步SGD后的参数。
class Reptile:
def __init__(self, model, lr=0.1, inner_steps=5):
self.model = model
self.lr = lr
self.inner_steps = inner_steps
def inner_loop(self, task_x, task_y):
"""在任务上执行多步SGD"""
theta = {name: param.clone()
for name, param in self.model.named_parameters()}
for _ in range(self.inner_steps):
loss = self.model.compute_loss(task_x, task_y, theta)
grads = torch.autograd.grad(loss, theta.values())
theta = {
name: param - self.lr * grad
for (name, param), grad in zip(theta.items(), grads)
}
return theta
def train_step(self, task_batch):
# 保存原始参数
theta_0 = {name: param.clone()
for name, param in self.model.named_parameters()}
# 对每个任务执行内层循环
new_thetas = []
for task_x, task_y in task_batch:
theta_i = self.inner_loop(task_x, task_y)
new_thetas.append(theta_i)
# 计算平均更新方向
update = {
name: torch.zeros_like(param)
for name, param in theta_0.items()
}
for theta_i in new_thetas:
for name in update:
update[name] += (theta_i[name] - theta_0[name])
for name in update:
update[name] /= len(new_thetas)
# 应用更新
for name, param in self.model.named_parameters():
param.data += update[name]与In-Context Learning的联系
形式对应
| 元RL | In-Context Learning (ICL) |
|---|---|
| 任务分布 | 提示中的示例分布 |
| 内层梯度更新 | Transformer注意力操作 |
| 快速适应 | 少样本泛化 |
| 外层优化 | 预训练 |
ICL作为隐式元学习
最近的研究表明,Transformer的ICL能力可以理解为一种隐式元学习。4
class ICLAsMetaLearning:
"""
Transformer ICL 机制可以被视为一种隐式的元学习过程
"""
def forward(self, context_tokens, query_token):
"""
context_tokens: 支持集示例 [(x_i, y_i), ...]
query_token: 查询输入 x_q
过程类似于:
1. 在上下文tokens上"训练"(注意力聚合)
2. 在查询token上"评估"(预测 y_q)
"""
# 1. 上下文编码
context_repr = self.attention_to_context(context_tokens)
# 2. 查询处理(类似于任务适应)
query_repr = self.process_query(
query_token,
context_repr # 条件于上下文
)
# 3. 预测
return self.predict(query_repr)Meta-SGD
Li et al. (2017)
Meta-SGD扩展MAML,同时学习初始参数和每个参数的学习率。5
关键创新
除了学习 ,还学习每个参数的学习率 :
class MetaSGD(MAML):
def __init__(self, model, inner_lr=0.01):
super().__init__(model)
# 为每个参数学习一个学习率
self.alpha = nn.Parameter(
torch.ones_like(
torch.nn.utils.parameters_to_vector(model.parameters())
) * inner_lr
)
def inner_update(self, task_support_x, task_support_y, theta=None, alpha=None):
"""使用参数级学习率更新"""
if theta is None:
theta = {name: param.clone()
for name, param in self.model.named_parameters()}
loss = self.model.compute_loss(task_support_x, task_support_y, theta)
grads = torch.autograd.grad(loss, theta.values(), create_graph=True)
# 使用参数级学习率
alpha_i = alpha if alpha is not None else self.alpha
theta_prime = {
name: param - alpha_i[i] * grad
for i, (name, param, grad) in enumerate(
zip(theta.keys(), theta.values(), grads)
)
}
return theta_prime基准测试
Meta-World
Meta-World是元RL的标准基准,包含50个机械臂操作任务。6
评估协议
| 设置 | 描述 |
|---|---|
| ML-1 | 1个新任务,10个演示 |
| ML-10 | 10个训练任务,新测试任务 |
| ML-50 | 50个训练任务,大量新测试任务 |
参考资料
相关链接
- 上下文强化学习 — ICL与RL的联系
- 探索-利用权衡 — 元RL中的探索问题
- Actor-Critic — 策略优化基础
Footnotes
-
Beck et al., “A Survey of Meta-Reinforcement Learning”, arXiv:2301.08028, 2023 ↩
-
Finn et al., “Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks”, ICML, 2017 ↩
-
Nichol et al., “On First-Order Meta-Learning Algorithms”, arXiv:1803.02999, 2018 ↩
-
Garg et al., “What Can Transformers Learn In-Context?”, NeurIPS, 2022 ↩
-
Li et al., “Meta-SGD: Learning to Learn Quickly for Few-Shot Learning”, arXiv:1703.03400, 2017 ↩
-
Yu et al., “Meta-World: A Benchmark and Evaluation for Multi-Task and Meta Reinforcement Learning”, CoRL, 2019 ↩