MARL策略梯度方法

1. 从单智能体到多智能体

1.1 标准策略梯度

在单智能体RL中,策略梯度为:

1.2 多智能体策略梯度

个智能体同时学习时,梯度变为:

其中联合策略

1.3 核心问题

信用分配问题:在团队奖励下, 是联合Q值,无法直接衡量智能体 的贡献。

2. COMA:Counterfactual Multi-Agent Policy Gradients

2.1 核心思想

COMA1提出**反事实基线(Counterfactual Baseline)**来解决信用分配:

当智能体 采取动作 时,其他智能体应该采取什么动作作为基线?

2.2 反事实优势函数

直观理解

  • 减去智能体 所有可能动作的加权平均Q值
  • 只保留智能体 贡献的”增量”

2.3 PyTorch实现

class COMA:
    """
    Counterfactual Multi-Agent Policy Gradients
    """
    
    def __init__(self, n_agents, state_dim, action_dims):
        self.n_agents = n_agents
        self.agents = nn.ModuleList([
            Actor(state_dim, action_dims[i], hidden=64)
            for i in range(n_agents)
        ])
        
        # 中心化Critic:可以访问全局状态和所有动作
        self.critic = CentralizedCritic(state_dim, action_dims)
        
        # 反事实基线计算
        self.coma_critic = COMACritic(self.critic, n_agents, action_dims)
    
    def update(self, batch):
        """
        batch包含:states, actions, rewards, next_states, dones
        """
        # 计算反事实优势
        advantages = self.coma_critic.compute_advantage(
            batch.states,
            batch.actions,
            batch.next_states
        )
        
        # 对每个智能体更新策略
        total_loss = 0
        for i, agent in enumerate(self.agents):
            # 策略梯度
            log_probs = agent.get_log_prob(batch.states[:, i], batch.actions[:, i])
            
            # 使用反事实优势
            loss = -(advantages[:, i] * log_probs).mean()
            total_loss += loss
        
        # 更新Critic
        critic_loss = self.critic.update(batch)
        
        return total_loss + critic_loss
 
 
class COMACritic(nn.Module):
    """COMA的Critic计算反事实优势"""
    
    def __init__(self, centralized_critic, n_agents, action_dims):
        super().__init__()
        self.critic = centralized_critic
        self.n_agents = n_agents
        self.action_dims = action_dims
    
    def compute_advantage(self, states, actions, next_states=None):
        """
        计算每个智能体的反事实优势
        states: [batch, n_agents, state_dim] 或 [batch, state_dim]
        actions: [batch, n_agents]
        """
        batch_size = states.size(0)
        n_agents = self.n_agents
        
        # 联合Q值
        joint_q = self.critic(states, actions)  # [batch, 1]
        
        advantages = torch.zeros(batch_size, n_agents, device=states.device)
        
        for i in range(n_agents):
            # 对于智能体i,计算所有可能动作的Q值
            action_dim = self.action_dims[i]
            
            # 当前动作的Q值
            q_current = joint_q  # [batch, 1]
            
            # 反事实Q值:对其他智能体保持不变,只改变智能体i
            q_counterfactual_sum = torch.zeros(batch_size, device=states.device)
            
            # 获取智能体i的策略概率
            probs = self.agents[i].get_probs(states[:, i] if states.dim() > 2 else states)
            
            # 遍历所有可能动作
            for a in range(action_dim):
                counterfactual_actions = actions.clone()
                counterfactual_actions[:, i] = a
                
                q_a = self.critic(states, counterfactual_actions)
                q_counterfactual_sum += probs[:, a] * q_a.squeeze(-1)
            
            # 反事实优势
            advantages[:, i] = q_current.squeeze(-1) - q_counterfactual_sum
        
        return advantages.detach()  # 优势作为baseline

2.4 COMA的局限性

  1. 计算成本高:需要计算 次前向传播
  2. Critic表示限制:需要特殊设计才能高效计算反事实基线

3. MADDPG:Multi-Agent Actor-Critic

3.1 核心思想

MADDPG2为每个智能体使用中心化Critic+去中心化Actor

MADDPG架构:

训练阶段(中心化):
┌──────────────────────────────────────────────┐
│   Critic Q(s, a¹, ..., aⁿ)                  │
│          ↑      ↑          ↑                │
│       Actor1  Actor2    Actor3               │
│          ↑      ↑          ↑                │
│       π¹(·|o¹) π²(·|o²) π³(·|o³)          │
└──────────────────────────────────────────────┘

