概述
Actor-Critic方法结合了策略梯度(Actor)和价值函数(Critic)的优点,是一种广泛使用的强化学习框架。1
Actor:负责学习策略(输出动作)
Critic:负责评估策略(估计价值函数)
协同训练:Critic帮助Actor减小方差,Actor提供学习信号
为什么需要Actor-Critic?
REINFORCE的问题
REINFORCE使用蒙特卡洛回报 作为梯度估计器:
问题:
- 高方差:需要完整episode才能估计,方差随轨迹长度指数增长
- 低样本效率:必须等待episode结束
解决方案
用Critic估计的基线替代蒙特卡洛回报:
其中 是优势函数,可以用TD error近似。
框架结构
┌─────────────────────────────────────────────────────────────┐
│ Actor-Critic 框架 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────┐ │
│ │ 环境 │ │
│ └────┬─────┘ │
│ │ s_t, r_t │
│ ▼ │
│ ┌──────────────────────────────────────────┐ │
│ │ Critic (评价者) │ │
│ │ ┌─────────────┐ ┌─────────────────┐ │ │
│ │ │ 价值网络 │ │ TD Error │ │ │
│ │ │ V(s;θ_v) │──▶│ δ_t = r + γV(s')│ │ │
│ │ └─────────────┘ │ - V(s;θ_v) │ │ │
│ └───────────────────┴─────────────────┘────┘ │
│ │ δ_t (学习信号) │
│ ▼ │
│ ┌──────────────────────────────────────────┐ │
│ │ Actor (行动者) │ │
│ │ ┌─────────────┐ ┌─────────────────┐ │ │
│ │ │ 策略网络 │ │ 策略梯度更新 │ │ │
│ │ │ π(a|s;θ_π) │◀─│ δ_t·∇log π(a|s) │ │ │
│ │ └─────────────┘ └─────────────────┘ │ │
│ └───────────────────┴─────────────────────┘ │
│ │ a_t (动作) │
│ ▼ │
│ ┌──────────┐ │
│ │ 环境 │ │
│ └──────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
算法详解
基本Actor-Critic流程
1. 初始化:
- Actor参数 θ_π
- Critic参数 θ_v
2. 对每个episode:
a) 初始化状态 s
b) 对每一步:
- Actor选择动作: a ~ π_θ(a|s)
- 执行动作,获得 r, s'
- Critic计算TD error:
δ_t = r + γ V(s'; θ_v) - V(s; θ_v)
- Critic更新:
θ_v ← θ_v + α_v · δ_t · ∇_θ_v V(s; θ_v)
- Actor更新:
θ_π ← θ_π + α_π · δ_t · ∇_θ_π log π_θ(a|s)
- s ← s'
Python实现
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
class Actor(nn.Module):
"""Actor网络:策略"""
def __init__(self, state_dim, action_dim, hidden_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Softmax(dim=-1)
)
def forward(self, x):
return self.net(x)
class Critic(nn.Module):
"""Critic网络:价值函数"""
def __init__(self, state_dim, hidden_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
return self.net(x)
class ActorCriticAgent:
"""Actor-Critic智能体"""
def __init__(
self,
state_dim,
action_dim,
actor_lr=1e-3,
critic_lr=1e-3,
gamma=0.99,
entropy_coef=0.01
):
self.gamma = gamma
self.entropy_coef = entropy_coef
self.actor = Actor(state_dim, action_dim)
self.critic = Critic(state_dim)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)
def choose_action(self, state):
state = torch.FloatTensor(state).unsqueeze(0)
probs = self.actor(state)
dist = torch.distributions.Categorical(probs)
action = dist.sample()
return action.item(), dist.log_prob(action)
def update(self, state, action, reward, next_state, done):
state = torch.FloatTensor(state).unsqueeze(0)
next_state = torch.FloatTensor(next_state).unsqueeze(0)
action = torch.LongTensor([action])
# 1. Critic更新:计算TD error
with torch.no_grad():
if done:
target = torch.FloatTensor([reward])
else:
target = reward + self.gamma * self.critic(next_state)
current_value = self.critic(state)
critic_loss = nn.MSELoss()(current_value, target)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# 2. Actor更新:用TD error作为优势估计
td_error = (target - current_value).detach()
log_prob = torch.log(self.actor(state) + 1e-8)
action_log_prob = log_prob[0, action]
# 策略梯度损失 + 熵正则
entropy = (torch.log(self.actor(state) + 1e-8) * self.actor(state)).sum()
actor_loss = -(td_error * action_log_prob) - self.entropy_coef * entropy
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
return critic_loss.item(), actor_loss.item()
def train_actor_critic(env, agent, n_episodes=1000):
"""训练Actor-Critic"""
rewards_history = []
for episode in range(n_episodes):
state = env.reset()
done = False
episode_reward = 0
while not done:
action, _ = agent.choose_action(state)
next_state, reward, done, _ = env.step(action)
agent.update(state, action, reward, next_state, done)
state = next_state
episode_reward += reward
rewards_history.append(episode_reward)
if (episode + 1) % 100 == 0:
avg_reward = np.mean(rewards_history[-100:])
print(f"Episode {episode+1}, Avg Reward: {avg_reward:.2f}")
return rewards_history优势函数与GAE
N步TD优势估计
其中 是TD error。
GAE(Generalized Advantage Estimation)
GAE通过加权平均所有n步TD advantage来平衡偏差和方差:2
def compute_gae(rewards, values, next_values, gamma=0.99, lambd=0.95):
"""
计算GAE优势估计
参数:
rewards: 奖励序列
values: 价值函数序列 (包括最后一步)
gamma: 折扣因子
lambd: GAE参数 (0-1)
返回:
advantages: 优势估计
returns: 回报(用于Critic训练)
"""
advantages = []
gae = 0
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_value = 0
else:
next_value = values[t + 1]
delta = rewards[t] + gamma * next_value - values[t]
gae = delta + gamma * lambd * gae
advantages.insert(0, gae)
advantages = np.array(advantages)
returns = advantages + values[:-1] # 不包括最后一步的价值
return advantages, returns连续动作空间的Actor-Critic
SAC (Soft Actor-Critic)
SAC是一种最大熵Actor-Critic算法:
class SACAgent:
"""Soft Actor-Critic"""
def __init__(self, state_dim, action_dim, action_bound,
actor_lr=3e-4, critic_lr=3e-4, alpha_lr=3e-4,
gamma=0.99, tau=0.005):
self.gamma = gamma
self.tau = tau
self.action_bound = action_bound
# Actor
self.actor = GaussianPolicy(state_dim, action_dim)
# 双Critic
self.critic1 = Critic(state_dim, action_dim)
self.critic2 = Critic(state_dim, action_dim)
self.target_critic1 = Critic(state_dim, action_dim)
self.target_critic2 = Critic(state_dim, action_dim)
# 自动熵温度
self.log_alpha = torch.zeros(1, requires_grad=True)
self.target_entropy = -action_dim
# 优化器
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=critic_lr)
self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=critic_lr)
self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr)
def update(self, states, actions, rewards, next_states, dones):
alpha = self.log_alpha.exp()
# 1. Critic更新
with torch.no_grad():
next_actions, next_log_probs = self.actor.sample(next_states)
q1_target = self.target_critic1(next_states, next_actions)
q2_target = self.target_critic2(next_states, next_actions)
q_target = torch.min(q1_target, q2_target)
next_value = q_target - alpha * next_log_probs
q_target = rewards + self.gamma * (1 - dones) * next_value
q1 = self.critic1(states, actions)
q2 = self.critic2(states, actions)
critic1_loss = nn.MSELoss()(q1, q_target)
critic2_loss = nn.MSELoss()(q2, q_target)
self.critic1_optimizer.zero_grad()
self.critic1_loss.backward()
self.critic1_optimizer.step()
self.critic2_optimizer.zero_grad()
self.critic2_loss.backward()
self.critic2_optimizer.step()
# 2. Actor更新
actions_new, log_probs = self.actor.sample(states)
q1_new = self.critic1(states, actions_new)
q2_new = self.critic2(states, actions_new)
q_new = torch.min(q1_new, q2_new)
actor_loss = (alpha * log_probs - q_new).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# 3. 温度参数更新
alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
self.alpha_optimizer.zero_grad()
alpha_loss.backward()
self.alpha_optimizer.step()
# 4. 软更新目标网络
self.soft_update(self.critic1, self.target_critic1)
self.soft_update(self.critic2, self.target_critic2)
def soft_update(self, source, target):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(
self.tau * param.data + (1 - self.tau) * target_param.data
)A3C (Asynchronous Advantage Actor-Critic)
A3C使用多线程异步训练:3
import multiprocessing as mp
class A3CWorker(mp.Process):
"""A3C工作线程"""
def __init__(self, global_agent, worker_id, env, gamma=0.99, lambd=0.95):
super().__init__()
self.global_agent = global_agent
self.worker_id = worker_id
self.env = env
self.gamma = gamma
self.lambd = lambd
def run(self):
while True:
# 同步本地网络
local_agent = self.sync_from_global()
# 收集经验
states, actions, rewards = [], [], []
state = self.env.reset()
done = False
while not done and len(states) < 20:
action, log_prob = local_agent.choose_action(state)
next_state, reward, done, _ = self.env.step(action)
states.append(state)
actions.append(action)
rewards.append(reward)
state = next_state
# 计算GAE
values = [local_agent.critic(torch.FloatTensor(s)).item()
for s in states]
advantages, returns = compute_gae(rewards, values,
self.gamma, self.lambd)
# 更新全局网络
self.global_agent.update(states, actions, returns, advantages)
def sync_from_global(self):
"""从全局网络同步参数"""
local_agent = copy.deepcopy(self.global_agent)
return local_agent
def train_a3c(env_fn, n_workers=8, n_steps=20, gamma=0.99, lambd=0.95):
"""训练A3C"""
global_agent = ActorCriticAgent(state_dim, action_dim)
workers = [A3CWorker(global_agent, i, env_fn(), gamma, lambd)
for i in range(n_workers)]
for w in workers:
w.start()
for w in workers:
w.join()算法对比
| 算法 | 策略更新 | Critic类型 | 特点 |
|---|---|---|---|
| Actor-Critic | 在线 | 单 Critic | 基础框架 |
| A2C | 同步 | 单 Critic | A3C同步版本 |
| A3C | 异步 | 单 Critic | 多线程并行 |
| GAE | - | - | 优势估计技术 |
| PPO | 在线/离线 | 多种 | 剪切目标 |
| SAC | 离线 | 双 Critic | 最大熵 |
| TD3 | 离线 | 双 Critic | 连续控制 |