分布式强化学习算法
1. 概述
分布式强化学习通过并行收集经验和训练来加速学习过程。1 本文档介绍几种重要的分布式RL算法,包括分位数回归方法(QR-DQN、IQN)和集成方法(Rainbow)。
1.1 分布式RL架构
┌─────────────────────────────────────────────────────────────┐
│ 分布式RL架构 │
├─────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ Actor 1 │ │ Actor 2 │ │ Actor 3 │ │ Actor N │ │
│ │ (收集) │ │ (收集) │ │ (收集) │ │ (收集) │ │
│ └────┬────┘ └────┬────┘ └────┬────┘ └────┬────┘ │
│ │ │ │ │ │
│ └────────────┴─────┬──────┴────────────┘ │
│ │ │
│ ┌─────▼─────┐ │
│ │ Replay │ │
│ │ Buffer │ │
│ └─────┬─────┘ │
│ │ │
│ ┌─────▼─────┐ │
│ │ Learner │ │
│ │ (训练) │ │
│ └─────┬─────┘ │
│ │ │
│ ┌─────────────────┼─────────────────┐ │
│ ▼ ▼ ▼ │
│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
│ │ Q-Network│◄────│ Q-Network│◄────│ Q-Network│ │
│ │ (θ_1) │ │ (θ_2) │ │ (θ_N) │ │
│ └─────────┘ └─────────┘ └─────────┘ │
│ │
└─────────────────────────────────────────────────────────────┘
1.2 分布式RL的优势
| 方面 | 单进程RL | 分布式RL |
|---|---|---|
| 样本收集速度 | 慢 | 快 |
| GPU利用率 | 可能低 | 高 |
| 训练稳定性 | 中等 | 高 |
| 系统复杂度 | 低 | 高 |
2. 分位数回归基础
2.1 分位数定义
对于随机变量 ,其 -分位数为:
其中 是累积分布函数。
2.2 分位数回归损失
分位数回归通过最小化以下损失来估计分位数:
其中 是分位数损失函数。
2.3 分位数损失性质
def quantile_loss(actions, target_actions, tau):
"""
分位数损失
tau: 分位数水平 (0到1之间)
"""
u = target_actions - actions
loss = u * (tau - (u < 0).float())
return loss.abs().mean()3. QR-DQN (Quantile Regression DQN)
3.1 算法动机
标准DQN使用点估计 ,无法捕获价值函数的不确定性。
QR-DQN2 通过学习价值函数分布来解决:
3.2 分位数网格
QR-DQN在区间均匀分布 个分位数:
每个 对应一个可学习的分位数值 。
3.3 算法实现
class QuantileNetwork(nn.Module):
def __init__(self, state_dim, action_dim, n_quantiles=200):
super().__init__()
self.n_quantiles = n_quantiles
self.net = nn.Sequential(
nn.Linear(state_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, action_dim * n_quantiles)
)
def forward(self, state):
# 输出: (batch, action_dim * n_quantiles)
return self.net(state).view(-1, action_dim, self.n_quantiles)
class QRDQN:
def __init__(self, state_dim, action_dim, n_quantiles=200, kappa=1.0):
self.n_quantiles = n_quantiles
self.kappa = kappa # Huber损失中的阈值
self.q_net = QuantileNetwork(state_dim, action_dim, n_quantiles)
self.target_net = QuantileNetwork(state_dim, action_dim, n_quantiles)
self.target_net.load_state_dict(self.q_net.state_dict())
self.optimizer = optim.Adam(self.q_net.parameters(), lr=1e-4)
def quantile_huber_loss(self, theta, target_theta, tau):
"""
分位数Huber损失
theta: 当前分位数 (batch, action_dim, n_quantiles)
target_theta: 目标分位数 (batch, n_quantiles)
"""
batch_size = theta.shape[0]
# 计算TD误差
# target_theta: (batch, n_quantiles, 1)
# theta: (batch, action_dim, n_quantiles)
target_theta = target_theta.unsqueeze(1) # (batch, 1, n_quantiles)
td_errors = target_theta - theta # (batch, action_dim, n_quantiles)
# 分位数权重
tau = tau.view(1, 1, -1) # (1, 1, n_quantiles)
# Huber损失
huber_loss = torch.where(
td_errors.abs() <= self.kappa,
0.5 * td_errors.pow(2),
self.kappa * (td_errors.abs() - 0.5 * self.kappa)
)
# 加权求和
loss = (tau - (td_errors < 0).float()).abs() * huber_loss
return loss.mean()
def update(self, states, actions, rewards, next_states, dones, gamma=0.99):
# 计算目标分位数
with torch.no_grad():
# 选择下一个状态的最大动作对应的分位数
next_theta = self.target_net(next_states) # (batch, action, n_quantiles)
next_action = next_theta.mean(dim=2).argmax(dim=1) # (batch,)
# 收集对应动作的分位数
target_theta = next_theta.gather(1, next_action.unsqueeze(1).unsqueeze(2).expand(-1, 1, self.n_quantiles))
target_theta = target_theta.squeeze(1) # (batch, n_quantiles)
# 计算目标
target_theta = rewards.unsqueeze(1) + gamma * (1 - dones.unsqueeze(1)) * target_theta
# 当前分位数
theta = self.q_net(states) # (batch, action, n_quantiles)
theta = theta.gather(1, actions.unsqueeze(1).unsqueeze(2).expand(-1, 1, self.n_quantiles))
theta = theta.squeeze(1) # (batch, n_quantiles)
# 分位数网格
tau = torch.arange(0, 1, 1/self.n_quantiles)[:self.n_quantiles]
tau = tau.view(1, -1).to(states.device) # (1, n_quantiles)
# 计算损失
loss = self.quantile_huber_loss(theta, target_theta, tau)
# 更新
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def get_action(self, state):
with torch.no_grad():
theta = self.q_net(state) # (1, action, n_quantiles)
q_values = theta.mean(dim=2) # (1, action)
return q_values.argmax(dim=1).item()3.4 QR-DQN的分布表示
价值分布表示:
概率密度
│
│ ╱╲
│ / ╲
│ / ╲
│ / ╲
│/ ╲________
└──────────────────── 价值
↑
Q(s,a) = E[价值]
4. IQN (Implicit Quantile Network)
4.1 与QR-DQN的区别
| 方面 | QR-DQN | IQN |
|---|---|---|
| 分位数采样 | 固定均匀 | 随机/学习 |
| 网络结构 | 共享 | 独立的嵌入 |
| 表达力 | 受限于均匀网格 | 更强 |
4.2 IQN核心思想
IQN使用隐式分位数而非固定网格:
其中 是可变的,通过网络学习。
4.3 网络架构
class IQN(nn.Module):
def __init__(self, state_dim, action_dim, n_cos_embedding=64, hidden_dim=256):
super().__init__()
self.n_cos_embedding = n_cos_embedding
self.hidden_dim = hidden_dim
# 状态处理网络
self.state_net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# 余弦嵌入层
self.cos_embedding = nn.Embedding(n_cos_embedding, hidden_dim)
# τ处理网络
self.tau_net = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# 合并后的输出网络
self.output_net = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
def forward(self, state, tau):
"""
state: (batch, state_dim)
tau: (batch, n_samples) 分位数水平
"""
batch_size = state.shape[0]
n_samples = tau.shape[1]
# 1. 状态编码
state_embedding = self.state_net(state) # (batch, hidden)
# 2. τ的余弦嵌入
# tau范围 [0, 1] -> 索引
tau_idx = (tau * (self.n_cos_embedding - 1)).long() # (batch, n_samples)
tau_idx = tau_idx.clamp(0, self.n_cos_embedding - 1)
tau_embedding = self.cos_embedding(tau_idx) # (batch, n_samples, hidden)
# 3. τ处理
tau_embedding = tau_embedding.view(-1, self.hidden_dim)
tau_processed = self.tau_net(tau_embedding) # (batch * n_samples, hidden)
# 4. 状态重复n_samples次
state_rep = state_embedding.unsqueeze(1).expand(-1, n_samples, -1)
state_rep = state_rep.reshape(-1, self.hidden_dim) # (batch * n_samples, hidden)
# 5. 合并
combined = torch.cat([state_rep, tau_processed], dim=1) # (batch * n_samples, hidden*2)
# 6. 输出
output = self.output_net(combined) # (batch * n_samples, action)
# 重塑
output = output.view(batch_size, n_samples, -1) # (batch, n_samples, action)
return output
def sample_quantiles(self, state, n_samples=32):
"""采样分位数"""
tau = torch.rand(state.shape[0], n_samples).to(state.device)
return self.forward(state, tau)
class IQNAgent:
def __init__(self, state_dim, action_dim):
self.network = IQN(state_dim, action_dim)
self.target_network = IQN(state_dim, action_dim)
self.target_network.load_state_dict(self.network.state_dict())
self.optimizer = optim.Adam(self.network.parameters(), lr=1e-4)
def update(self, batch, n_quantile_samples=32):
states, actions, rewards, next_states, dones = batch
with torch.no_grad():
# 从target network采样
next_quantiles = self.target_network.sample_quantiles(next_states, n_quantile_samples)
next_values = next_quantiles.mean(dim=1) # (batch, action)
next_action = next_values.argmax(dim=1)
# 选择对应动作的分位数
target_quantiles = next_quantiles.gather(2, next_action.unsqueeze(1).unsqueeze(2).expand(-1, n_quantile_samples, 1))
target_quantiles = target_quantiles.squeeze(2) # (batch, n_samples)
target = rewards.unsqueeze(1) + 0.99 * (1 - dones.unsqueeze(1)) * target_quantiles
# 当前网络
tau = torch.rand(states.shape[0], n_quantile_samples).to(states.device)
current_quantiles = self.network(states, tau) # (batch, n_samples, action)
current_quantiles = current_quantiles.gather(2, actions.unsqueeze(1).unsqueeze(2).expand(-1, n_quantile_samples, 1))
current_quantiles = current_quantiles.squeeze(2) # (batch, n_samples)
# 计算损失
td_errors = target.unsqueeze(2) - current_quantiles.unsqueeze(1) # (batch, n_samples, n_samples)
# 分位数权重
tau_expanded = tau.unsqueeze(1) # (batch, 1, n_samples)
quantile_weight = (tau_expanded - (td_errors < 0).float()).abs() # (batch, n_samples, n_samples)
# Huber损失
kappa = 1.0
huber_loss = torch.where(
td_errors.abs() <= kappa,
0.5 * td_errors.pow(2),
kappa * (td_errors.abs() - 0.5 * kappa)
)
loss = (quantile_weight * huber_loss).mean()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()5. Rainbow算法
5.1 算法概述
Rainbow3 整合了DQN的六项改进:
| 技术 | 来源论文 |
|---|---|
| 双Q学习 | Double DQN |
| 优先级经验回放 | Prioritized Experience Replay |
| 决斗网络 | Dueling Network |
| 多步学习 | Multi-step Learning |
| 分布RL | C51/Distributional RL |
| 噪声网络 | Noisy Nets |
5.2 Rainbow架构
class RainbowDQN(nn.Module):
def __init__(self, state_dim, action_dim, n_atoms=51, noisy=True):
super().__init__()
self.n_atoms = n_atoms
self.action_dim = action_dim
# 特征提取
self.feature = nn.Sequential(
nn.Linear(state_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU()
)
# 价值流 (V)
if noisy:
self.value_stream = NoisyLinear(256, n_atoms)
else:
self.value_stream = nn.Linear(256, n_atoms)
# 优势流 (A)
if noisy:
self.advantage_stream = NoisyLinear(256, action_dim * n_atoms)
else:
self.advantage_stream = nn.Linear(256, action_dim * n_atoms)
def forward(self, state):
features = self.feature(state)
# 价值分布
v = self.value_stream(features) # (batch, n_atoms)
v = F.softmax(v, dim=-1)
# 优势分布
a = self.advantage_stream(features) # (batch, action * n_atoms)
a = a.view(-1, self.action_dim, self.n_atoms)
a = F.softmax(a, dim=-1)
# 组合:Q = V + A - mean(A)
q_dist = v.unsqueeze(1) + a - a.mean(dim=1, keepdim=True)
return q_dist
class NoisyLinear(nn.Module):
"""噪声线性层"""
def __init__(self, in_features, out_features, sigma_init=0.5):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.sigma_init = sigma_init
self.weight_mu = nn.Parameter(torch.empty(out_features, in_features))
self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features))
self.bias_mu = nn.Parameter(torch.empty(out_features))
self.bias_sigma = nn.Parameter(torch.empty(out_features))
self.register_buffer('weight_epsilon', torch.empty(out_features, in_features))
self.register_buffer('bias_epsilon', torch.empty(out_features))
self.reset_parameters()
self.sample_noise()
def reset_parameters(self):
mu_range = 1 / np.sqrt(self.in_features)
self.weight_mu.data.uniform_(-mu_range, mu_range)
self.weight_sigma.data.fill_(self.sigma_init / np.sqrt(self.in_features))
self.bias_mu.data.uniform_(-mu_range, mu_range)
self.bias_sigma.data.fill_(self.sigma_init / np.sqrt(self.out_features))
def sample_noise(self):
with torch.no_grad():
self.weight_epsilon.copy_(torch.randn(self.out_features, self.in_features))
self.bias_epsilon.copy_(torch.randn(self.out_features))
def forward(self, x):
if self.training:
weight = self.weight_mu + self.weight_sigma * self.weight_epsilon
bias = self.bias_mu + self.bias_sigma * self.bias_epsilon
else:
weight = self.weight_mu
bias = self.bias_mu
return F.linear(x, weight, bias)5.3 优先级经验回放
class PrioritizedReplayBuffer:
def __init__(self, capacity, alpha=0.6, beta=0.4):
self.capacity = capacity
self.alpha = alpha # 优先级指数
self.beta = beta # 重要性采样指数
self.buffer = []
self.priorities = np.zeros(capacity, dtype=np.float32)
self.position = 0
def push(self, state, action, reward, next_state, done):
max_priority = self.priorities.max() if self.buffer else 1.0
if len(self.buffer) < self.capacity:
self.buffer.append((state, action, reward, next_state, done))
else:
self.buffer[self.position] = (state, action, reward, next_state, done)
self.priorities[self.position] = max_priority
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size):
# 计算采样概率
priorities = self.priorities[:len(self.buffer)]
probs = priorities ** self.alpha
probs /= probs.sum()
# 采样
indices = np.random.choice(len(self.buffer), batch_size, p=probs, replace=False)
# 计算重要性采样权重
weights = (len(self.buffer) * probs[indices]) ** (-self.beta)
weights /= weights.max()
batch = [self.buffer[i] for i in indices]
states, actions, rewards, next_states, dones = zip(*batch)
return (
torch.FloatTensor(np.array(states)),
torch.LongTensor(actions),
torch.FloatTensor(rewards),
torch.FloatTensor(np.array(next_states)),
torch.FloatTensor(dones),
torch.FloatTensor(weights),
indices
)
def update_priorities(self, indices, td_errors):
for idx, error in zip(indices, td_errors):
self.priorities[idx] = abs(error) + 1e-5 # 防止零优先级6. 性能对比
6.1 Atari游戏基准
| 算法 | Human-normalized Score |
|---|---|
| DQN | 100% |
| Double DQN | 117% |
| Prioritized DQN | 126% |
| Dueling DQN | 132% |
| Rainbow | 223% |
6.2 消融实验
Rainbow论文的消融实验表明:
- 每项技术都有正向贡献
- 组合效果显著超过单技术
7. 实践建议
7.1 算法选择
| 场景 | 推荐算法 |
|---|---|
| 计算资源有限 | Double DQN |
| 需要不确定性估计 | QR-DQN / IQN |
| 大规模分布式训练 | Ape-X |
| 最优性能 | Rainbow |
7.2 超参数设置
config = {
'qr_dqn': {
'n_quantiles': 200,
'kappa': 1.0,
},
'iqn': {
'n_cos_embedding': 64,
'n_quantile_samples': 32,
},
'rainbow': {
'n_atoms': 51,
'v_min': -10,
'v_max': 10,
'noisy_std': 0.5,
},
'prioritized_replay': {
'alpha': 0.6,
'beta_start': 0.4,
'beta_frames': 100000,
}
}8. 总结
分布式RL算法显著提升了样本效率和性能:
- QR-DQN:通过分位数回归估计价值分布
- IQN:使用隐式分位数增强表达力
- Rainbow:集成多项技术达到SOTA
核心要点:
- 分布表示 > 点估计
- 优先级采样 > 均匀采样
- 集成学习 > 单模型
参考资料
相关主题
- dqn — DQN基础
- soft-actor-critic — 最大熵RL
- td3-twin-delayed-ddpg — 连续控制
- exploration-exploitation-rl — 探索策略
Footnotes
-
Horgan, D., Quan, J., Budden, D., et al. (2018). Distributed Prioritized Experience Replay. ICLR. ↩
-
Dabney, W., Rowland, M., Bellemare, M., & Munos, R. (2018). Distributional Reinforcement Learning with Quantile Regression. AAAI. ↩
-
Hessel, M., Modayil, J., van Hasselt, H., et al. (2018). Rainbow: Combining Improvements in Deep Reinforcement Learning. AAAI. ↩