元学习中的优化方法
元学习中的”学习如何学习”本质上是一个优化问题。本文档探讨基于优化的元学习方法,重点关注如何设计更好的优化器来适应新任务。
优化视角回顾
三种元学习范式的优化视角
| 范式 | 优化对象 | 方法 |
|---|---|---|
| 度量学习 | 嵌入空间 | 学习距离函数 |
| 模型学习 | 网络结构 | 快速参数更新 |
| 优化学习 | 优化器本身 | 学习优化器 |
双层优化的挑战
元学习的核心是双层优化(Bi-Level Optimization):
LSTM元学习器
核心思想
使用LSTM作为元学习器,学习如何更新神经网络的参数:
为什么用LSTM?
LSTM的门控机制类似于优化器的行为:
- 遗忘门:决定保留多少历史梯度信息
- 输入门:控制新梯度的影响
- 输出门:决定更新方向
代码实现
class LSTMOptimizer(nn.Module):
"""
LSTM元学习器
学习一个LSTM来预测参数更新
"""
def __init__(self, model_dim, hidden_dim=20):
super().__init__()
# 元学习器的LSTM
self.lstm = nn.LSTMCell(input_size=model_dim * 2, hidden_size=hidden_dim)
# 输出层:预测更新方向和幅度
self.output = nn.Linear(hidden_dim, 2 * model_dim) # 方向和log lr
self.model_dim = model_dim
self.hidden_dim = hidden_dim
def forward(self, grads, params, state=None):
"""
Args:
grads: 梯度向量 (d,)
params: 参数向量 (d,)
state: (h, c) LSTM状态
Returns:
update: 更新向量 (d,)
new_state: 更新后的LSTM状态
"""
# 拼接梯度和参数
x = torch.cat([grads, params], dim=-1)
# LSTM前向
if state is None:
batch_size = grads.size(0)
h = torch.zeros(batch_size, self.hidden_dim, device=grads.device)
c = torch.zeros(batch_size, self.hidden_dim, device=grads.device)
else:
h, c = state
h_new, c_new = self.lstm(x, (h, c))
# 预测更新
out = self.output(h_new)
update, log_lr = out.chunk(2, dim=-1)
# 缩放更新
lr = torch.exp(log_lr)
update = torch.tanh(update) * lr
return update, (h_new, c_new)
class MetaOptimizer:
"""
基于LSTM的元优化器
"""
def __init__(self, model, lr_meta=0.001):
self.model = model
self.lr_meta = lr_meta
# 元学习器
self.optimizer_lstm = LSTMOptimizer(
model_dim=sum(p.numel() for p in model.parameters()),
hidden_dim=32
)
self.optimizer = torch.optim.Adam(
self.optimizer_lstm.parameters(),
lr=lr_meta
)
def step(self, loss, retain_graph=False):
"""
执行元优化步骤
"""
self.optimizer.zero_grad()
# 获取当前参数和梯度
params = torch.cat([p.data.view(-1) for p in self.model.parameters()])
grads = torch.cat([
p.grad.view(-1) for p in self.model.parameters()
if p.grad is not None
])
# LSTM更新
update, _ = self.optimizer_lstm(grads, params)
# 应用更新
offset = 0
for p in self.model.parameters():
numel = p.numel()
p.data.add_(update[offset:offset+numel].view(p.shape))
offset += numel
return update.norm().item()近似方法
一阶梯度估计
移位函数
不使用显式梯度,使用移位函数(Shift Function)近似:
其中 是学习的移位函数。
REINFORCE
使用REINFORCE估计梯度:
二阶梯度的近似
对角近似
忽略Hessian矩阵的非对角元素:
K-FAC近似
使用Kronecker分解:
其中 和 分别是输入和输出的协方差矩阵。
任务自适应学习率
核心思想
不同任务可能需要不同的学习率。MAML++ 等方法引入了任务自适应学习率。
实现方式
class TaskAdaptiveLR(nn.Module):
"""
任务自适应学习率
为每个任务学习个性化的学习率缩放
"""
def __init__(self, model_dim, num_tasks):
super().__init__()
# 每个任务的学习率缩放因子
self.lr_scalers = nn.Embedding(num_tasks, model_dim)
nn.init.ones_(self.lr_scalers.weight)
def get_lr(self, task_id):
"""获取特定任务的学习率"""
return self.lr_scalers(task_id)
class MAMLPlusPlus(nn.Module):
"""
MAML++ 实现
改进点:
1. 任务自适应学习率
2. 多步损失
3. 余弦衰减
"""
def __init__(self, model, num_tasks=1000):
super().__init__()
self.model = model
self.task_lr = TaskAdaptiveLR(
model_dim=sum(p.numel() for p in model.parameters()),
num_tasks=num_tasks
)
# 元学习率
self.meta_lr = 0.001
self.num_inner_steps = 5
def inner_update(self, support_x, support_y, task_id, first_order=True):
"""
任务内更新(带自适应学习率)
"""
params = {name: p.clone() for name, p in self.model.named_parameters()}
lr_scalers = self.task_lr.get_lr(task_id)
lr_base = 0.01
losses = []
for step in range(self.num_inner_steps):
# 前向
logits = self._forward(params, support_x)
loss = F.cross_entropy(logits, support_y)
losses.append(loss)
# 梯度
grads = torch.autograd.grad(
loss, params.values(),
retain_graph=(step < self.num_inner_steps - 1)
)
# 自适应学习率更新
offset = 0
for name, param in params.items():
numel = param.numel()
lr = lr_base * lr_scalers[offset:offset+numel].mean()
param.data.sub_(lr * grads[list(params.keys()).index(name)])
offset += numel
# 多步损失
return sum(losses) / len(losses)
def meta_update(self, tasks):
"""
元更新
"""
meta_grads = None
for task_id, (support_x, support_y, query_x, query_y) in enumerate(tasks):
# 内层更新
inner_loss = self.inner_update(support_x, support_y, task_id)
# 外层梯度
grads = torch.autograd.grad(
inner_loss,
self.model.parameters(),
retain_graph=True
)
if meta_grads is None:
meta_grads = grads
else:
meta_grads = [g1 + g2 for g1, g2 in zip(meta_grads, grads)]
# 更新
with torch.no_grad():
for param, grad in zip(self.model.parameters(), meta_grads):
param -= self.meta_lr * grad / len(tasks)无梯度方法
进化策略
使用进化算法优化元参数:
class EvolutionaryMetaLearning:
"""
基于进化策略的元学习
"""
def __init__(self, model, pop_size=50, lr=0.01):
self.model = model
self.pop_size = pop_size
self.lr = lr
def evolve_step(self, task):
"""
一步进化
"""
# 当前参数
theta = self._get_params()
# 采样噪声
noises = [
{name: torch.randn_like(p) for name, p in self.model.named_parameters()}
for _ in range(self.pop_size)
]
# 评估
fitness = []
for noise in noises:
# 扰动参数
perturbed = {name: theta[name] + self.lr * noise[name] for name in theta}
# 计算适应度
fitness.append(self._evaluate(perturbed, task))
# 计算梯度估计
grad_est = {name: torch.zeros_like(p) for name, p in theta.items()}
for noise, fit in zip(noises, fitness):
for name in theta:
grad_est[name] += fit * noise[name]
grad_est = {name: g / (self.pop_size * self.lr) for name, g in grad_est.items()}
# 更新
for name in theta:
theta[name] -= self.lr * grad_est[name]
self._set_params(theta)
return sum(fitness) / len(fitness)
def _get_params(self):
return {name: p.clone() for name, p in self.model.named_parameters()}
def _set_params(self, params):
for name, p in self.model.named_parameters():
p.data.copy_(params[name])
def _evaluate(self, params, task):
self._set_params(params)
loss = self._task_loss(task)
return -loss.item()