MARL值函数分解方法

1. 核心问题

在合作式MARL中,我们需要解决**信用分配(Credit Assignment)**问题:

其中 是某种可分解的操作。

1.1 为什么要分解?

  1. 维度灾难:联合Q函数随智能体数指数增长
  2. 去中心化执行:每个智能体只能访问自己的Q函数
  3. 可解释性:理解每个智能体的贡献

1.2 分解的理想性质

性质描述
个体一致性(IA)贪婪选择个体最优导致联合最优
团队可加性(TA)团队奖励等于个体奖励之和
联合贪婪性

2. VDN:Value Decomposition Networks

2.1 核心思想

VDN1使用最简单的线性分解:

2.2 网络架构

┌─────────────────────────────────────────────────────┐
│                 VDN 架构                            │
│                                                     │
│   状态s                                             │
│     │                                               │
│     ├──→ Q₁(s, a¹) ──────┐                        │
│     │                     │                        │
│     ├──→ Q₂(s, a²) ───→ │ + → Q_tot(s, a)       │
│     │                     │                        │
│     └──→ Qₙ(s, aⁿ) ──────┘                        │
│                                                     │
│   训练信号从 Q_tot 反向传播到每个 Q_i               │
└─────────────────────────────────────────────────────┘

2.3 PyTorch实现

class VDN(nn.Module):
    """Value Decomposition Networks"""
    
    def __init__(self, state_dim, action_dims, hidden_dim=64):
        super().__init__()
        self.n_agents = len(action_dims)
        
        # 每个智能体有自己的Q网络
        self.q_nets = nn.ModuleList([
            MLP(state_dim, action_dims[i], hidden_dim)
            for i in range(self.n_agents)
        ])
    
    def forward(self, states, actions=None):
        """
        states: [batch, n_agents, state_dim]
        actions: [batch, n_agents] (one-hot) or None
        """
        individual_qs = []
        
        for i, q_net in enumerate(self.q_nets):
            # 每个智能体根据自己观察到的状态计算Q值
            # 实际中,状态可能是局部观察
            q_i = q_net(states[:, i])  # [batch, action_dim_i]
            
            if actions is not None:
                # 如果提供了动作,计算对应Q值
                q_i = (q_i * actions[:, i]).sum(dim=-1, keepdim=True)
            
            individual_qs.append(q_i)
        
        # 分解:直接求和
        total_q = sum(individual_qs)
        
        return total_q, individual_qs
    
    def get_epsilon_greedy_action(self, states, epsilon=0.1):
        """ε-贪婪策略:分解后每个智能体独立选择"""
        actions = []
        for i, q_net in enumerate(self.q_nets):
            q_i = q_net(states[:, i])
            
            if random.random() < epsilon:
                a_i = torch.randint(0, q_i.size(-1), (q_i.size(0),))
            else:
                a_i = q_i.argmax(dim=-1)
            
            actions.append(a_i)
        
        return torch.stack(actions, dim=1)

2.4 局限性

VDN的线性加和假设过于简单:

# VDN假设:团队Q = Σ 个体Q
# 问题:无法表示非线性交互
 
# 例如:
# Q(s, left, left) = 10  # 两人同时向左(碰撞)
# Q(s, right, right) = 10  # 两人同时向右(碰撞)
# Q(s, left, right) = 100  # 两人分开(成功)
 
# VDN无法捕捉这个交互:
# VDN会预测 Q(s, left, left) = Q₁ + Q₁  # 总是最小
#           Q(s, left, right) = Q₁ + Q₂  # 可能不大

3. QMIX:Monotonic Value Function Factorisation

3.1 核心思想

QMIX2放宽了VDN的线性假设,使用单调混合网络

其中混合网络 满足单调性

3.2 为什么需要单调性?

单调性的重要性:

如果 Q_tot 对每个 Q_i 单调递增,
那么:
  argmax_{a} Q_tot(s, a) 
= argmax_{a} (f_mix(Q_1(s, a.1), ...))
= (argmax_{a1} Q_1(s, a1), argmax_{a2} Q_2(s, a2), ...)

即:个体贪婪 = 联合贪婪 ✓

3.3 网络架构

