MARL值函数分解方法
1. 核心问题
在合作式MARL中,我们需要解决**信用分配(Credit Assignment)**问题:
其中 是某种可分解的操作。
1.1 为什么要分解?
- 维度灾难:联合Q函数随智能体数指数增长
- 去中心化执行:每个智能体只能访问自己的Q函数
- 可解释性:理解每个智能体的贡献
1.2 分解的理想性质
| 性质 | 描述 |
|---|---|
| 个体一致性(IA) | 贪婪选择个体最优导致联合最优 |
| 团队可加性(TA) | 团队奖励等于个体奖励之和 |
| 联合贪婪性 |
2. VDN:Value Decomposition Networks
2.1 核心思想
VDN1使用最简单的线性分解:
2.2 网络架构
┌─────────────────────────────────────────────────────┐
│ VDN 架构 │
│ │
│ 状态s │
│ │ │
│ ├──→ Q₁(s, a¹) ──────┐ │
│ │ │ │
│ ├──→ Q₂(s, a²) ───→ │ + → Q_tot(s, a) │
│ │ │ │
│ └──→ Qₙ(s, aⁿ) ──────┘ │
│ │
│ 训练信号从 Q_tot 反向传播到每个 Q_i │
└─────────────────────────────────────────────────────┘
2.3 PyTorch实现
class VDN(nn.Module):
"""Value Decomposition Networks"""
def __init__(self, state_dim, action_dims, hidden_dim=64):
super().__init__()
self.n_agents = len(action_dims)
# 每个智能体有自己的Q网络
self.q_nets = nn.ModuleList([
MLP(state_dim, action_dims[i], hidden_dim)
for i in range(self.n_agents)
])
def forward(self, states, actions=None):
"""
states: [batch, n_agents, state_dim]
actions: [batch, n_agents] (one-hot) or None
"""
individual_qs = []
for i, q_net in enumerate(self.q_nets):
# 每个智能体根据自己观察到的状态计算Q值
# 实际中,状态可能是局部观察
q_i = q_net(states[:, i]) # [batch, action_dim_i]
if actions is not None:
# 如果提供了动作,计算对应Q值
q_i = (q_i * actions[:, i]).sum(dim=-1, keepdim=True)
individual_qs.append(q_i)
# 分解:直接求和
total_q = sum(individual_qs)
return total_q, individual_qs
def get_epsilon_greedy_action(self, states, epsilon=0.1):
"""ε-贪婪策略:分解后每个智能体独立选择"""
actions = []
for i, q_net in enumerate(self.q_nets):
q_i = q_net(states[:, i])
if random.random() < epsilon:
a_i = torch.randint(0, q_i.size(-1), (q_i.size(0),))
else:
a_i = q_i.argmax(dim=-1)
actions.append(a_i)
return torch.stack(actions, dim=1)2.4 局限性
VDN的线性加和假设过于简单:
# VDN假设:团队Q = Σ 个体Q
# 问题:无法表示非线性交互
# 例如:
# Q(s, left, left) = 10 # 两人同时向左(碰撞)
# Q(s, right, right) = 10 # 两人同时向右(碰撞)
# Q(s, left, right) = 100 # 两人分开(成功)
# VDN无法捕捉这个交互:
# VDN会预测 Q(s, left, left) = Q₁ + Q₁ # 总是最小
# Q(s, left, right) = Q₁ + Q₂ # 可能不大3. QMIX:Monotonic Value Function Factorisation
3.1 核心思想
QMIX2放宽了VDN的线性假设,使用单调混合网络:
其中混合网络 满足单调性:
3.2 为什么需要单调性?
单调性的重要性:
如果 Q_tot 对每个 Q_i 单调递增,
那么:
argmax_{a} Q_tot(s, a)
= argmax_{a} (f_mix(Q_1(s, a.1), ...))
= (argmax_{a1} Q_1(s, a1), argmax_{a2} Q_2(s, a2), ...)
即:个体贪婪 = 联合贪婪 ✓
3.3 网络架构
┌─────────────────────────────────────────────────────────┐
│ QMIX 架构 │
│ │
│ 状态s ──→ HyperNetwork ──→ 混合网络权重 │
│ ↓ │
│ ┌──────────────────────────────────────────────────┐ │
│ │ 混合网络 f_mix │ │
│ │ │ │
│ │ Q₁ ──→ [W₁] ─┐ │ │
│ │ ↓ │ │ │
│ │ Q₂ ──→ [W₂] ─→ [W₃] ──→ Q_tot │ │
│ │ ↗ │ │ │
│ │ Qₙ ──→ [Wₙ] ─┘ │ │
│ │ │ │
│ │ 所有权重 W_i ≥ 0 (保证单调性) │ │
│ └──────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────┘
3.4 混合网络实现
class QMixNet(nn.Module):
"""QMIX混合网络"""
def __init__(self, n_agents, state_dim, hidden_dim=32):
super().__init__()
self.n_agents = n_agents
self.state_dim = state_dim
# 生成第一层权重的超网络
self.hyper_w1 = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_agents * hidden_dim)
)
# 生成第一层偏置的超网络
self.hyper_b1 = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# 生成第二层权重
self.hyper_w2 = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim)
)
# 生成第二层偏置
self.hyper_b2 = nn.Sequential(
nn.Linear(state_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1)
)
# 第二层输出维度
self.v2 = nn.Parameter(torch.zeros(1, hidden_dim // 2))
def forward(self, agent_qs, states):
"""
agent_qs: [batch, n_agents]
states: [batch, state_dim]
"""
batch_size = agent_qs.size(0)
# 生成第一层参数 (确保非负以保证单调性)
w1 = torch.abs(self.hyper_w1(states)) # [batch, n_agents * hidden_dim]
w1 = w1.view(batch_size, self.n_agents, -1) # [batch, n_agents, hidden_dim]
b1 = self.hyper_b1(states) # [batch, hidden_dim]
# 生成第二层参数
w2 = torch.abs(self.hyper_w2(states)) # [batch, hidden_dim]
w2 = w2.view(batch_size, -1, 1) # [batch, hidden_dim, 1]
b2 = self.hyper_b2(states) # [batch, hidden_dim // 2]
# 第一层: ReLU(Q * W + b)
hidden = torch.bmm(agent_qs.unsqueeze(1), w1) + b1.unsqueeze(1)
hidden = F.relu(hidden) # [batch, 1, hidden_dim]
# 第二层
v2 = self.v2.expand(batch_size, -1, 1) # [batch, hidden_dim // 2, 1]
# 合并v2到w2
w2_full = torch.cat([w2[:, :hidden.size(-1)//2],
v2.expand(-1, hidden.size(-1)//2, 1)], dim=1)
q_tot = torch.bmm(hidden, w2_full) + b2.unsqueeze(2)
return q_tot.squeeze(-1) # [batch, 1]3.5 训练算法
class QMIX:
def __init__(self, n_agents, state_dim, action_dims):
self.agent_net = VDNQNetwork(n_agents, state_dim, action_dims)
self.mixer = QMixNet(n_agents, state_dim)
self.target_agent_net = copy.deepcopy(self.agent_net)
self.target_mixer = copy.deepcopy(self.mixer)
def update(self, batch):
states, actions, rewards, next_states, dones = batch
# 计算当前Q值
agent_qs, _ = self.agent_net(states, actions)
q_tot = self.mixer(agent_qs, states)
# 计算目标Q值
with torch.no_grad():
next_agent_qs, _ = self.target_agent_net(next_states)
next_q_tot = self.target_mixer(next_agent_qs, next_states)
target_q = rewards + (1 - dones) * self.gamma * next_q_tot
# 损失
loss = F.mse_loss(q_tot, target_q)
# 梯度反向传播
self.optim.zero_grad()
loss.backward()
self.optim.step()
def update_target(self, tau=0.005):
"""软更新目标网络"""
for target, source in zip(
self.target_agent_net.parameters(),
self.agent_net.parameters()
):
target.data.copy_(tau * source + (1 - tau) * target)4. QTRAN:General Value Function Factorisation
4.1 核心思想
QTRAN3解决VDN和QMIX的结构限制:
- VDN:只能表示可加性
- QMIX:只能表示单调性
- QTRAN:可以表示任意分解
4.2 变换方法
QTRAN的核心是将原始Q函数变换为可分解的形式:
使得:
4.3 损失函数
QTRAN包含两个损失项:
class QTRANLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, q_tot, individual_qs, target_q_tot,
states, actions):
"""
Args:
q_tot: 联合Q值 [batch]
individual_qs: 个体Q值列表 [[batch, dim_i]]
target_q_tot: 目标联合Q值 [batch]
"""
batch_size = q_tot.size(0)
# 1. 重建损失:保证变换后的Q_tot与原始相近
recon_loss = F.mse_loss(q_tot, target_q_tot)
# 2. 分解损失:个体Q的贪婪选择等于联合贪婪
# 计算每个智能体在自己的Q上贪婪选择的联合Q
greedy_joint_q = 0
for i, q_i in enumerate(individual_qs):
greedy_a_i = q_i.argmax(dim=-1) # [batch]
# 这里应该用实际的贪婪Q值
greedy_joint_q += q_i.gather(-1, greedy_a_i.unsqueeze(-1)).squeeze(-1)
# 3. 差异损失:保证变换后的联合贪婪等于真实联合贪婪
difference_loss = F.relu(q_tot - greedy_joint_q).mean()
return recon_loss + difference_loss5. QTRAN vs VDN vs QMIX
5.1 表达能力对比
| 算法 | 表达能力 | 约束 | 去中心化贪婪 |
|---|---|---|---|
| VDN | 线性加和 | 无 | ✓ |
| QMIX | 单调函数 | 单调性 | ✓ |
| QTRAN | 任意函数 | 无 | ✓ |
5.2 计算复杂度
| 算法 | 混合网络参数量 | 训练复杂度 |
|---|---|---|
| VDN | O(1) | 低 |
| QMIX | O(n × h) | 中 |
| QTRAN | O(n × h + 额外网络) | 高 |
5.3 适用场景
VDN:
✓ 智能体独立工作,低交互
✗ 高度协调任务
QMIX:
✓ 大多数合作任务
✗ 复杂的反协调游戏
QTRAN:
✓ 任意合作结构
✗ 计算成本高
6. Weighted QMIX
6.1 动机
QMIX的单调性约束限制了表达能力:
# 一个单调性无法表示的例子:
# 智能体A: 动作1比动作2好
# 智能体B: 动作2比动作1好
# 最优联合: (1,2) 或 (2,1),但QMIX无法表示这种"反协调"6.2 加权QMIX
Weighted QMIX4通过放松单调性来扩展表达能力:
class WeightedQMIXMixer(nn.Module):
"""加权QMIX:允许负权重但加权重构"""
def __init__(self, n_agents, state_dim):
super().__init__()
# 不再限制权重非负
self.hyper_w1 = nn.Linear(state_dim, n_agents * 32)
self.hyper_b1 = nn.Linear(state_dim, 32)
self.hyper_w2 = nn.Linear(state_dim, 32)
self.hyper_b2 = nn.Linear(state_dim, 1)
self.v = nn.Parameter(torch.rand(1, 32))
def forward(self, agent_qs, states):
batch_size = agent_qs.size(0)
# 权重可以取负值
w1 = self.hyper_w1(states).view(batch_size, self.n_agents, -1)
b1 = self.hyper_b1(states)
w2 = self.hyper_w2(states).view(batch_size, -1, 1)
b2 = self.hyper_b2(states)
# 第一层
hidden = F.elu(torch.bmm(agent_qs.unsqueeze(1), w1) + b1.unsqueeze(1))
# 第二层
q_tot = torch.bmm(hidden, w2) + b2.unsqueeze(2)
return q_tot.squeeze(-1)7. 参考文献
Footnotes
-
Sunehag et al. “Value-Decomposition Networks For Cooperative Multi-Agent Learning” AAMAS 2017 ↩
-
Rashid et al. “QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning” ICML 2018 ↩
-
Son et al. “QTRAN: Learning to Factorize with Transformation for Cooperative Multi-Agent Reinforcement Learning” ICML 2019 ↩
-
Rashid et al. “Weighted QMIX: Expanding Monotonic Value Function Factorisation” NeurIPS 2020 ↩