生成对抗模仿学习

生成对抗模仿学习(GAIL)将模仿学习框架为生成对抗网络(GAN),通过对抗训练直接从专家演示中学习策略。1

问题背景

模仿学习的挑战

方法优点缺点
行为克隆简单分布偏移累积
DAgger迭代改进需要在线交互

GAIL的创新

GAIL通过引入判别器来区分专家轨迹和学习策略产生的轨迹,将模仿学习转化为对抗优化问题。


GAIL框架

核心思想

GAIL的核心是训练一个判别器 来区分专家轨迹 和学习轨迹 ,同时策略 试图欺骗判别器。

博弈论视角

纳什均衡

算法流程

┌─────────────────────────────────────────────────────────────┐
│                    GAIL 训练流程                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  ┌─────────┐      ┌─────────┐      ┌─────────┐           │
│  │  专家   │      │ 判别器  │      │  策略   │           │
│  │ 轨迹    │─────▶│  D_φ   │◀────│  π_θ   │           │
│  │ τ_E     │      │        │      │         │           │
│  └─────────┘      └────┬────┘      └────┬────┘           │
│                         │                 │                │
│                         ▼                 ▼                │
│                    log D(τ_E)        log(1-D(τ_θ))        │
│                         ▲                 │                │
│                         │                 ▼                │
│                         │           奖励 = -log D(τ_θ)     │
│                         │                 │                │
│                         │                 ▼                │
│                         │           策略梯度更新           │
│                         │                 │                │
│                         └─────────────────┘                │
│                              循环                          │
│                                                             │
└─────────────────────────────────────────────────────────────┘

数学推导

目标函数

GAIL最小化学习轨迹与专家轨迹的JS散度

奖励函数

从判别器提取的奖励:

策略梯度

使用GAE(广义优势估计)计算优势:

策略梯度:


算法实现

PyTorch实现

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
 
class Discriminator(nn.Module):
    """判别器网络:判断轨迹是专家的还是策略生成的"""
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        
        # 轨迹编码器
        self.encoder = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, states, actions):
        """输入: 状态-动作序列"""
        # 展平输入
        x = torch.cat([states, actions], dim=-1)
        return self.encoder(x)
    
    def get_reward(self, states, actions):
        """获取GAIL奖励:r = -log(1 - D)"""
        prob = self(states, actions)
        return -torch.log(1 - prob + 1e-8)
 
 
class GAILAgent:
    def __init__(self, state_dim, action_dim, hidden_dim=128, lr=3e-4):
        # 策略网络(PPO Actor-Critic)
        self.policy = PPOActorCritic(state_dim, action_dim, hidden_dim)
        
        # 判别器
        self.discriminator = Discriminator(state_dim, action_dim, hidden_dim)
        
        # 优化器
        self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=lr)
    
    def update(self, expert_buffer, policy_buffer, ppo_epochs=10):
        """GAIL更新"""
        all_expert_states = torch.FloatTensor(expert_buffer.get_states())
        all_expert_actions = torch.FloatTensor(expert_buffer.get_actions())
        all_policy_states = torch.FloatTensor(policy_buffer.get_states())
        all_policy_actions = torch.FloatTensor(policy_buffer.get_actions())
        
        # 1. 更新判别器
        for _ in range(5):  # 多次更新判别器
            expert_score = self.discriminator(all_expert_states, all_expert_actions)
            policy_score = self.discriminator(all_policy_states, all_policy_actions)
            
            # 判别器损失
            disc_loss = (
                -torch.log(expert_score + 1e-8).mean() - 
                torch.log(1 - policy_score + 1e-8).mean()
            )
            
            self.discriminator_optimizer.zero_grad()
            disc_loss.backward()
            self.discriminator_optimizer.step()
        
        # 2. 获取GAIL奖励并更新策略
        rewards = self.discriminator.get_reward(
            all_policy_states, 
            all_policy_actions
        )
        
        # 使用PPO更新策略
        self.policy.update(policy_buffer, rewards, ppo_epochs)

PPO Actor-Critic实现

class PPOActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        
        # Actor
        self.actor = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
        
        # Critic
        self.critic = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        self.log_std = nn.Parameter(torch.zeros(action_dim))
    
    def forward(self, state):
        return self.actor(state), self.critic(state)
    
    def get_action(self, state):
        mean = self.actor(state)
        std = torch.exp(self.log_std)
        dist = torch.distributions.Normal(mean, std)
        action = dist.sample()
        log_prob = dist.log_prob(action).sum(dim=-1)
        return action, log_prob
    
    def evaluate(self, state, action):
        mean = self.actor(state)
        std = torch.exp(self.log_std)
        dist = torch.distributions.Normal(mean, std)
        log_prob = dist.log_prob(action).sum(dim=-1)
        value = self.critic(state)
        return log_prob, value

GAIL变体

InfoGAIL

InfoGAIL引入互信息来增加可解释性。2

改进

其中 是潜在变量,表示不同的行为风格。

Wasserstein GAIL (WGAIL)

使用Wasserstein距离替代JS散度,提高训练稳定性。3

class WGAN_GP_Discriminator(nn.Module):
    """带梯度惩罚的WGAN判别器"""
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def gradient_penalty(self, real, fake):
        """计算梯度惩罚"""
        alpha = torch.rand(real.size(0), 1)
        interpolated = alpha * real + (1 - alpha) * fake
        
        mixed_scores = self.net(interpolated)
        grad = torch.autograd.grad(
            outputs=mixed_scores,
            inputs=interpolated,
            grad_outputs=torch.ones_like(mixed_scores),
            create_graph=True
        )[0]
        
        grad_norm = grad.norm(2, dim=1)
        penalty = ((grad_norm - 1) ** 2).mean()
        return penalty
    
    def wasserstein_loss(self, real, fake):
        return fake.mean() - real.mean()

Point-GAIL

针对稀疏奖励场景的改进版本。


理论分析

收敛性

定理(Ho & Ermon, 2016):在无限样本和完美优化下,GAIL可以恢复一个与专家等价的策略。

与IRL的联系

GAIL可以视为隐式的IRL算法:

  • 判别器学习奖励函数
  • 策略通过最大化这个隐式奖励来学习

局限性

  1. 模式崩溃:可能只学习部分专家行为
  2. 不稳定性:对抗训练可能导致模式崩塌
  3. 样本效率:需要大量环境交互

实践技巧

超参数设置

参数推荐值说明
判别器更新次数3-5每策略更新一次判别器
学习率3e-4策略和判别器相同
熵正则化0.01防止过早收敛
GAE λ0.95优势估计参数

梯度惩罚

添加梯度惩罚可以提高训练稳定性:

gp = discriminator.gradient_penalty(expert_states, policy_states)
disc_loss += 10 * gp

参考资料


相关链接

Footnotes

  1. Ho & Ermon, “Generative Adversarial Imitation Learning”, NeurIPS, 2016

  2. Li et al., “Infogail: Interpretable Imitation Learning from Visual Demonstrations”, NeurIPS, 2017

  3. Liu et al., “Improved Training of Wasserstein GANs”, NeurIPS, 2017