概述
PPO(Proximal Policy Optimization,近端策略优化)由Schulman等人于2017年提出,是目前最广泛使用的强化学习算法之一。1
核心思想:通过裁剪概率比,限制策略更新幅度,确保训练稳定。
PPO在Atari、MuJoCo、 robotic control等任务上取得了优异性能,被广泛应用于AlphaGo、Boston Dynamics等实际系统。
问题背景
策略梯度的挑战
标准策略梯度方法面临两个问题:
- 大步伐更新导致崩溃:策略的微小参数变化可能导致动作分布剧烈变化
- 样本效率低:每个样本只能使用一次(on-policy)
信任域方法
TRPO(Trust Region Policy Optimization)通过KL散度约束限制更新:
但TRPO计算复杂(需要共轭梯度、Hessian向量乘积)。
PPO核心思想
裁剪替代约束
PPO用裁剪目标替代KL约束:
其中 是概率比。
直观理解
当 A > 0(动作好):
- 如果 r ↑(概率增加过多),裁剪阻止继续增加
- 鼓励适度增加概率
当 A < 0(动作差):
- 如果 r ↓(概率减少过多),裁剪阻止继续减少
- 鼓励适度减少概率
算法详解
PPO-Penalty vs PPO-Clip
| 变体 | 方法 | 公式 |
|---|---|---|
| PPO-Penalty | 自适应KL惩罚 | |
| PPO-Clip | 概率比裁剪 |
实践中PPO-Clip更常用。
完整算法流程
1. 初始化策略参数 θ_0 和价值函数参数 φ_0
2. 对每个epoch:
a) 用当前策略收集T个 timesteps 的数据:
- 存储: (s_t, a_t, r_t, v_t, log π_θ(a_t|s_t))
b) 计算GAE优势估计 Â_t
c) 优化价值函数:
L^V(φ) = Σ_t (V_φ(s_t) - Â_t)^2
d) 优化策略(多个epoch小批量更新):
L^CLIP(θ) = Σ_t min(r_t(θ) Â_t, clip(r_t(θ), 1-ε, 1+ε) Â_t)
e) θ_old ← θ
Python实现
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
class ActorCritic(nn.Module):
"""Actor-Critic网络"""
def __init__(self, state_dim, action_dim, hidden_dim=64):
super().__init__()
# Actor
self.actor = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)
# Critic
self.critic = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
return self.actor(x), self.critic(x)
class PPOAgent:
"""PPO智能体"""
def __init__(
self,
state_dim,
action_dim,
lr=3e-4,
gamma=0.99,
lambd=0.95,
epsilon=0.2,
k_epochs=4,
n_workers=8,
n_steps=2048,
batch_size=64,
entropy_coef=0.01,
value_coef=0.5,
max_grad_norm=0.5
):
self.gamma = gamma
self.lambd = lambd
self.epsilon = epsilon
self.k_epochs = k_epochs
self.entropy_coef = entropy_coef
self.value_coef = value_coef
self.max_grad_norm = max_grad_norm
self.policy = ActorCritic(state_dim, action_dim)
self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
# 存储
self.buffer = {
'states': [],
'actions': [],
'rewards': [],
'values': [],
'log_probs': [],
'dones': []
}
def choose_action(self, state, training=True):
state = torch.FloatTensor(state).unsqueeze(0)
with torch.no_grad():
probs, value = self.policy(state)
dist = torch.distributions.Categorical(probs)
action = dist.sample() if training else probs.argmax()
log_prob = dist.log_prob(action)
return action.item(), log_prob, value.item()
def store(self, state, action, reward, value, log_prob, done):
self.buffer['states'].append(state)
self.buffer['actions'].append(action)
self.buffer['rewards'].append(reward)
self.buffer['values'].append(value)
self.buffer['log_probs'].append(log_prob)
self.buffer['dones'].append(done)
def compute_gae(self, rewards, values, dones, gamma, lambd):
"""计算GAE"""
advantages = []
gae = 0
values = values + [0] # 加上最后一步的value
for t in reversed(range(len(rewards))):
delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t]
gae = delta + gamma * lambd * (1 - dones[t]) * gae
advantages.insert(0, gae)
advantages = torch.FloatTensor(advantages)
returns = advantages + torch.FloatTensor(values[:-1])
return advantages, returns
def update(self):
"""PPO更新"""
# 转换为张量
states = torch.FloatTensor(np.array(self.buffer['states']))
actions = torch.LongTensor(self.buffer['actions'])
old_log_probs = torch.FloatTensor([lp.item() for lp in self.buffer['log_probs']])
# 计算GAE
rewards = self.buffer['rewards']
values = self.buffer['values']
dones = self.buffer['dones']
advantages, returns = self.compute_gae(rewards, values, dones, self.gamma, self.lambd)
# 标准化优势
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# 多个epoch更新
for _ in range(self.k_epochs):
# 随机打乱
indices = torch.randperm(len(states))
for start in range(0, len(states), 32): # mini-batch
end = start + 32
idx = indices[start:end]
batch_states = states[idx]
batch_actions = actions[idx]
batch_old_log_probs = old_log_probs[idx]
batch_advantages = advantages[idx]
batch_returns = returns[idx]
# 当前策略
probs, values_pred = self.policy(batch_states)
dist = torch.distributions.Categorical(probs)
# 新log概率
new_log_probs = dist.log_prob(batch_actions)
# 概率比
ratio = torch.exp(new_log_probs - batch_old_log_probs)
# 裁剪目标
surr1 = ratio * batch_advantages
surr2 = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) * batch_advantages
policy_loss = -torch.min(surr1, surr2).mean()
# 价值损失
value_loss = nn.functional.mse_loss(values_pred.squeeze(), batch_returns)
# 熵损失(鼓励探索)
entropy = dist.entropy().mean()
# 总损失
loss = policy_loss + self.value_coef * value_loss - self.entropy_coef * entropy
# 更新
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.optimizer.step()
# 清空buffer
for key in self.buffer:
self.buffer[key] = []
return policy_loss.item(), value_loss.item(), entropy.item()
def collect_rollout(self, env, n_steps):
"""收集经验"""
state = env.reset()
for _ in range(n_steps):
action, log_prob, value = self.choose_action(state)
next_state, reward, done, _ = env.step(action)
self.store(state, action, reward, value, log_prob, done)
state = next_state
if done:
state = env.reset()PPO-Penalty(自适应KL惩罚)
PPO-Penalty使用自适应KL惩罚系数:
class PPOPenaltyAgent:
"""PPO with Adaptive KL Penalty"""
def __init__(self, state_dim, action_dim):
# ...
self.kl_target = 0.01 # 目标KL散度
self.beta = 1.0 # 初始惩罚系数
self.beta_min = 1e-4
self.beta_max = 20.0
self.kl_coef = 1.5
def compute_kl_divergence(self, old_probs, new_probs):
"""计算KL散度"""
return (old_probs * (old_probs / (new_probs + 1e-8)).log()).sum(dim=-1).mean()
def update(self):
# ... 收集数据后
for _ in range(self.k_epochs):
# 计算KL散度
new_probs, _ = self.policy(states)
kl = self.compute_kl_divergence(old_probs, new_probs)
# 自适应调整beta
if kl < self.kl_target / 1.5:
self.beta /= self.kl_coef
elif kl > self.kl_target * 1.5:
self.beta *= self.kl_coef
self.beta = np.clip(self.beta, self.beta_min, self.beta_max)
# 计算带KL惩罚的损失
loss = policy_loss + self.beta * kl
# 更新...PPO的超参数
| 超参数 | 典型值 | 说明 |
|---|---|---|
| Clip范围 | 0.1-0.2 | 控制策略更新幅度 |
| **学习率 | 3e-4 | 可用学习率衰减 |
| **Mini-batch数 | 4-32 | 每个epoch更新次数 |
| **Epoch数 | 10-30 | 数据复用次数 |
| **GAE λ | 0.95 | 偏差-方差权衡 |
| **熵系数 | 0-0.01 | 探索程度 |
| **价值系数 | 0.5-1.0 | Critic重要性 |
| **梯度裁剪 | 0.5-1.0 | 稳定训练 |
与其他算法对比
| 方面 | PPO | TRPO | A2C | DDPG |
|---|---|---|---|---|
| 策略更新 | 裁剪 | KL约束 | 同步 | 离线 |
| 采样效率 | 中等 | 中等 | 低 | 高 |
| 实现复杂度 | 低 | 高 | 低 | 中 |
| 连续动作 | 支持 | 支持 | 支持 | 优秀 |
| 离散动作 | 优秀 | 优秀 | 优秀 | 不支持 |
| 超参数敏感度 | 低 | 中 | 低 | 高 |
应用示例:CartPole
import gym
def train_cartpole():
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = PPOAgent(state_dim, action_dim)
n_episodes = 500
max_steps = 500
for episode in range(n_episodes):
state = env.reset()
for step in range(max_steps):
action, log_prob, value = agent.choose_action(state)
next_state, reward, done, _ = env.step(action)
agent.store(state, action, reward, value, log_prob, done)
state = next_state
if done:
break
# PPO更新
if len(agent.buffer['states']) >= 2048:
agent.update()
# 记录奖励
episode_reward = step + 1
if (episode + 1) % 10 == 0:
print(f"Episode {episode+1}, Reward: {episode_reward}")
if episode_reward >= 475:
print(f"Solved in {episode+1} episodes!")
break
if __name__ == "__main__":
train_cartpole()分布式PPO(PPOO)
对于大规模训练,可以使用分布式架构:
class DistributedPPO:
"""分布式PPO"""
def __init__(self, n_workers=16, n_envs_per_worker=8):
self.workers = [
ParallelEnv(n_envs=n_envs_per_worker)
for _ in range(n_workers)
]
self.global_agent = PPOAgent(state_dim, action_dim)
# Rollout阶段:并行收集数据
def rollout_phase():
all_data = []
for worker in self.workers:
data = worker.collect(n_steps=128)
all_data.extend(data)
return all_data
# Training阶段:更新全局网络
def training_phase(data):
self.global_agent.load_data(data)
for _ in range(10): # 10个epoch
self.global_agent.update()
# 启动
while True:
data = rollout_phase()
training_phase(data)参考
后续主题
- RLHF:人类反馈强化学习
- Actor-Critic:Actor-Critic框架基础
- 策略梯度:REINFORCE算法
Footnotes
-
Schulman et al., “Proximal Policy Optimization Algorithms”, arXiv:1707.06347, 2017 ↩