TD3 (Twin Delayed Deep Deterministic Policy Gradient)

1. 概述

TD3(Twin Delayed DDPG)是 Fujimoto 等人于2018年提出的针对 DDPG 算法过估计问题的改进算法。1 通过三项核心技术大幅提升了 DDPG 的稳定性和性能,被广泛认为是连续控制任务的 SOTA 算法之一。

1.1 DDPG的问题

标准 DDPG (Deep Deterministic Policy Gradient) 存在三个主要问题:

问题描述影响
过估计偏差Q值被系统性地高估策略退化
方差放大过估计导致梯度方差增大训练不稳定
策略退化策略更新方向错误性能下降

1.2 TD3的核心思想

TD3 提出了三项核心技术来解决上述问题:

  1. 双Q学习 (Clipped Double Q-Learning):使用两个Q网络,取较小值
  2. 延迟策略更新 (Delayed Policy Updates):减少策略更新频率
  3. 目标策略平滑 (Target Policy Smoothing):添加噪声正则化

2. 问题分析:DDPG的过估计

2.1 过估计的来源

在标准 Q-learning 中:

问题是 操作会系统性地高估真实Q值

数学分析

假设存在估计误差

那么:

这意味着即使每个动作的Q值都是无偏估计, 操作仍然会引入正向偏差

2.2 过估计的级联效应

过估计 Q值
    ↓
策略选择被误导
    ↓
学习次优策略
    ↓
更严重的过估计
    ↓
策略崩溃

2.3 Double Q-Learning的启发

Van Hasselt 等人提出的 Double Q-Learning 通过两个网络交替更新来解决过估计。2

但直接应用于连续动作空间存在挑战:无法离散化所有动作来取 max。


3. TD3三项核心技术

3.1 双Q学习 (Clipped Double Q-Learning)

核心思想:使用两个独立的Q网络,估计真实Q值的下界。

目标值计算

为什么取最小值有效

假设 是独立同分布的估计,有:

这提供了对真实Q值的下界估计,抵消了 引入的过估计。

直观理解

  • 如果两个Q网络都高估了,则取较小值减轻高估
  • 如果一个高估一个低估,取最小值更保守
  • 如果两个都准确,最小值略微低估但影响不大

3.2 延迟策略更新 (Delayed Policy Updates)

核心思想:Q网络更新更频繁,策略网络更新较慢。

实现

if train_step % d == 0:  # d通常是2
    # 更新策略网络
    policy_optimizer.step()
    # 软更新目标网络
    soft_update(policy)
    soft_update(q_network1)
    soft_update(q_network2)

为什么延迟有效

  1. 避免策略被错误Q值误导:Q网络更新两次后更稳定
  2. 减少策略更新的频率:给予Q网络更多时间收敛
  3. 减轻过估计的影响:不急于用不稳定Q值更新策略

延迟参数选择

延迟d效果
d=1等同于DDPG,不稳定
d=2推荐值,平衡稳定性和学习速度
d=3-4更稳定但学习可能变慢

3.3 目标策略平滑 (Target Policy Smoothing)

核心思想:在目标动作上添加小噪声,隐式正则化。

实现细节

def target_policy_smoothing(q_target, action, noise_std=0.2, noise_clip=0.5):
    noise = torch.randn_like(action) * noise_std
    noise = torch.clamp(noise, -noise_clip, noise_clip)
    action_smoothed = action + noise
    action_smoothed = torch.tanh(action_smoothed)  # 如果需要
    return q_target(action_smoothed)

物理意义

  1. 动作扰动 奖励平滑:相似的动作应该有相似的Q值
  2. 正则化效果:鼓励策略在相似状态下选择相似动作
  3. 减少过拟合:防止策略在离散点上学到错误值

噪声参数选择

参数典型值说明
(标准差)0.2噪声幅度
(裁剪)0.5噪声裁剪范围

4. 算法流程

4.1 完整算法

Algorithm: TD3 (Twin Delayed DDPG)

