概述

Actor-Critic方法结合了策略梯度(Actor)和价值函数(Critic)的优点,是一种广泛使用的强化学习框架。1

Actor:负责学习策略(输出动作)
Critic:负责评估策略(估计价值函数)
协同训练:Critic帮助Actor减小方差,Actor提供学习信号

为什么需要Actor-Critic?

REINFORCE的问题

REINFORCE使用蒙特卡洛回报 作为梯度估计器:

问题:

  1. 高方差:需要完整episode才能估计,方差随轨迹长度指数增长
  2. 低样本效率:必须等待episode结束

解决方案

用Critic估计的基线替代蒙特卡洛回报:

其中 是优势函数,可以用TD error近似。

框架结构

┌─────────────────────────────────────────────────────────────┐
│                      Actor-Critic 框架                       │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   ┌──────────┐                                              │
│   │  环境    │                                              │
│   └────┬─────┘                                              │
│        │ s_t, r_t                                           │
│        ▼                                                    │
│   ┌──────────────────────────────────────────┐              │
│   │               Critic (评价者)              │              │
│   │  ┌─────────────┐  ┌─────────────────┐    │              │
│   │  │  价值网络    │  │  TD Error       │    │              │
│   │  │  V(s;θ_v)  │──▶│  δ_t = r + γV(s')│   │              │
│   │  └─────────────┘  │  - V(s;θ_v)     │    │              │
│   └───────────────────┴─────────────────┘────┘              │
│        │ δ_t (学习信号)                                    │
│        ▼                                                    │
│   ┌──────────────────────────────────────────┐              │
│   │               Actor (行动者)              │              │
│   │  ┌─────────────┐  ┌─────────────────┐    │              │
│   │  │  策略网络    │  │  策略梯度更新   │    │              │
│   │  │  π(a|s;θ_π) │◀─│  δ_t·∇log π(a|s) │    │              │
│   │  └─────────────┘  └─────────────────┘    │              │
│   └───────────────────┴─────────────────────┘              │
│        │ a_t (动作)                                        │
│        ▼                                                    │
│   ┌──────────┐                                              │
│   │  环境    │                                              │
│   └──────────┘                                              │
│                                                             │
└─────────────────────────────────────────────────────────────┘

算法详解

基本Actor-Critic流程

1. 初始化:
   - Actor参数 θ_π
   - Critic参数 θ_v

2. 对每个episode:
   a) 初始化状态 s
   
   b) 对每一步:
      - Actor选择动作: a ~ π_θ(a|s)
      - 执行动作,获得 r, s'
      
      - Critic计算TD error:
        δ_t = r + γ V(s'; θ_v) - V(s; θ_v)
      
      - Critic更新:
        θ_v ← θ_v + α_v · δ_t · ∇_θ_v V(s; θ_v)
      
      - Actor更新:
        θ_π ← θ_π + α_π · δ_t · ∇_θ_π log π_θ(a|s)
      
      - s ← s'

Python实现

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
 
class Actor(nn.Module):
    """Actor网络:策略"""
    
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, x):
        return self.net(x)
 
 
class Critic(nn.Module):
    """Critic网络:价值函数"""
    
    def __init__(self, state_dim, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, x):
        return self.net(x)
 
 
