分布式强化学习算法

1. 概述

分布式强化学习通过并行收集经验和训练来加速学习过程。1 本文档介绍几种重要的分布式RL算法,包括分位数回归方法(QR-DQN、IQN)和集成方法(Rainbow)。

1.1 分布式RL架构

┌─────────────────────────────────────────────────────────────┐
│                    分布式RL架构                              │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   ┌─────────┐  ┌─────────┐  ┌─────────┐  ┌─────────┐      │
│   │ Actor 1 │  │ Actor 2 │  │ Actor 3 │  │ Actor N │      │
│   │ (收集)  │  │ (收集)  │  │ (收集)  │  │ (收集)  │      │
│   └────┬────┘  └────┬────┘  └────┬────┘  └────┬────┘      │
│        │            │            │            │            │
│        └────────────┴─────┬──────┴────────────┘            │
│                          │                                 │
│                    ┌─────▼─────┐                          │
│                    │  Replay   │                          │
│                    │  Buffer   │                          │
│                    └─────┬─────┘                          │
│                          │                                 │
│                    ┌─────▼─────┐                          │
│                    │  Learner  │                          │
│                    │ (训练)    │                          │
│                    └─────┬─────┘                          │
│                          │                                 │
│        ┌─────────────────┼─────────────────┐               │
│        ▼                 ▼                 ▼               │
│   ┌─────────┐       ┌─────────┐       ┌─────────┐         │
│   │ Q-Network│◄────│ Q-Network│◄────│ Q-Network│         │
│   │  (θ_1) │       │  (θ_2) │       │  (θ_N) │         │
│   └─────────┘       └─────────┘       └─────────┘         │
│                                                             │
└─────────────────────────────────────────────────────────────┘

1.2 分布式RL的优势

方面单进程RL分布式RL
样本收集速度
GPU利用率可能低
训练稳定性中等
系统复杂度

2. 分位数回归基础

2.1 分位数定义

对于随机变量 ,其 -分位数为:

其中 是累积分布函数。

2.2 分位数回归损失

分位数回归通过最小化以下损失来估计分位数:

其中 分位数损失函数

2.3 分位数损失性质

def quantile_loss(actions, target_actions, tau):
    """
    分位数损失
    tau: 分位数水平 (0到1之间)
    """
    u = target_actions - actions
    loss = u * (tau - (u < 0).float())
    return loss.abs().mean()

3. QR-DQN (Quantile Regression DQN)

3.1 算法动机

标准DQN使用点估计 ,无法捕获价值函数的不确定性。

QR-DQN2 通过学习价值函数分布来解决:

3.2 分位数网格

QR-DQN在区间均匀分布 个分位数:

每个 对应一个可学习的分位数值

3.3 算法实现

class QuantileNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, n_quantiles=200):
        super().__init__()
        self.n_quantiles = n_quantiles
        
        self.net = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim * n_quantiles)
        )
    
    def forward(self, state):
        # 输出: (batch, action_dim * n_quantiles)
        return self.net(state).view(-1, action_dim, self.n_quantiles)
 
 