1. 初始化:
   - 策略网络 π_ψ 和目标策略网络 π_ψ'
   - Q网络 Q_φ1, Q_φ2 和目标Q网络 Q_φ1', Q_φ2'
   - Replay Buffer D
   - 目标网络软更新系数 τ
   - 延迟参数 d
   - 目标策略平滑参数 σ, c

2. for episode = 1 to M:
   3.     s = env.reset()
   4.     for t = 1 to T:
   5.         a = π_ψ(s) + N(0, σ_explore)  # 带探索噪声
   6.         s', r, done = env.step(a)
   7.         D.push(s, a, r, s', done)
   8.         
   9.         if t % d == 0 and |D| > batch_size:
   10.            # ========== 更新Q网络 ==========
   11.            从D采样批次 (s, a, r, s', d)
   12.            
   13.            # 目标策略平滑
   14.            a' = π_ψ'(s')
   15.            noise = clip(N(0, σ), -c, +c)
   16.            a' = clip(a' + noise, a_low, a_high)
   17.            
   18.            # 目标Q值 (双Q最小值)
   19.            y = r + γ * min(Q_φ1'(s', a'), Q_φ2'(s', a'))
   20.            
   21.            # 更新Q网络
   22.            L(φ_i) = E[(Q_φ_i(s,a) - y)²]
   23.            
   24.            # ========== 延迟更新策略 ==========
   25.            if step % d == 0:
   26.                # 策略梯度 (只使用Q1)
   27.                J(ψ) = -E[Q_φ1(s, π_ψ(s))]
   28.                ∇_ψ J(ψ) ≈ E[∇_a Q_φ1(s,a) |_{a=π_ψ(s)} ∇_ψ π_ψ(s)]
   29.                
   30.                # 软更新目标网络
   31.                ψ' ← τψ + (1-τ)ψ'
   32.                φ_i' ← τφ_i + (1-τ)φ_i'
   33.            
   34.        s = s'

4.2 伪代码实现

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from collections import deque
import random
 
 
class ReplayBuffer:
    """经验回放缓冲区"""
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        return (
            torch.FloatTensor(np.array(states)),
            torch.FloatTensor(np.array(actions)),
            torch.FloatTensor(rewards).unsqueeze(1),
            torch.FloatTensor(np.array(next_states)),
            torch.FloatTensor(dones).unsqueeze(1)
        )
    
    def __len__(self):
        return len(self.buffer)
 
 
class Actor(nn.Module):
    """TD3的策略网络 (Actor)"""
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim),
            nn.Tanh()  # 输出 [-1, 1]
        )
    
    def forward(self, state):
        return self.net(state)
 
 
class Critic(nn.Module):
    """TD3的Q网络 (Critic)"""
    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super().__init__()
        # Q1网络
        self.q1_net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        # Q2网络
        self.q2_net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        return self.q1_net(x), self.q2_net(x)
    
    def q1(self, state, action):
        x = torch.cat([state, action], dim=1)
        return self.q1_net(x)
 
 
