MARL通信与协调机制

1. 通信的重要性

在多智能体系统中,通信是协调行动的关键机制

无通信:                          有通信:
┌────────────────┐             ┌────────────────┐
│ 智能体A        │             │ 智能体A ──→ 智能体B│
│   ?            │             │    传递意图      │
│智能体B        │             │智能体B ←── 智能体A │
│   ?            │             │    传递观察      │
└────────────────┘             └────────────────┘

问题:不知道其他智能体要做什么   解决:协调行动

2. 通信分类

2.1 通信模式

模式描述示例
显式通信智能体主动发送消息DIAL、CommNet
隐式通信通过环境中的信号间接通信跟随、观察行为
混合通信结合显式和隐式TarMAC

2.2 通信结构

完全连接:                    图结构:
A ←→ B ←→ C               A ──→ B
↑     ↓     ↑               ↓     ↓
└─────┴─────┘               C ←── D

广播:                       星型:
A → B, C, D                A → B, C, D
                           (A为协调者)

3. 显式通信方法

3.1 CommNet:连续通信

CommNet1使用连续消息进行通信:

class CommNet(nn.Module):
    """
    CommNet: 基于连续消息的多智能体通信
    """
    
    def __init__(self, input_dim, hidden_dim, n_agents, n_steps=3):
        super().__init__()
        self.n_agents = n_agents
        self.n_steps = n_steps
        
        # 编码器
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU()
        )
        
        # 通信层
        self.comm_layer = nn.ModuleList([
            CommunicationLayer(hidden_dim)
            for _ in range(n_steps)
        ])
        
        # 解码器
        self.decoder = nn.Linear(hidden_dim, input_dim)
    
    def forward(self, observations):
        """
        observations: [batch, n_agents, obs_dim]
        """
        batch_size = observations.size(0)
        
        # 编码观察
        h = self.encoder(observations)  # [batch, n_agents, hidden_dim]
        
        # 多轮消息传递
        for step in range(self.n_steps):
            # 消息聚合:简单平均
            messages = h.mean(dim=1, keepdim=True).expand(-1, self.n_agents, -1)
            
            # 消息传递
            h = self.comm_layer[step](h, messages)
        
        # 解码输出动作或价值
        return self.decoder(h)
 
 