class QRDQN:
    def __init__(self, state_dim, action_dim, n_quantiles=200, kappa=1.0):
        self.n_quantiles = n_quantiles
        self.kappa = kappa  # Huber损失中的阈值
        
        self.q_net = QuantileNetwork(state_dim, action_dim, n_quantiles)
        self.target_net = QuantileNetwork(state_dim, action_dim, n_quantiles)
        self.target_net.load_state_dict(self.q_net.state_dict())
        
        self.optimizer = optim.Adam(self.q_net.parameters(), lr=1e-4)
    
    def quantile_huber_loss(self, theta, target_theta, tau):
        """
        分位数Huber损失
        theta: 当前分位数 (batch, action_dim, n_quantiles)
        target_theta: 目标分位数 (batch, n_quantiles)
        """
        batch_size = theta.shape[0]
        
        # 计算TD误差
        # target_theta: (batch, n_quantiles, 1)
        # theta: (batch, action_dim, n_quantiles)
        target_theta = target_theta.unsqueeze(1)  # (batch, 1, n_quantiles)
        
        td_errors = target_theta - theta  # (batch, action_dim, n_quantiles)
        
        # 分位数权重
        tau = tau.view(1, 1, -1)  # (1, 1, n_quantiles)
        
        # Huber损失
        huber_loss = torch.where(
            td_errors.abs() <= self.kappa,
            0.5 * td_errors.pow(2),
            self.kappa * (td_errors.abs() - 0.5 * self.kappa)
        )
        
        # 加权求和
        loss = (tau - (td_errors < 0).float()).abs() * huber_loss
        return loss.mean()
    
    def update(self, states, actions, rewards, next_states, dones, gamma=0.99):
        # 计算目标分位数
        with torch.no_grad():
            # 选择下一个状态的最大动作对应的分位数
            next_theta = self.target_net(next_states)  # (batch, action, n_quantiles)
            next_action = next_theta.mean(dim=2).argmax(dim=1)  # (batch,)
            
            # 收集对应动作的分位数
            target_theta = next_theta.gather(1, next_action.unsqueeze(1).unsqueeze(2).expand(-1, 1, self.n_quantiles))
            target_theta = target_theta.squeeze(1)  # (batch, n_quantiles)
            
            # 计算目标
            target_theta = rewards.unsqueeze(1) + gamma * (1 - dones.unsqueeze(1)) * target_theta
        
        # 当前分位数
        theta = self.q_net(states)  # (batch, action, n_quantiles)
        theta = theta.gather(1, actions.unsqueeze(1).unsqueeze(2).expand(-1, 1, self.n_quantiles))
        theta = theta.squeeze(1)  # (batch, n_quantiles)
        
        # 分位数网格
        tau = torch.arange(0, 1, 1/self.n_quantiles)[:self.n_quantiles]
        tau = tau.view(1, -1).to(states.device)  # (1, n_quantiles)
        
        # 计算损失
        loss = self.quantile_huber_loss(theta, target_theta, tau)
        
        # 更新
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()
    
    def get_action(self, state):
        with torch.no_grad():
            theta = self.q_net(state)  # (1, action, n_quantiles)
            q_values = theta.mean(dim=2)  # (1, action)
            return q_values.argmax(dim=1).item()

3.4 QR-DQN的分布表示

价值分布表示:

        概率密度
            │
            │    ╱╲
            │   /  ╲
            │  /    ╲
            │ /      ╲
            │/        ╲________
            └──────────────────── 价值
                                  ↑
                              Q(s,a) = E[价值]

4. IQN (Implicit Quantile Network)

4.1 与QR-DQN的区别

方面QR-DQNIQN
分位数采样固定均匀随机/学习
网络结构共享独立的嵌入
表达力受限于均匀网格更强

4.2 IQN核心思想

IQN使用隐式分位数而非固定网格:

其中 是可变的,通过网络学习。

4.3 网络架构

class IQN(nn.Module):
    def __init__(self, state_dim, action_dim, n_cos_embedding=64, hidden_dim=256):
        super().__init__()
        self.n_cos_embedding = n_cos_embedding
        self.hidden_dim = hidden_dim
        
        # 状态处理网络
        self.state_net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # 余弦嵌入层
        self.cos_embedding = nn.Embedding(n_cos_embedding, hidden_dim)
        
        # τ处理网络
        self.tau_net = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # 合并后的输出网络
        self.output_net = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, state, tau):
        """
        state: (batch, state_dim)
        tau: (batch, n_samples) 分位数水平
        """
        batch_size = state.shape[0]
        n_samples = tau.shape[1]
        
        # 1. 状态编码
        state_embedding = self.state_net(state)  # (batch, hidden)
        
        # 2. τ的余弦嵌入
        # tau范围 [0, 1] -> 索引
        tau_idx = (tau * (self.n_cos_embedding - 1)).long()  # (batch, n_samples)
        tau_idx = tau_idx.clamp(0, self.n_cos_embedding - 1)
        
        tau_embedding = self.cos_embedding(tau_idx)  # (batch, n_samples, hidden)
        
        # 3. τ处理
        tau_embedding = tau_embedding.view(-1, self.hidden_dim)
        tau_processed = self.tau_net(tau_embedding)  # (batch * n_samples, hidden)
        
        # 4. 状态重复n_samples次
        state_rep = state_embedding.unsqueeze(1).expand(-1, n_samples, -1)
        state_rep = state_rep.reshape(-1, self.hidden_dim)  # (batch * n_samples, hidden)
        
        # 5. 合并
        combined = torch.cat([state_rep, tau_processed], dim=1)  # (batch * n_samples, hidden*2)
        
        # 6. 输出
        output = self.output_net(combined)  # (batch * n_samples, action)
        
        # 重塑
        output = output.view(batch_size, n_samples, -1)  # (batch, n_samples, action)
        
        return output
    
    def sample_quantiles(self, state, n_samples=32):
        """采样分位数"""
        tau = torch.rand(state.shape[0], n_samples).to(state.device)
        return self.forward(state, tau)
 
 