class TD3:
    """Twin Delayed DDPG (TD3)"""
    
    def __init__(self, state_dim, action_dim, hidden_dim=256,
                 lr=3e-4, gamma=0.99, tau=0.005,
                 policy_delay=2, noise_std=0.2, noise_clip=0.5,
                 policy_noise_std=0.2, policy_noise_clip=0.5):
        
        self.gamma = gamma
        self.tau = tau
        self.policy_delay = policy_delay
        self.noise_std = policy_noise_std
        self.noise_clip = policy_noise_clip
        self.total_it = 0
        
        # 策略网络
        self.actor = Actor(state_dim, action_dim, hidden_dim)
        self.actor_target = Actor(state_dim, action_dim, hidden_dim)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
        
        # Q网络 (双Q)
        self.critic = Critic(state_dim, action_dim, hidden_dim)
        self.critic_target = Critic(state_dim, action_dim, hidden_dim)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr)
        
        # 动作范围
        self.action_dim = action_dim
    
    def get_action(self, state, deterministic=False, noise_scale=0.1):
        """获取动作"""
        with torch.no_grad():
            state = torch.FloatTensor(state).unsqueeze(0)
            action = self.actor(state).cpu().numpy()[0]
            
            if deterministic:
                return action
            
            # 添加探索噪声
            noise = np.random.normal(0, noise_scale, size=action.shape)
            action = np.clip(action + noise, -1, 1)
            
            return action
    
    def update(self, states, actions, rewards, next_states, dones):
        """更新网络"""
        self.total_it += 1
        
        # ========== 1. 更新Q网络 ==========
        with torch.no_grad():
            # 目标策略平滑
            next_actions = self.actor_target(next_states)
            noise = torch.randn_like(next_actions) * self.noise_std
            noise = torch.clamp(noise, -self.noise_clip, self.noise_clip)
            next_actions = torch.clamp(next_actions + noise, -1, 1)
            
            # 目标Q值 (双Q取最小)
            target_q1, target_q2 = self.critic_target(next_states, next_actions)
            target_q = torch.min(target_q1, target_q2)
            target_q = rewards + self.gamma * (1 - dones) * target_q
        
        # 当前Q值
        current_q1, current_q2 = self.critic(states, actions)
        
        # Q损失
        q1_loss = F.mse_loss(current_q1, target_q)
        q2_loss = F.mse_loss(current_q2, target_q)
        critic_loss = q1_loss + q2_loss
        
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        # 梯度裁剪
        nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=10.0)
        self.critic_optimizer.step()
        
        # ========== 2. 延迟更新策略 ==========
        if self.total_it % self.policy_delay == 0:
            # 策略梯度: 最大化Q1
            policy_actions = self.actor(states)
            q1 = self.critic.q1(states, policy_actions)
            policy_loss = -q1.mean()
            
            self.actor_optimizer.zero_grad()
            policy_loss.backward()
            nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=10.0)
            self.actor_optimizer.step()
            
            # 软更新目标网络
            self._soft_update(self.actor, self.actor_target)
            self._soft_update(self.critic, self.critic_target)
            
            return {
                'q1_loss': q1_loss.item(),
                'q2_loss': q2_loss.item(),
                'policy_loss': policy_loss.item()
            }
        
        return {
            'q1_loss': q1_loss.item(),
            'q2_loss': q2_loss.item(),
            'policy_loss': None
        }
    
    def _soft_update(self, source, target):
        """软更新"""
        for target_param, param in zip(target.parameters(), source.parameters()):
            target_param.data.copy_(
                self.tau * param.data + (1 - self.tau) * target_param.data
            )
 
 
def train_td3(env, num_episodes=1000, batch_size=256, start_steps=10000):
    """TD3训练主循环"""
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    action_bound = env.action_space.high[0]  # 假设对称动作空间
    
    agent = TD3(state_dim, action_dim)
    buffer = ReplayBuffer(capacity=100000)
    
    rewards_history = []
    
    for episode in range(num_episodes):
        state, _ = env.reset()
        episode_reward = 0
        done = False
        
        while not done:
            # 初始阶段使用随机动作 (探索)
            if len(buffer) < start_steps:
                action = env.action_space.sample()
            else:
                action = agent.get_action(state)
            
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            buffer.push(state, action, reward, next_state, done)
            
            state = next_state
            episode_reward += reward
            
            # 更新
            if len(buffer) >= batch_size:
                batch = buffer.sample(batch_size)
                agent.update(*batch)
        
        rewards_history.append(episode_reward)
        
        if episode % 10 == 0:
            avg_reward = np.mean(rewards_history[-10:])
            print(f"Episode {episode}: Avg Reward (last 10): {avg_reward:.2f}")
    
    return agent, rewards_history

5. 超参数分析

5.1 关键超参数

超参数典型值影响
(软更新)0.005目标网络更新速度
(折扣)0.99长期奖励权重
(延迟)2策略更新频率
0.2平滑噪声幅度
0.5噪声裁剪范围
0.1探索噪声

