Soft Actor-Critic (SAC)

1. 概述

Soft Actor-Critic (SAC) 是由 Haarnoja 等人于2018年提出的最大熵强化学习算法。1 它结合了无模型强化学习的样本效率和最大熵框架的稳定探索性,成为连续控制任务中最广泛使用的算法之一。

1.1 与标准RL的关键区别

方面标准RLSAC (最大熵RL)
目标最大化期望累计奖励最大化奖励 + 熵
探索依赖探索策略熵自动鼓励探索
稳定性容易发散更加稳定
收敛性难以保证有理论保证

1.2 核心思想

SAC 的核心思想是在传统的强化学习目标中加入熵正则项:

其中:

  • 是轨迹
  • 是给定状态 下策略的熵
  • 是温度参数,控制熵的重要性

2. 最大熵框架

2.1 熵的物理意义

策略的熵 衡量在状态 下动作选择的不确定性:

直观理解

  • 高熵策略:在不同状态下倾向于尝试各种动作,探索性强
  • 低熵策略:动作选择更加确定,可能陷入局部最优

2.2 最大熵原则

在强化学习中,最大熵原则有明确的动机:

  1. 探索-利用平衡:熵项自动平衡探索新动作和利用已知好动作
  2. 鲁棒性:最大熵策略对模型误差更鲁棒
  3. 隐式正则化:防止策略变得过于确定性

2.3 温度参数α

温度参数 控制熵正则化的强度:

  • 过大:策略接近均匀分布,奖励被忽视
  • 过小:熵的贡献可忽略,接近标准RL
  • 自适应调整:SAC 可以自动调整

3. 软价值函数

3.1 软状态价值函数

标准RL中的状态价值函数:

SAC中的软状态价值函数:

即同时考虑Q值和策略的熵。

3.2 软Q价值函数

定义 软Q函数 为:

代入 的定义:

3.3 软贝尔曼方程

SAC中的软贝尔曼方程为:

与标准贝尔曼方程的区别:期望项包含了熵项。


4. 算法推导

4.1 策略更新

SAC的策略更新可以写成:

其中 是归一化常数。

推导

从最大熵目标出发,使用策略梯度:

使用Reparameterization Trick,将动作表示为:

其中

4.2 策略参数化

SAC使用重参数化高斯策略

实际实现中,使用神经网络输出

4.3 温度参数更新

SAC 自动调整温度参数

其中 是目标熵(通常设为动作空间维度的负值,如 )。


5. 算法流程

5.1 整体框架

Algorithm: Soft Actor-Critic

1. 初始化:
   - Q网络: Q_θ1, Q_θ2 (双Q网络)
   - 目标Q网络: Q_φ1, Q_φ2
   - 策略网络: π_ψ
   - 温度参数: α
   - Replay Buffer: D

2. for episode in range(num_episodes):
   3. 采集轨迹:
      - 从环境中采样状态 s
      - 根据 π_ψ 选择动作 a = π_ψ(s) + 噪声
      - 获得奖励 r 和下一个状态 s'
      - 存储 (s, a, r, s') 到 D
   
   4. if 采集足够样本:
      5. 从 D 中采样批次 B
      
      6. 更新Q网络:
         - 计算目标Q值: y = r + γ * (min(Q_φ1, Q_φ2)(s', a') - α log π_ψ(a'|s'))
         - 最小化均方误差: L_Q = (Q_θ(s,a) - y)²
      
      7. 更新策略网络 (延迟更新):
         - 采样新动作: a_ψ = π_ψ(s) (reparameterized)
         - 最小化: L_π = α * log π_ψ(a_ψ|s) - min(Q_θ1, Q_θ2)(s, a_ψ)
      
      8. 更新温度参数:
         - L_α = α * (log π_ψ(a|s) + H_target)
      
      9. 软更新目标网络:
         - θ_target ← τ * θ + (1-τ) * θ_target

5.2 伪代码实现

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from collections import deque
import random
 
class ReplayBuffer:
    """经验回放缓冲区"""
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            torch.FloatTensor(np.array(states)),
            torch.FloatTensor(np.array(actions)),
            torch.FloatTensor(rewards).unsqueeze(1),
            torch.FloatTensor(np.array(next_states)),
            torch.FloatTensor(dones).unsqueeze(1)
        )
    
    def __len__(self):
        return len(self.buffer)
 
 