┌─────────────────────────────────────────────────────────┐
│                    QMIX 架构                            │
│                                                          │
│   状态s ──→ HyperNetwork ──→ 混合网络权重               │
│                        ↓                                 │
│   ┌──────────────────────────────────────────────────┐  │
│   │           混合网络 f_mix                          │  │
│   │                                                   │  │
│   │   Q₁ ──→ [W₁] ─┐                                │  │
│   │               ↓  │                               │  │
│   │   Q₂ ──→ [W₂] ─→ [W₃] ──→ Q_tot                │  │
│   │               ↗  │                               │  │
│   │   Qₙ ──→ [Wₙ] ─┘                               │  │
│   │                                                   │  │
│   │   所有权重 W_i ≥ 0 (保证单调性)                   │  │
│   └──────────────────────────────────────────────────┘  │
│                                                          │
└─────────────────────────────────────────────────────────┘

3.4 混合网络实现

class QMixNet(nn.Module):
    """QMIX混合网络"""
    
    def __init__(self, n_agents, state_dim, hidden_dim=32):
        super().__init__()
        self.n_agents = n_agents
        self.state_dim = state_dim
        
        # 生成第一层权重的超网络
        self.hyper_w1 = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_agents * hidden_dim)
        )
        
        # 生成第一层偏置的超网络
        self.hyper_b1 = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # 生成第二层权重
        self.hyper_w2 = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # 生成第二层偏置
        self.hyper_b2 = nn.Sequential(
            nn.Linear(state_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )
        
        # 第二层输出维度
        self.v2 = nn.Parameter(torch.zeros(1, hidden_dim // 2))
    
    def forward(self, agent_qs, states):
        """
        agent_qs: [batch, n_agents]
        states: [batch, state_dim]
        """
        batch_size = agent_qs.size(0)
        
        # 生成第一层参数 (确保非负以保证单调性)
        w1 = torch.abs(self.hyper_w1(states))  # [batch, n_agents * hidden_dim]
        w1 = w1.view(batch_size, self.n_agents, -1)  # [batch, n_agents, hidden_dim]
        
        b1 = self.hyper_b1(states)  # [batch, hidden_dim]
        
        # 生成第二层参数
        w2 = torch.abs(self.hyper_w2(states))  # [batch, hidden_dim]
        w2 = w2.view(batch_size, -1, 1)  # [batch, hidden_dim, 1]
        
        b2 = self.hyper_b2(states)  # [batch, hidden_dim // 2]
        
        # 第一层: ReLU(Q * W + b)
        hidden = torch.bmm(agent_qs.unsqueeze(1), w1) + b1.unsqueeze(1)
        hidden = F.relu(hidden)  # [batch, 1, hidden_dim]
        
        # 第二层
        v2 = self.v2.expand(batch_size, -1, 1)  # [batch, hidden_dim // 2, 1]
        
        # 合并v2到w2
        w2_full = torch.cat([w2[:, :hidden.size(-1)//2], 
                             v2.expand(-1, hidden.size(-1)//2, 1)], dim=1)
        
        q_tot = torch.bmm(hidden, w2_full) + b2.unsqueeze(2)
        return q_tot.squeeze(-1)  # [batch, 1]

3.5 训练算法

class QMIX:
    def __init__(self, n_agents, state_dim, action_dims):
        self.agent_net = VDNQNetwork(n_agents, state_dim, action_dims)
        self.mixer = QMixNet(n_agents, state_dim)
        self.target_agent_net = copy.deepcopy(self.agent_net)
        self.target_mixer = copy.deepcopy(self.mixer)
    
    def update(self, batch):
        states, actions, rewards, next_states, dones = batch
        
        # 计算当前Q值
        agent_qs, _ = self.agent_net(states, actions)
        q_tot = self.mixer(agent_qs, states)
        
        # 计算目标Q值
        with torch.no_grad():
            next_agent_qs, _ = self.target_agent_net(next_states)
            next_q_tot = self.target_mixer(next_agent_qs, next_states)
            target_q = rewards + (1 - dones) * self.gamma * next_q_tot
        
        # 损失
        loss = F.mse_loss(q_tot, target_q)
        
        # 梯度反向传播
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
    
    def update_target(self, tau=0.005):
        """软更新目标网络"""
        for target, source in zip(
            self.target_agent_net.parameters(),
            self.agent_net.parameters()
        ):
            target.data.copy_(tau * source + (1 - tau) * target)

4. QTRAN:General Value Function Factorisation

4.1 核心思想

QTRAN3解决VDN和QMIX的结构限制:

  • VDN:只能表示可加性
  • QMIX:只能表示单调性
  • QTRAN:可以表示任意分解

4.2 变换方法

QTRAN的核心是将原始Q函数变换为可分解的形式

使得:

4.3 损失函数

QTRAN包含两个损失项:

class QTRANLoss(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, q_tot, individual_qs, target_q_tot, 
                states, actions):
        """
        Args:
            q_tot: 联合Q值 [batch]
            individual_qs: 个体Q值列表 [[batch, dim_i]]
            target_q_tot: 目标联合Q值 [batch]
        """
        batch_size = q_tot.size(0)
        
        # 1. 重建损失:保证变换后的Q_tot与原始相近
        recon_loss = F.mse_loss(q_tot, target_q_tot)
        
        # 2. 分解损失:个体Q的贪婪选择等于联合贪婪
        # 计算每个智能体在自己的Q上贪婪选择的联合Q
        greedy_joint_q = 0
        for i, q_i in enumerate(individual_qs):
            greedy_a_i = q_i.argmax(dim=-1)  # [batch]
            # 这里应该用实际的贪婪Q值
            greedy_joint_q += q_i.gather(-1, greedy_a_i.unsqueeze(-1)).squeeze(-1)
        
        # 3. 差异损失:保证变换后的联合贪婪等于真实联合贪婪
        difference_loss = F.relu(q_tot - greedy_joint_q).mean()
        
        return recon_loss + difference_loss

5. QTRAN vs VDN vs QMIX

5.1 表达能力对比

算法表达能力约束去中心化贪婪
VDN线性加和
QMIX单调函数单调性
QTRAN任意函数

5.2 计算复杂度

算法混合网络参数量训练复杂度
VDNO(1)
QMIXO(n × h)
QTRANO(n × h + 额外网络)

5.3 适用场景

VDN:
  ✓ 智能体独立工作,低交互
  ✗ 高度协调任务

QMIX:
  ✓ 大多数合作任务
  ✗ 复杂的反协调游戏

QTRAN:
  ✓ 任意合作结构
  ✗ 计算成本高

6. Weighted QMIX

6.1 动机

QMIX的单调性约束限制了表达能力:

# 一个单调性无法表示的例子:
# 智能体A: 动作1比动作2好
# 智能体B: 动作2比动作1好
# 最优联合: (1,2) 或 (2,1),但QMIX无法表示这种"反协调"

6.2 加权QMIX

Weighted QMIX4通过放松单调性来扩展表达能力:

class WeightedQMIXMixer(nn.Module):
    """加权QMIX:允许负权重但加权重构"""
    
    def __init__(self, n_agents, state_dim):
        super().__init__()
        # 不再限制权重非负
        self.hyper_w1 = nn.Linear(state_dim, n_agents * 32)
        self.hyper_b1 = nn.Linear(state_dim, 32)
        self.hyper_w2 = nn.Linear(state_dim, 32)
        self.hyper_b2 = nn.Linear(state_dim, 1)
        
        self.v = nn.Parameter(torch.rand(1, 32))
    
    def forward(self, agent_qs, states):
        batch_size = agent_qs.size(0)
        
        # 权重可以取负值
        w1 = self.hyper_w1(states).view(batch_size, self.n_agents, -1)
        b1 = self.hyper_b1(states)
        
        w2 = self.hyper_w2(states).view(batch_size, -1, 1)
        b2 = self.hyper_b2(states)
        
        # 第一层
        hidden = F.elu(torch.bmm(agent_qs.unsqueeze(1), w1) + b1.unsqueeze(1))
        
        # 第二层
        q_tot = torch.bmm(hidden, w2) + b2.unsqueeze(2)
        return q_tot.squeeze(-1)

7. 参考文献

Footnotes

  1. Sunehag et al. “Value-Decomposition Networks For Cooperative Multi-Agent Learning” AAMAS 2017

  2. Rashid et al. “QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning” ICML 2018

  3. Son et al. “QTRAN: Learning to Factorize with Transformation for Cooperative Multi-Agent Reinforcement Learning” ICML 2019

  4. Rashid et al. “Weighted QMIX: Expanding Monotonic Value Function Factorisation” NeurIPS 2020