5.2 敏感性分析

性能
  ↑
  │     ┌───── 延迟d=2 (推荐)
  │   ╱ │  d=1 (不稳定)
  │ ╱   │
  │╱    └──────── 延迟d=3 (稳定但慢)
  └─────────────────────→ 延迟参数d

5.3 调参建议

# Mujoco环境的推荐配置
config = {
    'gamma': 0.99,
    'tau': 0.005,
    'policy_delay': 2,          # 关键参数
    'policy_noise_std': 0.2,    # 平滑噪声
    'policy_noise_clip': 0.5,   # 噪声裁剪
    'explore_noise_std': 0.1,   # 探索噪声
    'batch_size': 256,
    'hidden_dim': 256,
    'lr': 3e-4,
}

6. 理论分析

6.1 过估计抑制的理论保证

定理 (TD3过估计上界)

为双Q估计,,则在温和条件下:

其中 是与网络架构相关的常数。

推论:双Q取最小值将过估计的量级限制在 水平。

6.2 延迟更新的收敛性

延迟参数的理论选择

延迟 应该满足:

其中 是软更新系数, 是Q函数Hessian矩阵的最大特征值。

实践中, 是大多数情况的良好选择。

6.3 目标策略平滑的统计性质

平滑后的Q值估计满足:

这相当于对相似的动作取平均,减少了方差


7. 与其他算法对比

7.1 TD3 vs DDPG

方面DDPGTD3
Q网络单Q双Q
策略更新每步每d步
目标平滑
稳定性中等
性能较低显著提升

7.2 TD3 vs PPO

方面TD3PPO
动作空间连续离散/连续
策略类型确定性随机
探索噪声熵项
样本效率高 (离策略)中 (在策略)
稳定性中等

7.3 TD3 vs SAC

方面TD3SAC
策略类型确定性随机
动作选择argmax采样
探索机制显式噪声熵正则化
温度参数
性能相当相当

8. 实践技巧

8.1 实现注意事项

  1. 梯度裁剪:防止梯度爆炸
  2. 目标网络延迟:策略更新频率低于Q网络
  3. 噪声裁剪:确保添加的噪声在合理范围内
  4. 动作归一化:确保动作在 [-1, 1] 范围内

8.2 常见问题

问题原因解决方案
Q值持续上升目标网络更新太快减小τ
策略退化Q过估计严重确保d≥2
探索不足探索噪声太小增大探索噪声
训练发散学习率太高使用梯度裁剪

8.3 评估技巧

def evaluate_agent(agent, env, num_episodes=10):
    """评估智能体性能"""
    rewards = []
    for _ in range(num_episodes):
        state, _ = env.reset()
        episode_reward = 0
        done = False
        
        while not done:
            action = agent.get_action(state, deterministic=True)
            state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            episode_reward += reward
        
        rewards.append(episode_reward)
    
    return np.mean(rewards), np.std(rewards)

9. 扩展与变体

9.1 DDPG with HER (Hindsight Experience Replay)

结合 hindsight 思想处理稀疏奖励。

9.2 TD3 with Prioritized Experience Replay

使用优先级采样提高样本效率。

9.3 Distributed TD3 (DTD3)

使用分布式采样加速训练。


10. 总结

TD3通过三项核心技术有效解决了DDPG的过估计问题:

  1. 双Q学习:取最小值抑制过估计
  2. 延迟策略更新:给予Q网络更多时间稳定
  3. 目标策略平滑:隐式正则化减少方差

TD3在连续控制任务中表现优异,是强化学习领域的重要里程碑。


参考资料


相关主题

Footnotes

  1. Fujimoto, S., van Hoof, H., & Meger, D. (2018). Addressing Function Approximation Error in Actor-Critic Methods. International Conference on Machine Learning.

  2. Van Hasselt, H., Guez, A., & Silver, D. (2016). Deep Reinforcement Learning with Double Q-learning. AAAI Conference on Artificial Intelligence.