概述

PPO(Proximal Policy Optimization,近端策略优化)由Schulman等人于2017年提出,是目前最广泛使用的强化学习算法之一。1

核心思想:通过裁剪概率比,限制策略更新幅度,确保训练稳定。

PPO在Atari、MuJoCo、 robotic control等任务上取得了优异性能,被广泛应用于AlphaGo、Boston Dynamics等实际系统。

问题背景

策略梯度的挑战

标准策略梯度方法面临两个问题:

  1. 大步伐更新导致崩溃:策略的微小参数变化可能导致动作分布剧烈变化
  2. 样本效率低:每个样本只能使用一次(on-policy)

信任域方法

TRPO(Trust Region Policy Optimization)通过KL散度约束限制更新:

但TRPO计算复杂(需要共轭梯度、Hessian向量乘积)。

PPO核心思想

裁剪替代约束

PPO用裁剪目标替代KL约束:

其中 是概率比。

直观理解

当 A > 0(动作好):
   - 如果 r ↑(概率增加过多),裁剪阻止继续增加
   - 鼓励适度增加概率

当 A < 0(动作差):
   - 如果 r ↓(概率减少过多),裁剪阻止继续减少
   - 鼓励适度减少概率

算法详解

PPO-Penalty vs PPO-Clip

变体方法公式
PPO-Penalty自适应KL惩罚
PPO-Clip概率比裁剪

实践中PPO-Clip更常用。

完整算法流程

1. 初始化策略参数 θ_0 和价值函数参数 φ_0

2. 对每个epoch:
   a) 用当前策略收集T个 timesteps 的数据:
      - 存储: (s_t, a_t, r_t, v_t, log π_θ(a_t|s_t))
   
   b) 计算GAE优势估计 Â_t
   
   c) 优化价值函数:
      L^V(φ) = Σ_t (V_φ(s_t) - Â_t)^2
      
   d) 优化策略(多个epoch小批量更新):
      L^CLIP(θ) = Σ_t min(r_t(θ) Â_t, clip(r_t(θ), 1-ε, 1+ε) Â_t)
      
   e) θ_old ← θ

Python实现

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
 
class ActorCritic(nn.Module):
    """Actor-Critic网络"""
    
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super().__init__()
        
        # Actor
        self.actor = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )
        
        # Critic
        self.critic = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, x):
        return self.actor(x), self.critic(x)
 
 
class PPOAgent:
    """PPO智能体"""
    
    def __init__(
        self,
        state_dim,
        action_dim,
        lr=3e-4,
        gamma=0.99,
        lambd=0.95,
        epsilon=0.2,
        k_epochs=4,
        n_workers=8,
        n_steps=2048,
        batch_size=64,
        entropy_coef=0.01,
        value_coef=0.5,
        max_grad_norm=0.5
    ):
        self.gamma = gamma
        self.lambd = lambd
        self.epsilon = epsilon
        self.k_epochs = k_epochs
        self.entropy_coef = entropy_coef
        self.value_coef = value_coef
        self.max_grad_norm = max_grad_norm
        
        self.policy = ActorCritic(state_dim, action_dim)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        
        # 存储
        self.buffer = {
            'states': [],
            'actions': [],
            'rewards': [],
            'values': [],
            'log_probs': [],
            'dones': []
        }
    
    def choose_action(self, state, training=True):
        state = torch.FloatTensor(state).unsqueeze(0)
        
        with torch.no_grad():
            probs, value = self.policy(state)
        
        dist = torch.distributions.Categorical(probs)
        action = dist.sample() if training else probs.argmax()
        log_prob = dist.log_prob(action)
        
        return action.item(), log_prob, value.item()
    
    def store(self, state, action, reward, value, log_prob, done):
        self.buffer['states'].append(state)
        self.buffer['actions'].append(action)
        self.buffer['rewards'].append(reward)
        self.buffer['values'].append(value)
        self.buffer['log_probs'].append(log_prob)
        self.buffer['dones'].append(done)
    
    def compute_gae(self, rewards, values, dones, gamma, lambd):
        """计算GAE"""
        advantages = []
        gae = 0
        
        values = values + [0]  # 加上最后一步的value
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t]
            gae = delta + gamma * lambd * (1 - dones[t]) * gae
            advantages.insert(0, gae)
        
        advantages = torch.FloatTensor(advantages)
        returns = advantages + torch.FloatTensor(values[:-1])
        
        return advantages, returns
    
    def update(self):
        """PPO更新"""
        # 转换为张量
        states = torch.FloatTensor(np.array(self.buffer['states']))
        actions = torch.LongTensor(self.buffer['actions'])
        old_log_probs = torch.FloatTensor([lp.item() for lp in self.buffer['log_probs']])
        
        # 计算GAE
        rewards = self.buffer['rewards']
        values = self.buffer['values']
        dones = self.buffer['dones']
        advantages, returns = self.compute_gae(rewards, values, dones, self.gamma, self.lambd)
        
        # 标准化优势
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # 多个epoch更新
        for _ in range(self.k_epochs):
            # 随机打乱
            indices = torch.randperm(len(states))
            
            for start in range(0, len(states), 32):  # mini-batch
                end = start + 32
                idx = indices[start:end]
                
                batch_states = states[idx]
                batch_actions = actions[idx]
                batch_old_log_probs = old_log_probs[idx]
                batch_advantages = advantages[idx]
                batch_returns = returns[idx]
                
                # 当前策略
                probs, values_pred = self.policy(batch_states)
                dist = torch.distributions.Categorical(probs)
                
                # 新log概率
                new_log_probs = dist.log_prob(batch_actions)
                
                # 概率比
                ratio = torch.exp(new_log_probs - batch_old_log_probs)
                
                # 裁剪目标
                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * batch_advantages
                policy_loss = -torch.min(surr1, surr2).mean()
                
                # 价值损失
                value_loss = nn.functional.mse_loss(values_pred.squeeze(), batch_returns)
                
                # 熵损失(鼓励探索)
                entropy = dist.entropy().mean()
                
                # 总损失
                loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy
                
                # 更新
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
                self.optimizer.step()
        
        # 清空buffer
        for key in self.buffer:
            self.buffer[key] = []
        
        return policy_loss.item(), value_loss.item(), entropy.item()
    
    def collect_rollout(self, env, n_steps):
        """收集经验"""
        state = env.reset()
        
        for _ in range(n_steps):
            action, log_prob, value = self.choose_action(state)
            next_state, reward, done, _ = env.step(action)
            
            self.store(state, action, reward, value, log_prob, done)
            
            state = next_state
            if done:
                state = env.reset()

