约束强化学习

约束强化学习(Constrained MDP, CMDP)在最大化累积奖励的同时,需要满足一系列约束条件,广泛应用于安全关键场景。1

问题定义

CMDP框架

标准MDP扩展,引入成本函数约束阈值

其中:

  • :第 个成本函数
  • :第 个成本约束阈值

约束类型

类型示例应用场景
累积约束能耗限制、时间预算
最终状态约束安全区域
即时约束物理限制、硬件约束

CPO算法

Constrained Policy Optimization

Achiam et al. (2017) 提出了首个通用约束策略优化算法。2

核心思想

基于TRPO的信任域方法,为约束RL提供理论保证。

原始-对偶优化

CPO使用拉格朗日方法,将约束问题转化为无约束的原始-对偶问题:

其中 是累积奖励目标, 是累积成本, 是约束阈值。

算法实现

import torch
import torch.nn as nn
import torch.optim as optim
 
class CPO:
    def __init__(self, policy, value_net, cost_net, 
                 lambda_lr=0.1, max_kl=0.01):
        self.policy = policy
        self.value_net = value_net
        self.cost_net = cost_net
        self.lambda_lr = lambda_lr
        self.max_kl = max_kl
        self.lambda_param = torch.tensor(1.0, requires_grad=True)
        self.lambda_optimizer = optim.Adam([self.lambda_param], lr=lambda_lr)
    
    def compute_advantages(self, rewards, values, gamma=0.99, lam=0.95):
        """GAE优势估计"""
        advantages = []
        gae = 0
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + gamma * values[t+1] - values[t]
            gae = delta + gamma * lam * gae
            advantages.insert(0, gae)
        return torch.tensor(advantages)
    
    def surrogate_loss(self, states, actions, old_log_probs, advantages):
        """策略损失(TRPO风格)"""
        dist = self.policy(states)
        log_probs = dist.log_prob(actions)
        ratio = torch.exp(log_probs - old_log_probs)
        
        # 剪切
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1-self.eps, 1+self.eps) * advantages
        return -torch.min(surr1, surr2).mean()
    
    def trust_region_update(self, states, actions, old_kl):
        """信任域约束更新"""
        # 计算梯度
        loss = self.surrogate_loss(states, actions, 
                                   self.old_log_probs,
                                   self.advantages)
        grads = torch.autograd.grad(loss, self.policy.parameters())
        
        # 计算自然梯度(简化版)
        with torch.no_grad():
            # 应用KL约束步
            for param, grad in zip(self.policy.parameters(), grads):
                param -= self.max_kl * grad
    
    def update(self, trajectories, cost_limit):
        """
        trajectories: [(states, actions, rewards, costs, dones), ...]
        """
        # 1. 准备数据
        states = torch.cat([t.states for t in trajectories])
        actions = torch.cat([t.actions for t in trajectories])
        rewards = torch.cat([t.rewards for t in trajectories])
        costs = torch.cat([t.costs for t in trajectories])
        
        # 2. 计算价值估计
        with torch.no_grad():
            values = self.value_net(states).squeeze()
            next_value = self.value_net(trajectories[-1].next_states).squeeze()
            values = torch.cat([values, next_value])
        
        # 3. 计算优势
        self.advantages = self.compute_advantages(rewards, values)
        self.old_log_probs = self.policy(states).log_prob(actions).detach()
        
        # 4. 更新策略(信任域)
        self.trust_region_update(states, actions, self.max_kl)
        
        # 5. 更新lambda(梯度上升)
        actual_cost = costs.mean()
        cost_violation = actual_cost - cost_limit
        
        lambda_loss = -self.lambda_param * cost_violation
        self.lambda_optimizer.zero_grad()
        lambda_loss.backward()
        self.lambda_optimizer.step()
        
        # 约束lambda非负
        self.lambda_param.data = torch.clamp(self.lambda_param, min=0)

PPO-Lagrangian

实际应用最广泛的方法

PPO-Lagrangian将PPO与Lagrangian方法结合,实现简单且效果好。