class SoftActorCritic:
    """Soft Actor-Critic 算法实现"""
    
    def __init__(self, state_dim, action_dim, hidden_dim=256, 
                 lr=3e-4, gamma=0.99, tau=0.005, alpha=0.2,
                 target_entropy=None):
        self.gamma = gamma
        self.tau = tau
        self.alpha = alpha
        
        # 目标熵:通常是动作空间维度的负值
        if target_entropy is None:
            self.target_entropy = -action_dim
        else:
            self.target_entropy = target_entropy
        
        # 双Q网络
        self.q1 = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.q2 = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        # 目标Q网络
        self.q1_target = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.q2_target = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        # 复制参数到目标网络
        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())
        
        # 策略网络:输出均值和log标准差
        self.policy = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.mean = nn.Linear(hidden_dim, action_dim)
        self.log_std = nn.Linear(hidden_dim, action_dim)
        
        # 优化器
        self.q_optimizer = optim.Adam(
            list(self.q1.parameters()) + list(self.q2.parameters()), 
            lr=lr
        )
        self.policy_optimizer = optim.Adam(
            list(self.policy.parameters()) + 
            list(self.mean.parameters()) + 
            list(self.log_std.parameters()),
            lr=lr
        )
        
        # 自动温度调整的优化器
        self.log_alpha = torch.zeros(1, requires_grad=True)
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=lr)
    
    def get_action(self, state, deterministic=False):
        """获取动作"""
        with torch.no_grad():
            # 编码状态
            h = self.policy(state)
            mean = self.mean(h)
            log_std = self.log_std(h)
            log_std = torch.clamp(log_std, -20, 2)
            
            if deterministic:
                return torch.tanh(mean)
            
            std = log_std.exp()
            dist = torch.distributions.Normal(mean, std)
            x = dist.rsample()  # 重参数化采样
            action = torch.tanh(x)
            
            return action
    
    def update(self, states, actions, rewards, next_states, dones):
        """更新网络参数"""
        batch_size = states.shape[0]
        
        # ========== 1. 更新Q网络 ==========
        with torch.no_grad():
            # 采样新动作计算目标Q值
            next_h = self.policy(next_states)
            next_mean = self.mean(next_h)
            next_log_std = self.log_std(next_h)
            next_log_std = torch.clamp(next_log_std, -20, 2)
            next_std = next_log_std.exp()
            
            next_dist = torch.distributions.Normal(next_mean, next_std)
            next_x = next_dist.rsample()
            next_action = torch.tanh(next_x)
            
            # 计算熵 -α * log π(a'|s')
            log_pi_next = next_dist.log_prob(next_x) - torch.log(1 - next_action.pow(2) + 1e-6)
            log_pi_next = log_pi_next.sum(dim=1, keepdim=True)
            
            # 目标Q值
            next_q1 = self.q1_target(torch.cat([next_states, next_action], dim=1))
            next_q2 = self.q2_target(torch.cat([next_states, next_action], dim=1))
            next_q = torch.min(next_q1, next_q2)
            next_value = next_q - self.alpha * log_pi_next
            
            target_q = rewards + self.gamma * (1 - dones) * next_value
        
        # 当前Q值
        current_q1 = self.q1(torch.cat([states, actions], dim=1))
        current_q2 = self.q2(torch.cat([states, actions], dim=1))
        
        # Q损失
        q_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
        
        self.q_optimizer.zero_grad()
        q_loss.backward()
        self.q_optimizer.step()
        
        # ========== 2. 更新策略网络 ==========
        h = self.policy(states)
        mean = self.mean(h)
        log_std = self.log_std(h)
        log_std = torch.clamp(log_std, -20, 2)
        std = log_std.exp()
        
        dist = torch.distributions.Normal(mean, std)
        x = dist.rsample()
        action = torch.tanh(x)
        
        # 计算熵
        log_pi = dist.log_prob(x) - torch.log(1 - action.pow(2) + 1e-6)
        log_pi = log_pi.sum(dim=1, keepdim=True)
        
        # Q值
        q1_pi = self.q1(torch.cat([states, action], dim=1))
        q2_pi = self.q2(torch.cat([states, action], dim=1))
        q_pi = torch.min(q1_pi, q2_pi)
        
        # 策略损失:最大化 Q - α * log π
        policy_loss = (self.alpha * log_pi - q_pi).mean()
        
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        
        # ========== 3. 更新温度参数 ==========
        alpha_loss = self.log_alpha * (log_pi.detach() + self.target_entropy)
        alpha_loss = -alpha_loss.mean()
        
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()
        
        self.alpha = self.log_alpha.exp().item()
        
        # ========== 4. 软更新目标网络 ==========
        self._soft_update(self.q1, self.q1_target)
        self._soft_update(self.q2, self.q2_target)
        
        return {
            'q_loss': q_loss.item(),
            'policy_loss': policy_loss.item(),
            'alpha': self.alpha
        }
    
    def _soft_update(self, source, target):
        """软更新目标网络"""
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(
                self.tau * param.data + (1 - self.tau) * target_param.data
            )
 
 
