MARL通信与协调机制
1. 通信的重要性
在多智能体系统中,通信是协调行动的关键机制:
无通信: 有通信:
┌────────────────┐ ┌────────────────┐
│ 智能体A │ │ 智能体A ──→ 智能体B│
│ ? │ │ 传递意图 │
│智能体B │ │智能体B ←── 智能体A │
│ ? │ │ 传递观察 │
└────────────────┘ └────────────────┘
问题:不知道其他智能体要做什么 解决:协调行动
2. 通信分类
2.1 通信模式
| 模式 | 描述 | 示例 |
|---|---|---|
| 显式通信 | 智能体主动发送消息 | DIAL、CommNet |
| 隐式通信 | 通过环境中的信号间接通信 | 跟随、观察行为 |
| 混合通信 | 结合显式和隐式 | TarMAC |
2.2 通信结构
完全连接: 图结构:
A ←→ B ←→ C A ──→ B
↑ ↓ ↑ ↓ ↓
└─────┴─────┘ C ←── D
广播: 星型:
A → B, C, D A → B, C, D
(A为协调者)
3. 显式通信方法
3.1 CommNet:连续通信
CommNet1使用连续消息进行通信:
class CommNet(nn.Module):
"""
CommNet: 基于连续消息的多智能体通信
"""
def __init__(self, input_dim, hidden_dim, n_agents, n_steps=3):
super().__init__()
self.n_agents = n_agents
self.n_steps = n_steps
# 编码器
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU()
)
# 通信层
self.comm_layer = nn.ModuleList([
CommunicationLayer(hidden_dim)
for _ in range(n_steps)
])
# 解码器
self.decoder = nn.Linear(hidden_dim, input_dim)
def forward(self, observations):
"""
observations: [batch, n_agents, obs_dim]
"""
batch_size = observations.size(0)
# 编码观察
h = self.encoder(observations) # [batch, n_agents, hidden_dim]
# 多轮消息传递
for step in range(self.n_steps):
# 消息聚合:简单平均
messages = h.mean(dim=1, keepdim=True).expand(-1, self.n_agents, -1)
# 消息传递
h = self.comm_layer[step](h, messages)
# 解码输出动作或价值
return self.decoder(h)
class CommunicationLayer(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.gru = nn.GRUCell(hidden_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim * 2, hidden_dim)
def forward(self, hidden, message):
"""
hidden: [batch, n_agents, hidden_dim]
message: [batch, n_agents, hidden_dim]
"""
# 拼接自己的隐藏状态和接收到的消息
combined = torch.cat([hidden, message], dim=-1)
# 更新隐藏状态
new_hidden = self.gru(
combined.view(-1, hidden.size(-1) * 2),
hidden.view(-1, hidden.size(-1))
)
return new_hidden.view_as(hidden)3.2 DIAL:可微分通信
DIAL2允许通信梯度反向传播:
class DIAL(nn.Module):
"""
DIAL: Differentiable Inter-Agent Learning
核心思想:消息可以携带梯度信息
"""
def __init__(self, obs_dim, action_dim, message_dim=8):
super().__init__()
self.obs_dim = obs_dim
self.action_dim = action_dim
self.message_dim = message_dim
# 观察编码
self.obs_encoder = nn.Sequential(
nn.Linear(obs_dim, 64),
nn.ReLU(),
nn.Linear(64, 32)
)
# 消息生成
self.message_net = nn.Sequential(
nn.Linear(32 + message_dim, message_dim * 2) # 均值+方差
)
# 动作网络
self.actor = nn.Sequential(
nn.Linear(32 + message_dim, 32),
nn.ReLU(),
nn.Linear(32, action_dim),
nn.Softmax(dim=-1)
)
# Critic网络
self.critic = CentralizedCritic(obs_dim, action_dim, message_dim)
def send_message(self, obs_encoding, incoming_messages):
"""
生成要发送的消息
"""
if incoming_messages is None:
incoming_messages = torch.zeros_like(obs_encoding)
combined = torch.cat([obs_encoding, incoming_messages], dim=-1)
message_params = self.message_net(combined)
# 重参数化
mean, log_std = message_params.chunk(2, dim=-1)
message = mean + torch.randn_like(std) * torch.exp(log_std)
return message, mean, log_std
def receive_message(self, messages, agent_indices):
"""
接收来自其他智能体的消息
使用注意力机制选择重要消息
"""
# 简单实现:所有消息的平均
return messages.mean(dim=1)
def forward(self, obs, incoming_messages=None, training=True):
"""
完整的前向传播
"""
# 编码观察
encoding = self.obs_encoder(obs)
# 生成消息
message, msg_mean, msg_std = self.send_message(encoding, incoming_messages)
# 接收消息
received = self.receive_message(message, None)
# 拼接并输出动作
combined = torch.cat([encoding, received], dim=-1)
action = self.actor(combined)
if training:
return action, message, msg_mean, msg_std
return action, message3.3 TarMAC:目标感知多智能体通信
TarMAC3使用基于签名的注意力:
class TarMAC(nn.Module):
"""
TarMAC: Target-Aware Multi-Agent Communication
核心:智能体发送带"签名"的消息,接收者选择性接收
"""
def __init__(self, hidden_dim):
super().__init__()
self.hidden_dim = hidden_dim
# 签名生成
self.signature_net = nn.Linear(hidden_dim, hidden_dim)
# 注意力
self.query_net = nn.Linear(hidden_dim, hidden_dim)
self.key_net = nn.Linear(hidden_dim, hidden_dim)
self.value_net = nn.Linear(hidden_dim, hidden_dim)
def forward(self, hidden_states):
"""
hidden_states: [batch, n_agents, hidden_dim]
"""
batch_size, n_agents = hidden_states.shape[:2]
# 生成签名
signatures = torch.tanh(self.signature_net(hidden_states))
# 查询:当前智能体想知道什么
queries = self.query_net(hidden_states) # [batch, n_agents, hidden]
# 键:消息的"地址"
keys = self.key_net(hidden_states) # [batch, n_agents, hidden]
# 值:消息内容
values = self.value_net(hidden_states) # [batch, n_agents, hidden]
# 计算注意力
# 对于每个智能体,计算与其他智能体消息的注意力
attention_weights = torch.matmul(queries, keys.transpose(-2, -1))
attention_weights = attention_weights / np.sqrt(self.hidden_dim)
attention_weights = F.softmax(attention_weights, dim=-1)
# 加权聚合消息
aggregated = torch.matmul(attention_weights, values)
return aggregated4. 隐式通信方法
4.1 通过行为推断
隐式通信不需要显式消息传递:
class ImplicitCommunication:
"""
隐式通信:智能体通过观察他人行为来推断意图
"""
def __init__(self):
self.intent_predictor = IntentPredictor()
def infer_intent(self, obs, action):
"""
从观察和动作推断其他智能体的意图
"""
# 训练一个网络来预测其他智能体的意图
predicted_intent = self.intent_predictor(obs, action)
return predicted_intent
def coordinate(self, obs, inferred_intents):
"""
根据推断的意图协调自己的行动
"""
# 将他人意图作为额外信息输入
action = self.policy(torch.cat([obs, inferred_intents], dim=-1))
return action4.2 注意力机制协调
class AttentionCoordination(nn.Module):
"""
使用注意力进行隐式协调
"""
def __init__(self, hidden_dim, n_heads=4):
super().__init__()
self.multihead_attn = nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=n_heads,
batch_first=True
)
def forward(self, agent_states, agent_ids=None):
"""
agent_states: [batch, n_agents, hidden_dim]
"""
# 使用自注意力进行协调
attended, _ = self.multihead_attn(
agent_states, agent_states, agent_states
)
# 协调后的状态
coordinated_states = agent_states + attended
return coordinated_states5. 图神经网络通信
5.1 智能体建模为图
图视角下的多智能体系统:
节点:智能体
边:通信链接
特征:观察、动作、意图
A ──B
/│\ │ ← 图结构
C │ D │
\│/ │
E ──F
5.2 GraphAttentionComm
class GraphAttentionComm(nn.Module):
"""
基于图注意力的多智能体通信
"""
def __init__(self, node_dim, edge_dim, hidden_dim, n_heads=4):
super().__init__()
self.node_encoder = nn.Linear(node_dim, hidden_dim)
self.edge_encoder = nn.Linear(edge_dim, hidden_dim)
self.gat_layers = nn.ModuleList([
GATLayer(hidden_dim, n_heads)
for _ in range(3) # 3层图注意力
])
self.policy_head = nn.Linear(hidden_dim, 1)
def forward(self, node_features, edge_index, edge_features=None):
"""
node_features: [batch, n_agents, node_dim]
edge_index: [2, n_edges] 或邻接矩阵
edge_features: [batch, n_edges, edge_dim]
"""
# 编码节点特征
h = self.node_encoder(node_features)
# 图注意力层
for gat_layer in self.gat_layers:
h = gat_layer(h, edge_index, edge_features)
# 策略输出
logits = self.policy_head(h)
return logits, h5.3 TGCNet:Transformer图协同网络
TGCNet4使用Transformer进行动态图通信:
class TGCNet(nn.Module):
"""
Transformer-based Graph Coarsening Network
核心:动态学习通信拓扑
"""
def __init__(self, n_agents, obs_dim, hidden_dim=128):
super().__init__()
self.n_agents = n_agents
# 观察编码
self.obs_encoder = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU()
)
# Transformer编码器(建模智能体间关系)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=8,
dim_feedforward=hidden_dim * 4,
batch_first=True
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=3)
# 动态邻接矩阵学习
self.edge_predictor = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid()
)
# 策略网络
self.policy = nn.Linear(hidden_dim, 1)
def predict_edges(self, hidden_states):
"""
预测哪些智能体应该通信
返回动态邻接矩阵
"""
batch_size = hidden_states.size(0)
# 构造所有智能体对
h_i = hidden_states.unsqueeze(2).expand(-1, -1, self.n_agents, -1)
h_j = hidden_states.unsqueeze(1).expand(-1, self.n_agents, -1, -1)
# 计算边的权重
h_pair = torch.cat([h_i, h_j], dim=-1)
edge_weights = self.edge_predictor(h_pair).squeeze(-1)
# Mask对角线(不与自己通信)
mask = torch.eye(self.n_agents, device=hidden_states.device).unsqueeze(0)
edge_weights = edge_weights * (1 - mask)
return edge_weights
def forward(self, observations):
"""
observations: [batch, n_agents, obs_dim]
"""
# 编码观察
h = self.obs_encoder(observations)
# 预测通信拓扑
edge_weights = self.predict_edges(h)
# 使用Transformer进行信息传递
h_transformed = self.transformer(h)
# 策略输出
action_logits = self.policy(h_transformed)
return action_logits, edge_weights, h_transformed6. 通信效率与带宽限制
6.1 离散通信
class DiscreteCommunication(nn.Module):
"""
离散通信:消息是离散符号
"""
def __init__(self, message_dim, vocab_size):
super().__init__()
self.vocab_size = vocab_size
self.message_dim = message_dim
# 连续→离散
self.message_encoder = nn.Sequential(
nn.Linear(message_dim, 64),
nn.ReLU(),
nn.Linear(64, vocab_size) # 离散符号
)
# 离散→连续
self.message_decoder = nn.Embedding(vocab_size, message_dim)
def forward(self, hidden):
# 生成离散消息
logits = self.message_encoder(hidden)
message = F.gumbel_softmax(logits, tau=0.5, hard=True)
# 解码为连续向量
decoded = self.message_decoder(message.argmax(dim=-1))
return decoded, message6.2 通信调度
class CommunicationScheduler(nn.Module):
"""
学习何时通信
"""
def __init__(self, hidden_dim):
super().__init__()
self.comm_decider = nn.Sequential(
nn.Linear(hidden_dim, 32),
nn.ReLU(),
nn.Linear(32, 1),
nn.Sigmoid()
)
def should_communicate(self, hidden, threshold=0.5):
"""
决定是否通信
"""
comm_prob = self.comm_decider(hidden)
if self.training:
# 训练时:随机
return (comm_prob > threshold).float()
else:
# 推理时:贪心
return (comm_prob > 0.5).float()7. 通信安全与鲁棒性
7.1 对抗鲁棒通信
class RobustCommunication(nn.Module):
"""
对抗噪声的鲁棒通信
"""
def __init__(self, base_comm):
super().__init__()
self.base_comm = base_comm
self.noise_estimator = NoiseEstimator()
def denoise(self, message, estimated_noise):
"""
去噪:减去估计的噪声
"""
return message - estimated_noise
def forward(self, hidden, noisy_message=None):
if noisy_message is not None:
# 估计噪声
noise = self.noise_estimator(noisy_message, hidden)
# 去噪
message = self.denoise(noisy_message, noise)
else:
message = self.base_comm(hidden)
return message8. 参考文献
Footnotes
-
Sukhbaatar et al. “Learning Multiagent Communication with Backpropagation” NeurIPS 2016 ↩
-
Foerster et al. “Learning to Communicate with Deep Multi-Agent Reinforcement Learning” NeurIPS 2016 ↩
-
Das et al. “TAR-MAC: Target-Aware Multi-Agent Communication” ICML 2019 ↩
-
Liu et al. “Bridging Training and Execution via Dynamic Directed Graph-Based Communication” 2024 ↩