PPO-Penalty(自适应KL惩罚)

PPO-Penalty使用自适应KL惩罚系数:

class PPOPenaltyAgent:
    """PPO with Adaptive KL Penalty"""
    
    def __init__(self, state_dim, action_dim):
        # ...
        self.kl_target = 0.01  # 目标KL散度
        self.beta = 1.0       # 初始惩罚系数
        self.beta_min = 1e-4
        self.beta_max = 20.0
        self.kl_coef = 1.5
    
    def compute_kl_divergence(self, old_probs, new_probs):
        """计算KL散度"""
        return (old_probs * (old_probs / (new_probs + 1e-8)).log()).sum(dim=-1).mean()
    
    def update(self):
        # ... 收集数据后
        
        for _ in range(self.k_epochs):
            # 计算KL散度
            new_probs, _ = self.policy(states)
            kl = self.compute_kl_divergence(old_probs, new_probs)
            
            # 自适应调整beta
            if kl < self.kl_target / 1.5:
                self.beta /= self.kl_coef
            elif kl > self.kl_target * 1.5:
                self.beta *= self.kl_coef
            self.beta = np.clip(self.beta, self.beta_min, self.beta_max)
            
            # 计算带KL惩罚的损失
            loss = policy_loss + self.beta * kl
            
            # 更新...

PPO的超参数

超参数典型值说明
Clip范围0.1-0.2控制策略更新幅度
**学习率3e-4可用学习率衰减
**Mini-batch数4-32每个epoch更新次数
**Epoch数10-30数据复用次数
**GAE λ0.95偏差-方差权衡
**熵系数0-0.01探索程度
**价值系数0.5-1.0Critic重要性
**梯度裁剪0.5-1.0稳定训练

与其他算法对比

方面PPOTRPOA2CDDPG
策略更新裁剪KL约束同步离线
采样效率中等中等
实现复杂度
连续动作支持支持支持优秀
离散动作优秀优秀优秀不支持
超参数敏感度

应用示例:CartPole

import gym
 
def train_cartpole():
    env = gym.make('CartPole-v1')
    
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    
    agent = PPOAgent(state_dim, action_dim)
    
    n_episodes = 500
    max_steps = 500
    
    for episode in range(n_episodes):
        state = env.reset()
        
        for step in range(max_steps):
            action, log_prob, value = agent.choose_action(state)
            next_state, reward, done, _ = env.step(action)
            
            agent.store(state, action, reward, value, log_prob, done)
            
            state = next_state
            if done:
                break
        
        # PPO更新
        if len(agent.buffer['states']) >= 2048:
            agent.update()
        
        # 记录奖励
        episode_reward = step + 1
        if (episode + 1) % 10 == 0:
            print(f"Episode {episode+1}, Reward: {episode_reward}")
        
        if episode_reward >= 475:
            print(f"Solved in {episode+1} episodes!")
            break
 
 
if __name__ == "__main__":
    train_cartpole()

分布式PPO(PPOO)

对于大规模训练,可以使用分布式架构:

class DistributedPPO:
    """分布式PPO"""
    
    def __init__(self, n_workers=16, n_envs_per_worker=8):
        self.workers = [
            ParallelEnv(n_envs=n_envs_per_worker)
            for _ in range(n_workers)
        ]
        self.global_agent = PPOAgent(state_dim, action_dim)
        
        # Rollout阶段:并行收集数据
        def rollout_phase():
            all_data = []
            for worker in self.workers:
                data = worker.collect(n_steps=128)
                all_data.extend(data)
            return all_data
        
        # Training阶段:更新全局网络
        def training_phase(data):
            self.global_agent.load_data(data)
            for _ in range(10):  # 10个epoch
                self.global_agent.update()
        
        # 启动
        while True:
            data = rollout_phase()
            training_phase(data)

参考


后续主题

Footnotes

  1. Schulman et al., “Proximal Policy Optimization Algorithms”, arXiv:1707.06347, 2017