class ActorCriticAgent:
    """Actor-Critic智能体"""
    
    def __init__(
        self,
        state_dim,
        action_dim,
        actor_lr=1e-3,
        critic_lr=1e-3,
        gamma=0.99,
        entropy_coef=0.01
    ):
        self.gamma = gamma
        self.entropy_coef = entropy_coef
        
        self.actor = Actor(state_dim, action_dim)
        self.critic = Critic(state_dim)
        
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)
    
    def choose_action(self, state):
        state = torch.FloatTensor(state).unsqueeze(0)
        probs = self.actor(state)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        return action.item(), dist.log_prob(action)
    
    def update(self, state, action, reward, next_state, done):
        state = torch.FloatTensor(state).unsqueeze(0)
        next_state = torch.FloatTensor(next_state).unsqueeze(0)
        action = torch.LongTensor([action])
        
        # 1. Critic更新:计算TD error
        with torch.no_grad():
            if done:
                target = torch.FloatTensor([reward])
            else:
                target = reward + self.gamma * self.critic(next_state)
        
        current_value = self.critic(state)
        critic_loss = nn.MSELoss()(current_value, target)
        
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        
        # 2. Actor更新:用TD error作为优势估计
        td_error = (target - current_value).detach()
        
        log_prob = torch.log(self.actor(state) + 1e-8)
        action_log_prob = log_prob[0, action]
        
        # 策略梯度损失 + 熵正则
        entropy = (torch.log(self.actor(state) + 1e-8) * self.actor(state)).sum()
        actor_loss = -(td_error * action_log_prob) - self.entropy_coef * entropy
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        return critic_loss.item(), actor_loss.item()
 
 
def train_actor_critic(env, agent, n_episodes=1000):
    """训练Actor-Critic"""
    rewards_history = []
    
    for episode in range(n_episodes):
        state = env.reset()
        done = False
        episode_reward = 0
        
        while not done:
            action, _ = agent.choose_action(state)
            next_state, reward, done, _ = env.step(action)
            
            agent.update(state, action, reward, next_state, done)
            
            state = next_state
            episode_reward += reward
        
        rewards_history.append(episode_reward)
        
        if (episode + 1) % 100 == 0:
            avg_reward = np.mean(rewards_history[-100:])
            print(f"Episode {episode+1}, Avg Reward: {avg_reward:.2f}")
    
    return rewards_history

优势函数与GAE

N步TD优势估计

其中 是TD error。

GAE(Generalized Advantage Estimation)

GAE通过加权平均所有n步TD advantage来平衡偏差和方差:2

def compute_gae(rewards, values, next_values, gamma=0.99, lambd=0.95):
    """
    计算GAE优势估计
    
    参数:
        rewards: 奖励序列
        values: 价值函数序列 (包括最后一步)
        gamma: 折扣因子
        lambd: GAE参数 (0-1)
    
    返回:
        advantages: 优势估计
        returns: 回报(用于Critic训练)
    """
    advantages = []
    gae = 0
    
    for t in reversed(range(len(rewards))):
        if t == len(rewards) - 1:
            next_value = 0
        else:
            next_value = values[t + 1]
        
        delta = rewards[t] + gamma * next_value - values[t]
        gae = delta + gamma * lambd * gae
        advantages.insert(0, gae)
    
    advantages = np.array(advantages)
    returns = advantages + values[:-1]  # 不包括最后一步的价值
    
    return advantages, returns

连续动作空间的Actor-Critic

SAC (Soft Actor-Critic)

SAC是一种最大熵Actor-Critic算法:

class SACAgent:
    """Soft Actor-Critic"""
    
    def __init__(self, state_dim, action_dim, action_bound,
                 actor_lr=3e-4, critic_lr=3e-4, alpha_lr=3e-4,
                 gamma=0.99, tau=0.005):
        self.gamma = gamma
        self.tau = tau
        self.action_bound = action_bound
        
        # Actor
        self.actor = GaussianPolicy(state_dim, action_dim)
        
        # 双Critic
        self.critic1 = Critic(state_dim, action_dim)
        self.critic2 = Critic(state_dim, action_dim)
        self.target_critic1 = Critic(state_dim, action_dim)
        self.target_critic2 = Critic(state_dim, action_dim)
        
        # 自动熵温度
        self.log_alpha = torch.zeros(1, requires_grad=True)
        self.target_entropy = -action_dim
        
        # 优化器
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic1_optimizer = optim.Adam(self.critic1.parameters(), lr=critic_lr)
        self.critic2_optimizer = optim.Adam(self.critic2.parameters(), lr=critic_lr)
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr)
    
    def update(self, states, actions, rewards, next_states, dones):
        alpha = self.log_alpha.exp()
        
        # 1. Critic更新
        with torch.no_grad():
            next_actions, next_log_probs = self.actor.sample(next_states)
            q1_target = self.target_critic1(next_states, next_actions)
            q2_target = self.target_critic2(next_states, next_actions)
            q_target = torch.min(q1_target, q2_target)
            next_value = q_target - alpha * next_log_probs
            q_target = rewards + self.gamma * (1 - dones) * next_value
        
        q1 = self.critic1(states, actions)
        q2 = self.critic2(states, actions)
        
        critic1_loss = nn.MSELoss()(q1, q_target)
        critic2_loss = nn.MSELoss()(q2, q_target)
        
        self.critic1_optimizer.zero_grad()
        self.critic1_loss.backward()
        self.critic1_optimizer.step()
        
        self.critic2_optimizer.zero_grad()
        self.critic2_loss.backward()
        self.critic2_optimizer.step()
        
        # 2. Actor更新
        actions_new, log_probs = self.actor.sample(states)
        q1_new = self.critic1(states, actions_new)
        q2_new = self.critic2(states, actions_new)
        q_new = torch.min(q1_new, q2_new)
        
        actor_loss = (alpha * log_probs - q_new).mean()
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        # 3. 温度参数更新
        alpha_loss = -(self.log_alpha * (log_probs + self.target_entropy).detach()).mean()
        
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()
        
        # 4. 软更新目标网络
        self.soft_update(self.critic1, self.target_critic1)
        self.soft_update(self.critic2, self.target_critic2)
    
    def soft_update(self, source, target):
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(
                self.tau * param.data + (1 - self.tau) * target_param.data
            )

A3C (Asynchronous Advantage Actor-Critic)

A3C使用多线程异步训练:3

import multiprocessing as mp
 
class A3CWorker(mp.Process):
    """A3C工作线程"""
    
    def __init__(self, global_agent, worker_id, env, gamma=0.99, lambd=0.95):
        super().__init__()
        self.global_agent = global_agent
        self.worker_id = worker_id
        self.env = env
        self.gamma = gamma
        self.lambd = lambd
    
    def run(self):
        while True:
            # 同步本地网络
            local_agent = self.sync_from_global()
            
            # 收集经验
            states, actions, rewards = [], [], []
            state = self.env.reset()
            done = False
            
            while not done and len(states) < 20:
                action, log_prob = local_agent.choose_action(state)
                next_state, reward, done, _ = self.env.step(action)
                
                states.append(state)
                actions.append(action)
                rewards.append(reward)
                
                state = next_state
            
            # 计算GAE
            values = [local_agent.critic(torch.FloatTensor(s)).item() 
                     for s in states]
            advantages, returns = compute_gae(rewards, values, 
                                            self.gamma, self.lambd)
            
            # 更新全局网络
            self.global_agent.update(states, actions, returns, advantages)
    
    def sync_from_global(self):
        """从全局网络同步参数"""
        local_agent = copy.deepcopy(self.global_agent)
        return local_agent
 
 
def train_a3c(env_fn, n_workers=8, n_steps=20, gamma=0.99, lambd=0.95):
    """训练A3C"""
    global_agent = ActorCriticAgent(state_dim, action_dim)
    workers = [A3CWorker(global_agent, i, env_fn(), gamma, lambd) 
               for i in range(n_workers)]
    
    for w in workers:
        w.start()
    
    for w in workers:
        w.join()

算法对比

算法策略更新Critic类型特点
Actor-Critic在线单 Critic基础框架
A2C同步单 CriticA3C同步版本
A3C异步单 Critic多线程并行
GAE--优势估计技术
PPO在线/离线多种剪切目标
SAC离线双 Critic最大熵
TD3离线双 Critic连续控制

参考


后续主题

  • PPO:近端策略优化
  • RLHF:人类反馈强化学习
  • DQN:深度Q网络

Footnotes

  1. Mnih et al., “Asynchronous Methods for Deep Reinforcement Learning”, ICML, 2016

  2. Schulman et al., “High-Dimensional Continuous Control Using Generalized Advantage Estimation”, ICLR, 2016

  3. Mnih et al., “Asynchronous Methods for Deep Reinforcement Learning”, 2016