class PPOLagrangian:
    def __init__(self, policy, value_net, cost_net,
                 lr=3e-4, lambda_lr=0.1, eps=0.2):
        self.policy = policy
        self.value_net = value_net
        self.cost_net = cost_net
        
        self.policy_optimizer = optim.Adam(policy.parameters(), lr=lr)
        self.value_optimizer = optim.Adam(value_net.parameters(), lr=lr)
        self.cost_optimizer = optim.Adam(cost_net.parameters(), lr=lr)
        
        self.lambda_lr = lambda_lr
        self.lambda_param = torch.tensor(1.0)
        self.eps = eps
    
    def ppo_loss(self, states, actions, old_log_probs, advantages):
        """PPO剪切损失"""
        dist = self.policy(states)
        log_probs = dist.log_prob(actions)
        ratio = torch.exp(log_probs - old_log_probs)
        
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1-self.eps, 1+self.eps) * advantages
        
        return -torch.min(surr1, surr2).mean()
    
    def update(self, batch, cost_limit):
        states, actions, old_log_probs, advantages = batch
        
        # 1. 更新策略
        policy_loss = self.ppo_loss(states, actions, old_log_probs, advantages)
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()
        
        # 2. 更新价值网络
        with torch.no_grad():
            returns = advantages + self.value_net(states).squeeze()
        value_loss = nn.MSELoss()(self.value_net(states).squeeze(), returns)
        self.value_optimizer.zero_grad()
        value_loss.backward()
        self.value_optimizer.step()
        
        # 3. 更新成本网络
        cost_returns = advantages + self.cost_net(states).squeeze()
        cost_loss = nn.MSELoss()(self.cost_net(states).squeeze(), cost_returns)
        self.cost_optimizer.zero_grad()
        cost_loss.backward()
        self.cost_optimizer.step()
        
        # 4. 更新拉格朗日乘子
        actual_cost = self.cost_net(states).mean()
        cost_violation = actual_cost - cost_limit
        
        # 梯度上升更新lambda
        with torch.no_grad():
            self.lambda_param += self.lambda_lr * cost_violation
            self.lambda_param = max(0, self.lambda_param)  # 非负约束

约束满足策略

Lagrangian方法总结

方面描述
原始问题 s.t.
对偶变量
对偶函数
更新
乘子更新

自适应Lagrangian

class AdaptiveLagrangian:
    def __init__(self, lambda_init=1.0, 
                 kl_target=0.01,
                 lambda_lr_factor=1.5):
        self.lambda_param = lambda_init
        self.kl_target = kl_target
        self.lr_factor = lambda_lr_factor
        self.cost_history = []
    
    def update_lambda(self, actual_cost, target_cost):
        """自适应更新lambda"""
        cost_violation = actual_cost - target_cost
        self.cost_history.append(cost_violation)
        
        # 根据violation方向调整学习率
        if cost_violation > 0:
            # 违反约束,增加lambda学习率
            lr = self.lr_factor * abs(cost_violation)
        else:
            # 满足约束,减少lambda学习率
            lr = self.lr_factor / 2 * abs(cost_violation)
        
        # 梯度上升
        self.lambda_param += lr * cost_violation
        self.lambda_param = max(0, self.lambda_param)
        
        # 平滑更新
        self.lambda_param = 0.99 * self.lambda_param + 0.01 * self.lambda_param
    
    def get_adaptive_lambda(self, iteration):
        """获取自适应lambda,带有正则化"""
        # 随着迭代增加,逐渐减小lambda的剧烈变化
        decay = 0.999 ** iteration
        return self.lambda_param * decay

安全RL的应用

机器人控制

  • 碰撞避免
  • 力矩限制
  • 关节角度约束

自动驾驶

  • 安全距离保持
  • 速度限制
  • 车道偏离约束

医疗决策

  • 药物剂量限制
  • 治疗风险控制

参考资料


相关链接

Footnotes

  1. Altman, “Constrained Markov Decision Processes”, 1999

  2. Achiam et al., “Constrained Policy Optimization”, ICML, 2017