class IQNAgent:
    def __init__(self, state_dim, action_dim):
        self.network = IQN(state_dim, action_dim)
        self.target_network = IQN(state_dim, action_dim)
        self.target_network.load_state_dict(self.network.state_dict())
        
        self.optimizer = optim.Adam(self.network.parameters(), lr=1e-4)
    
    def update(self, batch, n_quantile_samples=32):
        states, actions, rewards, next_states, dones = batch
        
        with torch.no_grad():
            # 从target network采样
            next_quantiles = self.target_network.sample_quantiles(next_states, n_quantile_samples)
            next_values = next_quantiles.mean(dim=1)  # (batch, action)
            next_action = next_values.argmax(dim=1)
            
            # 选择对应动作的分位数
            target_quantiles = next_quantiles.gather(2, next_action.unsqueeze(1).unsqueeze(2).expand(-1, n_quantile_samples, 1))
            target_quantiles = target_quantiles.squeeze(2)  # (batch, n_samples)
            
            target = rewards.unsqueeze(1) + 0.99 * (1 - dones.unsqueeze(1)) * target_quantiles
        
        # 当前网络
        tau = torch.rand(states.shape[0], n_quantile_samples).to(states.device)
        current_quantiles = self.network(states, tau)  # (batch, n_samples, action)
        current_quantiles = current_quantiles.gather(2, actions.unsqueeze(1).unsqueeze(2).expand(-1, n_quantile_samples, 1))
        current_quantiles = current_quantiles.squeeze(2)  # (batch, n_samples)
        
        # 计算损失
        td_errors = target.unsqueeze(2) - current_quantiles.unsqueeze(1)  # (batch, n_samples, n_samples)
        
        # 分位数权重
        tau_expanded = tau.unsqueeze(1)  # (batch, 1, n_samples)
        quantile_weight = (tau_expanded - (td_errors < 0).float()).abs()  # (batch, n_samples, n_samples)
        
        # Huber损失
        kappa = 1.0
        huber_loss = torch.where(
            td_errors.abs() <= kappa,
            0.5 * td_errors.pow(2),
            kappa * (td_errors.abs() - 0.5 * kappa)
        )
        
        loss = (quantile_weight * huber_loss).mean()
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()

5. Rainbow算法

5.1 算法概述

Rainbow3 整合了DQN的六项改进:

技术来源论文
双Q学习Double DQN
优先级经验回放Prioritized Experience Replay
决斗网络Dueling Network
多步学习Multi-step Learning
分布RLC51/Distributional RL
噪声网络Noisy Nets

5.2 Rainbow架构

class RainbowDQN(nn.Module):
    def __init__(self, state_dim, action_dim, n_atoms=51, noisy=True):
        super().__init__()
        self.n_atoms = n_atoms
        self.action_dim = action_dim
        
        # 特征提取
        self.feature = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU()
        )
        
        # 价值流 (V)
        if noisy:
            self.value_stream = NoisyLinear(256, n_atoms)
        else:
            self.value_stream = nn.Linear(256, n_atoms)
        
        # 优势流 (A)
        if noisy:
            self.advantage_stream = NoisyLinear(256, action_dim * n_atoms)
        else:
            self.advantage_stream = nn.Linear(256, action_dim * n_atoms)
    
    def forward(self, state):
        features = self.feature(state)
        
        # 价值分布
        v = self.value_stream(features)  # (batch, n_atoms)
        v = F.softmax(v, dim=-1)
        
        # 优势分布
        a = self.advantage_stream(features)  # (batch, action * n_atoms)
        a = a.view(-1, self.action_dim, self.n_atoms)
        a = F.softmax(a, dim=-1)
        
        # 组合:Q = V + A - mean(A)
        q_dist = v.unsqueeze(1) + a - a.mean(dim=1, keepdim=True)
        
        return q_dist
 
 
