MARL策略梯度方法
1. 从单智能体到多智能体
1.1 标准策略梯度
在单智能体RL中,策略梯度为:
1.2 多智能体策略梯度
当 个智能体同时学习时,梯度变为:
其中联合策略 。
1.3 核心问题
信用分配问题:在团队奖励下, 是联合Q值,无法直接衡量智能体 的贡献。
2. COMA:Counterfactual Multi-Agent Policy Gradients
2.1 核心思想
COMA1提出**反事实基线(Counterfactual Baseline)**来解决信用分配:
当智能体 采取动作 时,其他智能体应该采取什么动作作为基线?
2.2 反事实优势函数
直观理解:
- 减去智能体 所有可能动作的加权平均Q值
- 只保留智能体 贡献的”增量”
2.3 PyTorch实现
class COMA:
"""
Counterfactual Multi-Agent Policy Gradients
"""
def __init__(self, n_agents, state_dim, action_dims):
self.n_agents = n_agents
self.agents = nn.ModuleList([
Actor(state_dim, action_dims[i], hidden=64)
for i in range(n_agents)
])
# 中心化Critic:可以访问全局状态和所有动作
self.critic = CentralizedCritic(state_dim, action_dims)
# 反事实基线计算
self.coma_critic = COMACritic(self.critic, n_agents, action_dims)
def update(self, batch):
"""
batch包含:states, actions, rewards, next_states, dones
"""
# 计算反事实优势
advantages = self.coma_critic.compute_advantage(
batch.states,
batch.actions,
batch.next_states
)
# 对每个智能体更新策略
total_loss = 0
for i, agent in enumerate(self.agents):
# 策略梯度
log_probs = agent.get_log_prob(batch.states[:, i], batch.actions[:, i])
# 使用反事实优势
loss = -(advantages[:, i] * log_probs).mean()
total_loss += loss
# 更新Critic
critic_loss = self.critic.update(batch)
return total_loss + critic_loss
class COMACritic(nn.Module):
"""COMA的Critic计算反事实优势"""
def __init__(self, centralized_critic, n_agents, action_dims):
super().__init__()
self.critic = centralized_critic
self.n_agents = n_agents
self.action_dims = action_dims
def compute_advantage(self, states, actions, next_states=None):
"""
计算每个智能体的反事实优势
states: [batch, n_agents, state_dim] 或 [batch, state_dim]
actions: [batch, n_agents]
"""
batch_size = states.size(0)
n_agents = self.n_agents
# 联合Q值
joint_q = self.critic(states, actions) # [batch, 1]
advantages = torch.zeros(batch_size, n_agents, device=states.device)
for i in range(n_agents):
# 对于智能体i,计算所有可能动作的Q值
action_dim = self.action_dims[i]
# 当前动作的Q值
q_current = joint_q # [batch, 1]
# 反事实Q值:对其他智能体保持不变,只改变智能体i
q_counterfactual_sum = torch.zeros(batch_size, device=states.device)
# 获取智能体i的策略概率
probs = self.agents[i].get_probs(states[:, i] if states.dim() > 2 else states)
# 遍历所有可能动作
for a in range(action_dim):
counterfactual_actions = actions.clone()
counterfactual_actions[:, i] = a
q_a = self.critic(states, counterfactual_actions)
q_counterfactual_sum += probs[:, a] * q_a.squeeze(-1)
# 反事实优势
advantages[:, i] = q_current.squeeze(-1) - q_counterfactual_sum
return advantages.detach() # 优势作为baseline2.4 COMA的局限性
- 计算成本高:需要计算 次前向传播
- Critic表示限制:需要特殊设计才能高效计算反事实基线
3. MADDPG:Multi-Agent Actor-Critic
3.1 核心思想
MADDPG2为每个智能体使用中心化Critic+去中心化Actor:
MADDPG架构:
训练阶段(中心化):
┌──────────────────────────────────────────────┐
│ Critic Q(s, a¹, ..., aⁿ) │
│ ↑ ↑ ↑ │
│ Actor1 Actor2 Actor3 │
│ ↑ ↑ ↑ │
│ π¹(·|o¹) π²(·|o²) π³(·|o³) │
└──────────────────────────────────────────────┘
执行阶段(去中心化):
┌──────────────────────────────────────────────┐
│ 智能体1 智能体2 智能体3 │
│ π¹(·|o¹) π²(·|o²) π³(·|o³) │
└──────────────────────────────────────────────┘
3.2 实现
class MADDPGAgent:
"""MADDPG单个智能体"""
def __init__(self, obs_dim, action_dim, hidden=64):
# Actor:去中心化,只输入自己的观察
self.actor = nn.Sequential(
nn.Linear(obs_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, action_dim),
nn.Tanh()
)
self.target_actor = copy.deepcopy(self.actor)
# Critic:中心化,输入全局状态和所有动作
self.critic = nn.Sequential(
nn.Linear(obs_dim * n_agents + action_dim * n_agents, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, 1)
)
self.target_critic = copy.deepcopy(self.critic)
def get_action(self, obs, noise=0.1):
"""执行时:只用Actor"""
action = self.actor(obs)
action += noise * torch.randn_like(action)
return action.clamp(-1, 1)
def update(self, agent_id, all_agents, batch):
"""
MADDPG更新
"""
obs = batch['obs'][:, agent_id]
actions = batch['actions'] # 所有智能体的动作
rewards = batch['rewards'][:, agent_id] # 该智能体的奖励
next_obs = batch['next_obs']
dones = batch['dones']
# ===== Critic更新 =====
# 目标动作(使用所有智能体的目标Actor)
with torch.no_grad():
target_actions = torch.stack([
agent.target_actor(next_obs[:, i])
for i, agent in enumerate(all_agents)
], dim=1)
# 拼接状态和目标动作
target_input = torch.cat([next_obs, target_actions], dim=1)
target_q = self.target_critic(target_input)
# TD目标
y = rewards.unsqueeze(-1) + self.gamma * target_q * (1 - dones.unsqueeze(-1))
# 当前Q值
current_input = torch.cat([obs, actions], dim=1)
current_q = self.critic(current_input)
critic_loss = F.mse_loss(current_q, y)
# ===== Actor更新 =====
# 使用当前策略生成动作
policy_actions = torch.stack([
agent.actor(obs) if i == agent_id
else agent.actor(obs) # 实际使用batch中的动作
for i, agent in enumerate(all_agents)
], dim=1)
# 用当前Critic评估(梯度传到Actor)
policy_input = torch.cat([obs, policy_actions], dim=1)
actor_loss = -self.critic(policy_input).mean()
return critic_loss, actor_loss4. MAVEN:Multi-Agent Variational Exploration
4.1 核心思想
MAVEN3引入潜在变量来协调多智能体的探索:
MAVEN架构:
z (潜在变量,控制协调模式)
│
├──→ [Mixture Network] ──→ Q_tot
│ ↑
│ ┌─────────────────────────┘
│ │
│ ├──→ Q₁(s, a¹ | z)
│ ├──→ Q₂(s, a² | z)
│ └──→ Qₙ(s, aⁿ | z)
4.2 潜在空间分解
4.3 实现
class MAVEN:
def __init__(self, n_agents, obs_dim, action_dims, z_dim=8):
self.n_agents = n_agents
self.z_dim = z_dim
# 潜在编码器
self.z_encoder = nn.Sequential(
nn.Linear(obs_dim, 64),
nn.ReLU(),
nn.Linear(64, z_dim * 2) # 均值和方差
)
# 每个智能体的Q网络(依赖z)
self.q_nets = nn.ModuleList([
ConditionalQNetwork(obs_dim, action_dims[i], z_dim, hidden=64)
for i in range(n_agents)
])
# 混合网络
self.mixer = QMIXMixer(n_agents, obs_dim + z_dim)
def sample_z(self, obs):
"""从近似后验采样潜在变量"""
z_params = self.z_encoder(obs)
mean, log_std = z_params.chunk(2, dim=-1)
std = torch.exp(log_std)
# 重参数化采样
z = mean + std * torch.randn_like(std)
return z, mean, log_std
def forward(self, states, actions=None, z=None):
if z is None:
# 从观察中推断z
z = self.sample_z(states[:, 0])[0]
# 每个智能体计算条件Q值
individual_qs = []
for i, q_net in enumerate(self.q_nets):
q_i = q_net(states[:, i], z)
if actions is not None:
q_i = q_i.gather(-1, actions[:, i:i+1]).squeeze(-1)
individual_qs.append(q_i)
individual_qs = torch.stack(individual_qs, dim=1)
# 混合
z_expanded = z.unsqueeze(1).expand(-1, self.n_agents, -1)
states_with_z = torch.cat([states, z_expanded], dim=-1)
q_tot = self.mixer(individual_qs, states_with_z)
return q_tot, individual_qs, z
def update(self, batch):
z, mean, log_std = self.sample_z(batch['obs'][:, 0])
# KL散度正则项
kl_loss = -0.5 * (1 + log_std - mean.pow(2) - log_std.exp()).mean()
# Q学习
q_tot, _, _ = self.forward(batch['obs'], batch['actions'], z)
with torch.no_grad():
next_z, _, _ = self.sample_z(batch['next_obs'][:, 0])
target_q, _, _ = self.forward(batch['next_obs'], None, next_z)
target = batch['rewards'].sum(dim=-1, keepdim=True) + \
self.gamma * target_q.max(dim=1, keepdim=True)[0]
q_loss = F.mse_loss(q_tot, target)
return q_loss + 0.1 * kl_loss # 平衡项5. 通信增强的策略梯度
5.1 DIAL:Differentiable Inter-Agent Learning
DIAL4允许梯度跨智能体传播:
class DIALAgent:
def __init__(self, obs_dim, action_dim, message_dim=8):
self.message_net = nn.Linear(obs_dim, message_dim)
self.critic = CentralizedCritic(obs_dim + message_dim, action_dim)
def forward(self, obs, incoming_message=None):
# 编码观察
encoded = F.relu(self.message_net(obs))
# 如果有传入消息,拼接
if incoming_message is not None:
encoded = torch.cat([encoded, incoming_message], dim=-1)
# 生成发送给其他智能体的消息
message = self.message_net(encoded)
# 策略输出
action = self.actor(encoded)
return action, message5.2 通信协议学习
class CommNet:
"""
通信网络:学习何时通信以及传递什么信息
"""
def __init__(self, n_agents, obs_dim, message_dim=8):
self.n_agents = n_agents
self.encoder = nn.Linear(obs_dim, 64)
self.message_pass = MessagePassingLayer(64, message_dim)
self.decoder = nn.Linear(64 + message_dim, 1) # 决定是否通信
def forward(self, observations, steps=3):
"""
多轮消息传递
"""
batch_size = observations[0].size(0)
hidden = torch.stack([
self.encoder(obs) for obs in observations
], dim=1) # [batch, n_agents, hidden]
for _ in range(steps):
# 消息传递
messages = self.message_pass(hidden)
# 更新隐藏状态
hidden = torch.cat([hidden, messages], dim=-1)
hidden = F.relu(self.encoder(hidden))
# 决定通信强度
comm_weights = torch.sigmoid(self.decoder(hidden))
return hidden, comm_weights6. 算法对比
| 算法 | Critic类型 | 信用分配 | 通信 | 适用场景 |
|---|---|---|---|---|
| COMA | 中心化 | 反事实基线 | 无 | 合作式 |
| MADDPG | 中心化 | 无 | 无 | 混合博弈 |
| MAVEN | 中心化 | 潜在变量 | 无 | 合作式 |
| DIAL | 中心化 | 无 | 可微 | 需要通信 |
7. 参考文献
Footnotes
-
Foerster et al. “Counterfactual Multi-Agent Policy Gradients” AAAI 2018 ↩
-
Lowe et al. “Multi-Agent Actor-Critic for Mixed Cooperative-Competitive Environments” NeurIPS 2017 ↩
-
Qian et al. “MAVEN: Multi-Agent Variational Exploration” NeurIPS 2019 ↩
-
Foerster et al. “Learning to Communicate with Deep Multi-Agent Reinforcement Learning” NeurIPS 2016 ↩