执行阶段(去中心化):
┌──────────────────────────────────────────────┐
│   智能体1     智能体2     智能体3           │
│   π¹(·|o¹)   π²(·|o²)   π³(·|o³)          │
└──────────────────────────────────────────────┘

3.2 实现

class MADDPGAgent:
    """MADDPG单个智能体"""
    
    def __init__(self, obs_dim, action_dim, hidden=64):
        # Actor:去中心化,只输入自己的观察
        self.actor = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, action_dim),
            nn.Tanh()
        )
        self.target_actor = copy.deepcopy(self.actor)
        
        # Critic:中心化,输入全局状态和所有动作
        self.critic = nn.Sequential(
            nn.Linear(obs_dim * n_agents + action_dim * n_agents, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )
        self.target_critic = copy.deepcopy(self.critic)
    
    def get_action(self, obs, noise=0.1):
        """执行时:只用Actor"""
        action = self.actor(obs)
        action += noise * torch.randn_like(action)
        return action.clamp(-1, 1)
    
    def update(self, agent_id, all_agents, batch):
        """
        MADDPG更新
        """
        obs = batch['obs'][:, agent_id]
        actions = batch['actions']  # 所有智能体的动作
        rewards = batch['rewards'][:, agent_id]  # 该智能体的奖励
        next_obs = batch['next_obs']
        dones = batch['dones']
        
        # ===== Critic更新 =====
        # 目标动作(使用所有智能体的目标Actor)
        with torch.no_grad():
            target_actions = torch.stack([
                agent.target_actor(next_obs[:, i])
                for i, agent in enumerate(all_agents)
            ], dim=1)
            
            # 拼接状态和目标动作
            target_input = torch.cat([next_obs, target_actions], dim=1)
            target_q = self.target_critic(target_input)
            
            # TD目标
            y = rewards.unsqueeze(-1) + self.gamma * target_q * (1 - dones.unsqueeze(-1))
        
        # 当前Q值
        current_input = torch.cat([obs, actions], dim=1)
        current_q = self.critic(current_input)
        
        critic_loss = F.mse_loss(current_q, y)
        
        # ===== Actor更新 =====
        # 使用当前策略生成动作
        policy_actions = torch.stack([
            agent.actor(obs) if i == agent_id
            else agent.actor(obs)  # 实际使用batch中的动作
            for i, agent in enumerate(all_agents)
        ], dim=1)
        
        # 用当前Critic评估(梯度传到Actor)
        policy_input = torch.cat([obs, policy_actions], dim=1)
        actor_loss = -self.critic(policy_input).mean()
        
        return critic_loss, actor_loss

4. MAVEN:Multi-Agent Variational Exploration

4.1 核心思想

MAVEN3引入潜在变量来协调多智能体的探索:

MAVEN架构:

z (潜在变量,控制协调模式)
  │
  ├──→ [Mixture Network] ──→ Q_tot
  │                              ↑
  │    ┌─────────────────────────┘
  │    │
  │    ├──→ Q₁(s, a¹ | z)
  │    ├──→ Q₂(s, a² | z)
  │    └──→ Qₙ(s, aⁿ | z)

4.2 潜在空间分解

4.3 实现

class MAVEN:
    def __init__(self, n_agents, obs_dim, action_dims, z_dim=8):
        self.n_agents = n_agents
        self.z_dim = z_dim
        
        # 潜在编码器
        self.z_encoder = nn.Sequential(
            nn.Linear(obs_dim, 64),
            nn.ReLU(),
            nn.Linear(64, z_dim * 2)  # 均值和方差
        )
        
        # 每个智能体的Q网络(依赖z)
        self.q_nets = nn.ModuleList([
            ConditionalQNetwork(obs_dim, action_dims[i], z_dim, hidden=64)
            for i in range(n_agents)
        ])
        
        # 混合网络
        self.mixer = QMIXMixer(n_agents, obs_dim + z_dim)
    
    def sample_z(self, obs):
        """从近似后验采样潜在变量"""
        z_params = self.z_encoder(obs)
        mean, log_std = z_params.chunk(2, dim=-1)
        std = torch.exp(log_std)
        
        # 重参数化采样
        z = mean + std * torch.randn_like(std)
        return z, mean, log_std
    
    def forward(self, states, actions=None, z=None):
        if z is None:
            # 从观察中推断z
            z = self.sample_z(states[:, 0])[0]
        
        # 每个智能体计算条件Q值
        individual_qs = []
        for i, q_net in enumerate(self.q_nets):
            q_i = q_net(states[:, i], z)
            if actions is not None:
                q_i = q_i.gather(-1, actions[:, i:i+1]).squeeze(-1)
            individual_qs.append(q_i)
        
        individual_qs = torch.stack(individual_qs, dim=1)
        
        # 混合
        z_expanded = z.unsqueeze(1).expand(-1, self.n_agents, -1)
        states_with_z = torch.cat([states, z_expanded], dim=-1)
        q_tot = self.mixer(individual_qs, states_with_z)
        
        return q_tot, individual_qs, z
    
    def update(self, batch):
        z, mean, log_std = self.sample_z(batch['obs'][:, 0])
        
        # KL散度正则项
        kl_loss = -0.5 * (1 + log_std - mean.pow(2) - log_std.exp()).mean()
        
        # Q学习
        q_tot, _, _ = self.forward(batch['obs'], batch['actions'], z)
        
        with torch.no_grad():
            next_z, _, _ = self.sample_z(batch['next_obs'][:, 0])
            target_q, _, _ = self.forward(batch['next_obs'], None, next_z)
            target = batch['rewards'].sum(dim=-1, keepdim=True) + \
                     self.gamma * target_q.max(dim=1, keepdim=True)[0]
        
        q_loss = F.mse_loss(q_tot, target)
        
        return q_loss + 0.1 * kl_loss  # 平衡项

5. 通信增强的策略梯度

5.1 DIAL:Differentiable Inter-Agent Learning

DIAL4允许梯度跨智能体传播:

class DIALAgent:
    def __init__(self, obs_dim, action_dim, message_dim=8):
        self.message_net = nn.Linear(obs_dim, message_dim)
        self.critic = CentralizedCritic(obs_dim + message_dim, action_dim)
    
    def forward(self, obs, incoming_message=None):
        # 编码观察
        encoded = F.relu(self.message_net(obs))
        
        # 如果有传入消息,拼接
        if incoming_message is not None:
            encoded = torch.cat([encoded, incoming_message], dim=-1)
        
        # 生成发送给其他智能体的消息
        message = self.message_net(encoded)
        
        # 策略输出
        action = self.actor(encoded)
        
        return action, message

5.2 通信协议学习

class CommNet:
    """
    通信网络:学习何时通信以及传递什么信息
    """
    
    def __init__(self, n_agents, obs_dim, message_dim=8):
        self.n_agents = n_agents
        self.encoder = nn.Linear(obs_dim, 64)
        self.message_pass = MessagePassingLayer(64, message_dim)
        self.decoder = nn.Linear(64 + message_dim, 1)  # 决定是否通信
    
    def forward(self, observations, steps=3):
        """
        多轮消息传递
        """
        batch_size = observations[0].size(0)
        hidden = torch.stack([
            self.encoder(obs) for obs in observations
        ], dim=1)  # [batch, n_agents, hidden]
        
        for _ in range(steps):
            # 消息传递
            messages = self.message_pass(hidden)
            
            # 更新隐藏状态
            hidden = torch.cat([hidden, messages], dim=-1)
            hidden = F.relu(self.encoder(hidden))
        
        # 决定通信强度
        comm_weights = torch.sigmoid(self.decoder(hidden))
        
        return hidden, comm_weights

6. 算法对比

算法Critic类型信用分配通信适用场景
COMA中心化反事实基线合作式
MADDPG中心化混合博弈
MAVEN中心化潜在变量合作式
DIAL中心化可微需要通信

7. 参考文献

Footnotes

  1. Foerster et al. “Counterfactual Multi-Agent Policy Gradients” AAAI 2018

  2. Lowe et al. “Multi-Agent Actor-Critic for Mixed Cooperative-Competitive Environments” NeurIPS 2017

  3. Qian et al. “MAVEN: Multi-Agent Variational Exploration” NeurIPS 2019

  4. Foerster et al. “Learning to Communicate with Deep Multi-Agent Reinforcement Learning” NeurIPS 2016