class NoisyLinear(nn.Module):
    """噪声线性层"""
    def __init__(self, in_features, out_features, sigma_init=0.5):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.sigma_init = sigma_init
        
        self.weight_mu = nn.Parameter(torch.empty(out_features, in_features))
        self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features))
        self.bias_mu = nn.Parameter(torch.empty(out_features))
        self.bias_sigma = nn.Parameter(torch.empty(out_features))
        
        self.register_buffer('weight_epsilon', torch.empty(out_features, in_features))
        self.register_buffer('bias_epsilon', torch.empty(out_features))
        
        self.reset_parameters()
        self.sample_noise()
    
    def reset_parameters(self):
        mu_range = 1 / np.sqrt(self.in_features)
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.weight_sigma.data.fill_(self.sigma_init / np.sqrt(self.in_features))
        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.bias_sigma.data.fill_(self.sigma_init / np.sqrt(self.out_features))
    
    def sample_noise(self):
        with torch.no_grad():
            self.weight_epsilon.copy_(torch.randn(self.out_features, self.in_features))
            self.bias_epsilon.copy_(torch.randn(self.out_features))
    
    def forward(self, x):
        if self.training:
            weight = self.weight_mu + self.weight_sigma * self.weight_epsilon
            bias = self.bias_mu + self.bias_sigma * self.bias_epsilon
        else:
            weight = self.weight_mu
            bias = self.bias_mu
        
        return F.linear(x, weight, bias)

5.3 优先级经验回放

class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.6, beta=0.4):
        self.capacity = capacity
        self.alpha = alpha  # 优先级指数
        self.beta = beta   # 重要性采样指数
        self.buffer = []
        self.priorities = np.zeros(capacity, dtype=np.float32)
        self.position = 0
    
    def push(self, state, action, reward, next_state, done):
        max_priority = self.priorities.max() if self.buffer else 1.0
        
        if len(self.buffer) < self.capacity:
            self.buffer.append((state, action, reward, next_state, done))
        else:
            self.buffer[self.position] = (state, action, reward, next_state, done)
        
        self.priorities[self.position] = max_priority
        self.position = (self.position + 1) % self.capacity
    
    def sample(self, batch_size):
        # 计算采样概率
        priorities = self.priorities[:len(self.buffer)]
        probs = priorities ** self.alpha
        probs /= probs.sum()
        
        # 采样
        indices = np.random.choice(len(self.buffer), batch_size, p=probs, replace=False)
        
        # 计算重要性采样权重
        weights = (len(self.buffer) * probs[indices]) ** (-self.beta)
        weights /= weights.max()
        
        batch = [self.buffer[i] for i in indices]
        states, actions, rewards, next_states, dones = zip(*batch)
        
        return (
            torch.FloatTensor(np.array(states)),
            torch.LongTensor(actions),
            torch.FloatTensor(rewards),
            torch.FloatTensor(np.array(next_states)),
            torch.FloatTensor(dones),
            torch.FloatTensor(weights),
            indices
        )
    
    def update_priorities(self, indices, td_errors):
        for idx, error in zip(indices, td_errors):
            self.priorities[idx] = abs(error) + 1e-5  # 防止零优先级

6. 性能对比

6.1 Atari游戏基准

算法Human-normalized Score
DQN100%
Double DQN117%
Prioritized DQN126%
Dueling DQN132%
Rainbow223%

6.2 消融实验

Rainbow论文的消融实验表明:

  • 每项技术都有正向贡献
  • 组合效果显著超过单技术

7. 实践建议

7.1 算法选择

场景推荐算法
计算资源有限Double DQN
需要不确定性估计QR-DQN / IQN
大规模分布式训练Ape-X
最优性能Rainbow

7.2 超参数设置

config = {
    'qr_dqn': {
        'n_quantiles': 200,
        'kappa': 1.0,
    },
    'iqn': {
        'n_cos_embedding': 64,
        'n_quantile_samples': 32,
    },
    'rainbow': {
        'n_atoms': 51,
        'v_min': -10,
        'v_max': 10,
        'noisy_std': 0.5,
    },
    'prioritized_replay': {
        'alpha': 0.6,
        'beta_start': 0.4,
        'beta_frames': 100000,
    }
}

8. 总结

分布式RL算法显著提升了样本效率和性能:

  1. QR-DQN:通过分位数回归估计价值分布
  2. IQN:使用隐式分位数增强表达力
  3. Rainbow:集成多项技术达到SOTA

核心要点

  • 分布表示 > 点估计
  • 优先级采样 > 均匀采样
  • 集成学习 > 单模型

参考资料


相关主题

Footnotes

  1. Horgan, D., Quan, J., Budden, D., et al. (2018). Distributed Prioritized Experience Replay. ICLR.

  2. Dabney, W., Rowland, M., Bellemare, M., & Munos, R. (2018). Distributional Reinforcement Learning with Quantile Regression. AAAI.

  3. Hessel, M., Modayil, J., van Hasselt, H., et al. (2018). Rainbow: Combining Improvements in Deep Reinforcement Learning. AAAI.