生成对抗模仿学习
生成对抗模仿学习(GAIL)将模仿学习框架为生成对抗网络(GAN),通过对抗训练直接从专家演示中学习策略。1
问题背景
模仿学习的挑战
| 方法 | 优点 | 缺点 |
|---|---|---|
| 行为克隆 | 简单 | 分布偏移累积 |
| DAgger | 迭代改进 | 需要在线交互 |
GAIL的创新
GAIL通过引入判别器来区分专家轨迹和学习策略产生的轨迹,将模仿学习转化为对抗优化问题。
GAIL框架
核心思想
GAIL的核心是训练一个判别器 来区分专家轨迹 和学习轨迹 ,同时策略 试图欺骗判别器。
博弈论视角:
纳什均衡:
算法流程
┌─────────────────────────────────────────────────────────────┐
│ GAIL 训练流程 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ 专家 │ │ 判别器 │ │ 策略 │ │
│ │ 轨迹 │─────▶│ D_φ │◀────│ π_θ │ │
│ │ τ_E │ │ │ │ │ │
│ └─────────┘ └────┬────┘ └────┬────┘ │
│ │ │ │
│ ▼ ▼ │
│ log D(τ_E) log(1-D(τ_θ)) │
│ ▲ │ │
│ │ ▼ │
│ │ 奖励 = -log D(τ_θ) │
│ │ │ │
│ │ ▼ │
│ │ 策略梯度更新 │
│ │ │ │
│ └─────────────────┘ │
│ 循环 │
│ │
└─────────────────────────────────────────────────────────────┘
数学推导
目标函数
GAIL最小化学习轨迹与专家轨迹的JS散度:
奖励函数
从判别器提取的奖励:
策略梯度
使用GAE(广义优势估计)计算优势:
策略梯度:
算法实现
PyTorch实现
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
class Discriminator(nn.Module):
"""判别器网络:判断轨迹是专家的还是策略生成的"""
def __init__(self, state_dim, action_dim, hidden_dim=128):
super().__init__()
# 轨迹编码器
self.encoder = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
def forward(self, states, actions):
"""输入: 状态-动作序列"""
# 展平输入
x = torch.cat([states, actions], dim=-1)
return self.encoder(x)
def get_reward(self, states, actions):
"""获取GAIL奖励:r = -log(1 - D)"""
prob = self(states, actions)
return -torch.log(1 - prob + 1e-8)
class GAILAgent:
def __init__(self, state_dim, action_dim, hidden_dim=128, lr=3e-4):
# 策略网络(PPO Actor-Critic)
self.policy = PPOActorCritic(state_dim, action_dim, hidden_dim)
# 判别器
self.discriminator = Discriminator(state_dim, action_dim, hidden_dim)
# 优化器
self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=lr)
self.discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=lr)
def update(self, expert_buffer, policy_buffer, ppo_epochs=10):
"""GAIL更新"""
all_expert_states = torch.FloatTensor(expert_buffer.get_states())
all_expert_actions = torch.FloatTensor(expert_buffer.get_actions())
all_policy_states = torch.FloatTensor(policy_buffer.get_states())
all_policy_actions = torch.FloatTensor(policy_buffer.get_actions())
# 1. 更新判别器
for _ in range(5): # 多次更新判别器
expert_score = self.discriminator(all_expert_states, all_expert_actions)
policy_score = self.discriminator(all_policy_states, all_policy_actions)
# 判别器损失
disc_loss = (
-torch.log(expert_score + 1e-8).mean() -
torch.log(1 - policy_score + 1e-8).mean()
)
self.discriminator_optimizer.zero_grad()
disc_loss.backward()
self.discriminator_optimizer.step()
# 2. 获取GAIL奖励并更新策略
rewards = self.discriminator.get_reward(
all_policy_states,
all_policy_actions
)
# 使用PPO更新策略
self.policy.update(policy_buffer, rewards, ppo_epochs)PPO Actor-Critic实现
class PPOActorCritic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=128):
super().__init__()
# Actor
self.actor = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
# Critic
self.critic = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
self.log_std = nn.Parameter(torch.zeros(action_dim))
def forward(self, state):
return self.actor(state), self.critic(state)
def get_action(self, state):
mean = self.actor(state)
std = torch.exp(self.log_std)
dist = torch.distributions.Normal(mean, std)
action = dist.sample()
log_prob = dist.log_prob(action).sum(dim=-1)
return action, log_prob
def evaluate(self, state, action):
mean = self.actor(state)
std = torch.exp(self.log_std)
dist = torch.distributions.Normal(mean, std)
log_prob = dist.log_prob(action).sum(dim=-1)
value = self.critic(state)
return log_prob, valueGAIL变体
InfoGAIL
InfoGAIL引入互信息来增加可解释性。2
改进:
其中 是潜在变量,表示不同的行为风格。
Wasserstein GAIL (WGAIL)
使用Wasserstein距离替代JS散度,提高训练稳定性。3
class WGAN_GP_Discriminator(nn.Module):
"""带梯度惩罚的WGAN判别器"""
def __init__(self, state_dim, action_dim, hidden_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def gradient_penalty(self, real, fake):
"""计算梯度惩罚"""
alpha = torch.rand(real.size(0), 1)
interpolated = alpha * real + (1 - alpha) * fake
mixed_scores = self.net(interpolated)
grad = torch.autograd.grad(
outputs=mixed_scores,
inputs=interpolated,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True
)[0]
grad_norm = grad.norm(2, dim=1)
penalty = ((grad_norm - 1) ** 2).mean()
return penalty
def wasserstein_loss(self, real, fake):
return fake.mean() - real.mean()Point-GAIL
针对稀疏奖励场景的改进版本。
理论分析
收敛性
定理(Ho & Ermon, 2016):在无限样本和完美优化下,GAIL可以恢复一个与专家等价的策略。
与IRL的联系
GAIL可以视为隐式的IRL算法:
- 判别器学习奖励函数
- 策略通过最大化这个隐式奖励来学习
局限性
- 模式崩溃:可能只学习部分专家行为
- 不稳定性:对抗训练可能导致模式崩塌
- 样本效率:需要大量环境交互
实践技巧
超参数设置
| 参数 | 推荐值 | 说明 |
|---|---|---|
| 判别器更新次数 | 3-5 | 每策略更新一次判别器 |
| 学习率 | 3e-4 | 策略和判别器相同 |
| 熵正则化 | 0.01 | 防止过早收敛 |
| GAE λ | 0.95 | 优势估计参数 |
梯度惩罚
添加梯度惩罚可以提高训练稳定性:
gp = discriminator.gradient_penalty(expert_states, policy_states)
disc_loss += 10 * gp