class CommunicationLayer(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.gru = nn.GRUCell(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim * 2, hidden_dim)
    
    def forward(self, hidden, message):
        """
        hidden: [batch, n_agents, hidden_dim]
        message: [batch, n_agents, hidden_dim]
        """
        # 拼接自己的隐藏状态和接收到的消息
        combined = torch.cat([hidden, message], dim=-1)
        
        # 更新隐藏状态
        new_hidden = self.gru(
            combined.view(-1, hidden.size(-1) * 2),
            hidden.view(-1, hidden.size(-1))
        )
        
        return new_hidden.view_as(hidden)

3.2 DIAL:可微分通信

DIAL2允许通信梯度反向传播:

class DIAL(nn.Module):
    """
    DIAL: Differentiable Inter-Agent Learning
    核心思想:消息可以携带梯度信息
    """
    
    def __init__(self, obs_dim, action_dim, message_dim=8):
        super().__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.message_dim = message_dim
        
        # 观察编码
        self.obs_encoder = nn.Sequential(
            nn.Linear(obs_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 32)
        )
        
        # 消息生成
        self.message_net = nn.Sequential(
            nn.Linear(32 + message_dim, message_dim * 2)  # 均值+方差
        )
        
        # 动作网络
        self.actor = nn.Sequential(
            nn.Linear(32 + message_dim, 32),
            nn.ReLU(),
            nn.Linear(32, action_dim),
            nn.Softmax(dim=-1)
        )
        
        # Critic网络
        self.critic = CentralizedCritic(obs_dim, action_dim, message_dim)
    
    def send_message(self, obs_encoding, incoming_messages):
        """
        生成要发送的消息
        """
        if incoming_messages is None:
            incoming_messages = torch.zeros_like(obs_encoding)
        
        combined = torch.cat([obs_encoding, incoming_messages], dim=-1)
        message_params = self.message_net(combined)
        
        # 重参数化
        mean, log_std = message_params.chunk(2, dim=-1)
        message = mean + torch.randn_like(std) * torch.exp(log_std)
        
        return message, mean, log_std
    
    def receive_message(self, messages, agent_indices):
        """
        接收来自其他智能体的消息
        使用注意力机制选择重要消息
        """
        # 简单实现:所有消息的平均
        return messages.mean(dim=1)
    
    def forward(self, obs, incoming_messages=None, training=True):
        """
        完整的前向传播
        """
        # 编码观察
        encoding = self.obs_encoder(obs)
        
        # 生成消息
        message, msg_mean, msg_std = self.send_message(encoding, incoming_messages)
        
        # 接收消息
        received = self.receive_message(message, None)
        
        # 拼接并输出动作
        combined = torch.cat([encoding, received], dim=-1)
        action = self.actor(combined)
        
        if training:
            return action, message, msg_mean, msg_std
        return action, message

3.3 TarMAC:目标感知多智能体通信

TarMAC3使用基于签名的注意力

class TarMAC(nn.Module):
    """
    TarMAC: Target-Aware Multi-Agent Communication
    核心:智能体发送带"签名"的消息,接收者选择性接收
    """
    
    def __init__(self, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        # 签名生成
        self.signature_net = nn.Linear(hidden_dim, hidden_dim)
        
        # 注意力
        self.query_net = nn.Linear(hidden_dim, hidden_dim)
        self.key_net = nn.Linear(hidden_dim, hidden_dim)
        self.value_net = nn.Linear(hidden_dim, hidden_dim)
    
    def forward(self, hidden_states):
        """
        hidden_states: [batch, n_agents, hidden_dim]
        """
        batch_size, n_agents = hidden_states.shape[:2]
        
        # 生成签名
        signatures = torch.tanh(self.signature_net(hidden_states))
        
        # 查询:当前智能体想知道什么
        queries = self.query_net(hidden_states)  # [batch, n_agents, hidden]
        
        # 键:消息的"地址"
        keys = self.key_net(hidden_states)  # [batch, n_agents, hidden]
        
        # 值:消息内容
        values = self.value_net(hidden_states)  # [batch, n_agents, hidden]
        
        # 计算注意力
        # 对于每个智能体,计算与其他智能体消息的注意力
        attention_weights = torch.matmul(queries, keys.transpose(-2, -1))
        attention_weights = attention_weights / np.sqrt(self.hidden_dim)
        attention_weights = F.softmax(attention_weights, dim=-1)
        
        # 加权聚合消息
        aggregated = torch.matmul(attention_weights, values)
        
        return aggregated

4. 隐式通信方法

4.1 通过行为推断

隐式通信不需要显式消息传递:

class ImplicitCommunication:
    """
    隐式通信:智能体通过观察他人行为来推断意图
    """
    
    def __init__(self):
        self.intent_predictor = IntentPredictor()
    
    def infer_intent(self, obs, action):
        """
        从观察和动作推断其他智能体的意图
        """
        # 训练一个网络来预测其他智能体的意图
        predicted_intent = self.intent_predictor(obs, action)
        return predicted_intent
    
    def coordinate(self, obs, inferred_intents):
        """
        根据推断的意图协调自己的行动
        """
        # 将他人意图作为额外信息输入
        action = self.policy(torch.cat([obs, inferred_intents], dim=-1))
        return action

4.2 注意力机制协调

class AttentionCoordination(nn.Module):
    """
    使用注意力进行隐式协调
    """
    
    def __init__(self, hidden_dim, n_heads=4):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=n_heads,
            batch_first=True
        )
    
    def forward(self, agent_states, agent_ids=None):
        """
        agent_states: [batch, n_agents, hidden_dim]
        """
        # 使用自注意力进行协调
        attended, _ = self.multihead_attn(
            agent_states, agent_states, agent_states
        )
        
        # 协调后的状态
        coordinated_states = agent_states + attended
        
        return coordinated_states

5. 图神经网络通信

5.1 智能体建模为图

图视角下的多智能体系统:

节点:智能体
边:通信链接
特征:观察、动作、意图

    A ──B
   /│\  │  ← 图结构
  C │ D │
   \│/  │
    E ──F

5.2 GraphAttentionComm

class GraphAttentionComm(nn.Module):
    """
    基于图注意力的多智能体通信
    """
    
    def __init__(self, node_dim, edge_dim, hidden_dim, n_heads=4):
        super().__init__()
        self.node_encoder = nn.Linear(node_dim, hidden_dim)
        self.edge_encoder = nn.Linear(edge_dim, hidden_dim)
        
        self.gat_layers = nn.ModuleList([
            GATLayer(hidden_dim, n_heads)
            for _ in range(3)  # 3层图注意力
        ])
        
        self.policy_head = nn.Linear(hidden_dim, 1)
    
    def forward(self, node_features, edge_index, edge_features=None):
        """
        node_features: [batch, n_agents, node_dim]
        edge_index: [2, n_edges] 或邻接矩阵
        edge_features: [batch, n_edges, edge_dim]
        """
        # 编码节点特征
        h = self.node_encoder(node_features)
        
        # 图注意力层
        for gat_layer in self.gat_layers:
            h = gat_layer(h, edge_index, edge_features)
        
        # 策略输出
        logits = self.policy_head(h)
        
        return logits, h

5.3 TGCNet:Transformer图协同网络

TGCNet4使用Transformer进行动态图通信:

class TGCNet(nn.Module):
    """
    Transformer-based Graph Coarsening Network
    核心:动态学习通信拓扑
    """
    
    def __init__(self, n_agents, obs_dim, hidden_dim=128):
        super().__init__()
        self.n_agents = n_agents
        
        # 观察编码
        self.obs_encoder = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU()
        )
        
        # Transformer编码器(建模智能体间关系)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=8,
            dim_feedforward=hidden_dim * 4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=3)
        
        # 动态邻接矩阵学习
        self.edge_predictor = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
        # 策略网络
        self.policy = nn.Linear(hidden_dim, 1)
    
    def predict_edges(self, hidden_states):
        """
        预测哪些智能体应该通信
        返回动态邻接矩阵
        """
        batch_size = hidden_states.size(0)
        
        # 构造所有智能体对
        h_i = hidden_states.unsqueeze(2).expand(-1, -1, self.n_agents, -1)
        h_j = hidden_states.unsqueeze(1).expand(-1, self.n_agents, -1, -1)
        
        # 计算边的权重
        h_pair = torch.cat([h_i, h_j], dim=-1)
        edge_weights = self.edge_predictor(h_pair).squeeze(-1)
        
        # Mask对角线(不与自己通信)
        mask = torch.eye(self.n_agents, device=hidden_states.device).unsqueeze(0)
        edge_weights = edge_weights * (1 - mask)
        
        return edge_weights
    
    def forward(self, observations):
        """
        observations: [batch, n_agents, obs_dim]
        """
        # 编码观察
        h = self.obs_encoder(observations)
        
        # 预测通信拓扑
        edge_weights = self.predict_edges(h)
        
        # 使用Transformer进行信息传递
        h_transformed = self.transformer(h)
        
        # 策略输出
        action_logits = self.policy(h_transformed)
        
        return action_logits, edge_weights, h_transformed