def train_sac(env, num_episodes=1000, batch_size=256):
    """SAC训练主循环"""
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    
    agent = SoftActorCritic(state_dim, action_dim)
    buffer = ReplayBuffer(capacity=100000)
    
    rewards_history = []
    
    for episode in range(num_episodes):
        state, _ = env.reset()
        episode_reward = 0
        done = False
        
        while not done:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            action = agent.get_action(state_tensor).cpu().numpy()[0]
            
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            buffer.push(state, action, reward, next_state, done)
            
            state = next_state
            episode_reward += reward
            
            # 更新
            if len(buffer) >= batch_size:
                batch = buffer.sample(batch_size)
                agent.update(*batch)
        
        rewards_history.append(episode_reward)
        
        if episode % 10 == 0:
            avg_reward = np.mean(rewards_history[-10:])
            print(f"Episode {episode}: Avg Reward (last 10): {avg_reward:.2f}")
    
    return agent, rewards_history

6. 理论分析

6.1 收敛性分析

SAC的收敛性可以通过以下定理保证:

定理:在温和条件下(策略和Q函数被参数化近似、足够的探索、数据无限),SAC收敛到最优最大熵策略。

关键要素

  1. 双Q网络:减少过估计
  2. 目标网络:稳定目标
  3. 最大熵目标:避免策略退化

6.2 样本效率

SAC 相比其他无模型RL算法具有更高的样本效率:

  • 离策略学习:使用经验回放
  • 最大熵探索:更充分的状态空间覆盖
  • 双Q网络:减少方差

6.3 超参数敏感性

超参数典型值影响
学习率3e-4影响收敛速度
γ (折扣因子)0.99影响长期奖励权重
τ (软更新)0.005影响目标网络更新速度
α (温度)0.2 或 自动影响探索-利用平衡
Replay Buffer1e5 - 1e6影响历史信息利用

7. 与其他算法的对比

7.1 vs PPO

方面PPOSAC
策略更新信任域约束最小化KL散度
动作空间离散/连续连续
探索熵奖励最大熵框架
样本效率较低较高
稳定性中等较高

7.2 vs TD3

方面TD3SAC
动作空间连续连续
目标策略平滑加噪声通过熵
温度参数
策略类型确定性随机
探索额外探索噪声内在熵

7.3 适用场景

算法最佳场景
SAC连续控制、需要自动探索、多峰奖励
PPO离散动作、稳定性优先、大规模并行
TD3高维连续动作、确定性策略

8. 实践技巧

8.1 实现注意事项

  1. 重参数化技巧:确保使用 rsample() 而非 sample()
  2. Tanh激活:动作经过tanh压缩到[-1, 1]
  3. 数值稳定:log概率中加上小常数 1e-6
  4. 延迟更新:策略网络更新频率通常低于Q网络

8.2 调参建议

# 推荐的超参数配置
config = {
    'gamma': 0.99,           # 通常无需调整
    'tau': 0.005,            # 小值更稳定
    'lr': 3e-4,              # 常用值
    'hidden_dim': 256,        # 根据任务复杂度调整
    'batch_size': 256,        # 越大越稳定但慢
    'buffer_size': 1e6,       # 越大通常越好
    'target_entropy': -action_dim,  # 自动设置
}

8.3 常见问题与解决方案

问题原因解决方案
策略崩溃为均匀分布α过大减小目标熵或固定α
Q值持续下降目标网络更新太快减小τ
动作分布过窄熵奖励不够增大α或目标熵
训练不稳定批标准化问题检查数据归一化

9. 扩展与变体

9.1 SAC-Max

使用最大熵但固定温度参数,不自动调整。

9.2 Soft Actor-Critic with Expert Demonstrations (SAC-ED)

结合专家演示的SAC变体,加速学习。

9.3 RAMBO-RL

基于随机松弛的最大熵RL,适合离线设置。


10. 总结

SAC是现代强化学习中最重要的算法之一,它通过最大熵框架优雅地解决了探索-利用权衡问题。其核心特点包括:

  1. 最大熵目标:内置的探索激励机制
  2. 双Q网络:减少过估计,提高稳定性
  3. 自动温度调整:无需手动设置探索强度
  4. 离策略学习:高效的样本利用

SAC在连续控制任务中表现出色,是Mujoco、Robosuite等基准的默认算法之一。


参考资料


相关主题

Footnotes

  1. Haarnoja, T., Zhou, A., Abbeel, P., & Levine, S. (2018). Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor. International Conference on Machine Learning.