6. 通信效率与带宽限制

6.1 离散通信

class DiscreteCommunication(nn.Module):
    """
    离散通信:消息是离散符号
    """
    
    def __init__(self, message_dim, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.message_dim = message_dim
        
        # 连续→离散
        self.message_encoder = nn.Sequential(
            nn.Linear(message_dim, 64),
            nn.ReLU(),
            nn.Linear(64, vocab_size)  # 离散符号
        )
        
        # 离散→连续
        self.message_decoder = nn.Embedding(vocab_size, message_dim)
    
    def forward(self, hidden):
        # 生成离散消息
        logits = self.message_encoder(hidden)
        message = F.gumbel_softmax(logits, tau=0.5, hard=True)
        
        # 解码为连续向量
        decoded = self.message_decoder(message.argmax(dim=-1))
        
        return decoded, message

6.2 通信调度

class CommunicationScheduler(nn.Module):
    """
    学习何时通信
    """
    
    def __init__(self, hidden_dim):
        super().__init__()
        self.comm_decider = nn.Sequential(
            nn.Linear(hidden_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
    
    def should_communicate(self, hidden, threshold=0.5):
        """
        决定是否通信
        """
        comm_prob = self.comm_decider(hidden)
        
        if self.training:
            # 训练时:随机
            return (comm_prob > threshold).float()
        else:
            # 推理时:贪心
            return (comm_prob > 0.5).float()

7. 通信安全与鲁棒性

7.1 对抗鲁棒通信

class RobustCommunication(nn.Module):
    """
    对抗噪声的鲁棒通信
    """
    
    def __init__(self, base_comm):
        super().__init__()
        self.base_comm = base_comm
        self.noise_estimator = NoiseEstimator()
    
    def denoise(self, message, estimated_noise):
        """
        去噪:减去估计的噪声
        """
        return message - estimated_noise
    
    def forward(self, hidden, noisy_message=None):
        if noisy_message is not None:
            # 估计噪声
            noise = self.noise_estimator(noisy_message, hidden)
            # 去噪
            message = self.denoise(noisy_message, noise)
        else:
            message = self.base_comm(hidden)
        
        return message

8. 参考文献

Footnotes

  1. Sukhbaatar et al. “Learning Multiagent Communication with Backpropagation” NeurIPS 2016

  2. Foerster et al. “Learning to Communicate with Deep Multi-Agent Reinforcement Learning” NeurIPS 2016

  3. Das et al. “TAR-MAC: Target-Aware Multi-Agent Communication” ICML 2019

  4. Liu et al. “Bridging Training and Execution via Dynamic Directed Graph-Based Communication” 2024