离线强化学习算法分类
1. 算法分类体系
离线RL算法可以按照不同的设计哲学进行分类1:
离线强化学习算法
├── 策略约束方法
│ ├── Behavior Cloning (BC)
│ ├── TD3+BC
│ └── Critic Regularized Regression (CRR)
├── 悲观主义方法
│ ├── Conservative Q-Learning (CQL)
│ └── Implicit Q-Learning (IQL)
├── 基于模型的方法
│ ├── MOPO
│ └── COMBO
└── 序列建模方法
├── Decision Transformer
└── Trajectory Transformer
2. 策略约束方法
2.1 Behavior Cloning (BC)
行为克隆是最简单的离线RL方法,本质上是监督学习:
优点:
- 简单高效
- 理论基础清晰
缺点:
- 无法超越行为策略
- 复合误差:单步错误会导致后续状态分布偏移
2.2 TD3+BC
TD3+BC2在TD3的基础上加入了一个简单的行为约束:
# 核心思想:BC作为正则项
policy_loss = -Q(s, π_θ(s)) + α * λ * BC_loss
# 其中 BC_loss = E[(π_θ(a|s) - π_β(a|s))²]伪代码:
# TD3+BC 核心步骤
for each batch B:
# 1. 标准TD3更新
a1, a2 = π_θ(s') + noise
Q1, Q2 = critic(s, a)
critic_loss = (Q - y)²
# 2. 加入BC正则项(只在策略更新时)
if update_policy:
policy_loss = -Q1(s, π_θ(s)) + α * λ * (π_θ(a|s) - π_β(a|s))²D4RL基准性能对比:
| 数据集 | BC | 10% BC | TD3+BC |
|---|---|---|---|
| halfcheetah-medium | 42.6 | 42.5 | 48.3 |
| hopper-medium | 52.9 | 56.9 | 59.3 |
| walker2d-medium | 75.3 | 75.0 | 83.7 |
2.3 Critic Regularized Regression (CRR)
CRR3通过优势函数加权来选择性地学习:
def compute_advantage(batch, V, gamma, lam):
"""计算优势函数"""
advantages = []
for traj in batch.trajectories:
gae = 0
for t in reversed(range(len(traj))):
delta = traj.r[t] + gamma * V(traj.s[t+1]) - V(traj.s[t])
gae = delta + gamma * lam * gae
advantages.insert(0, gae)
return advantages
def crr_loss(batch, π, Q, V, λ):
"""
CRR损失函数
λ: 过滤阈值,高于该阈值的样本才被学习
"""
advantages = compute_advantage(batch, V, 0.99, 0.95)
log_probs = π.log_prob(batch.a, batch.s)
# 指数加权策略
weights = torch.exp(torch.clamp(advantages / λ, max=50))
# 加权BC损失
loss = -(weights * log_probs).mean()
return loss3. 悲观主义方法
3.1 Conservative Q-Learning (CQL)
CQL4的核心思想是惩罚Q值的过度估计:
完整目标:
直观理解:
CQL前: CQL后:
┌─────────────────┐ ┌─────────────────┐
│ Q(s,a) 高估 │ │ Q(s,a) 保守 │
│ ↑ │ │ ↓ │
│ 真实Q OOD动作 │ │ 真实Q OOD动作 │
│ █ │ │ █ │
│ ████ │ │ ████ │
│ │ │ │
└─────────────────┘ └─────────────────┘
PyTorch实现:
class CQL:
def __init__(self, alpha=1.0):
self.alpha = alpha # 保守系数
def update(self, batch, policy, q1, q2, target_q):
# 1. 标准TD损失
with torch.no_grad():
next_actions = policy(batch.next_state)
target = batch.reward + gamma * target_q(batch.next_state, next_actions)
current_q = q1(batch.state, batch.action)
td_loss = F.mse_loss(current_q, target)
# 2. CQL保守损失
# 随机采样动作计算Q值期望
random_actions = torch.rand_like(batch.action) * 2 - 1 # uniform[-1,1]
q_random = q1(batch.state, random_actions)
# 数据集中的动作
q_data = q1(batch.state, batch.action)
# CQL目标:提升数据动作的Q值,降低随机动作的Q值
cql_loss = self.alpha * (q_random.mean() - q_data.mean())
return td_loss + cql_loss3.2 Implicit Q-Learning (IQL)
IQL5的核心创新是避免评估OOD动作:
设计动机:
- CQL需要采样OOD动作来计算保守损失
- 这引入了一个新的估计问题
解决方案:使用**分位数回归(Expectile Regression)**来隐式学习最优策略
IQL定义两个关键函数:
- V函数:状态值函数,使用expectile回归学习
- A函数:优势函数,通过V和Q的关系得到
def expectile_loss(x, tau=0.7):
"""
分位数损失(Expectile Loss)
τ: 分位数参数(0.5为中位数)
"""
diff = x # Q - V
return torch.abs(tau - (diff < 0).float()) * diff ** 2
class IQL:
def __init__(self, tau=0.7, beta=3.0):
self.tau = tau # expectile参数
self.beta = beta # 优势函数截断参数
def compute_v(self, batch, q):
"""
学习V函数:关注高回报的状态
"""
with torch.no_grad():
# 在数据动作上计算Q值
q_values = q(batch.state, batch.action)
# Expectile回归:τ接近1时,V倾向于高Q值
v_target = expectile_loss(q_values - self.v(batch.state), self.tau)
return v_target.mean()
def compute_advantage(self, batch, q, v):
"""优势函数 A(s,a) = Q(s,a) - V(s)"""
return q(batch.state, batch.action) - v(batch.state)
def extract_policy(self, batch, q, v):
"""
隐式策略提取:不需要显式计算OOD动作
"""
advantages = self.compute_advantage(batch, q, v)
# 使用指数加权:只保留高优势的动作
weights = torch.exp(advantages / self.beta)
weights = weights / weights.sum(dim=-1, keepdim=True)
# 返回加权后的策略
return weightsIQL vs CQL 对比:
| 特性 | CQL | IQL |
|---|---|---|
| OOD动作处理 | 显式采样 | 隐式避免 |
| 计算复杂度 | O(n)采样 | O(1) |
| 参数敏感度 | α需调节 | τ、β较稳定 |
| 性能 | 良好 | 相当或更好 |
D4RL性能对比:
| 数据集 | BC | CQL | IQL |
|---|---|---|---|
| halfcheetah-medium | 42.6 | 44.0 | 47.4 |
| hopper-medium | 52.9 | 58.5 | 66.3 |
| walker2d-medium | 75.3 | 72.5 | 78.3 |
4. 算法选择指南
4.1 场景匹配
| 场景 | 推荐算法 |
|---|---|
| 数据质量高(expert级别) | BC、TD3+BC |
| 数据多样(multi-level) | CQL、IQL |
| 数据稀缺 | IQL |
| 需要稳定收敛 | TD3+BC |
4.2 超参数建议
| 算法 | 关键参数 | 建议范围 |
|---|---|---|
| CQL | α (保守系数) | 0.1 - 10.0 |
| IQL | τ (expectile) | 0.6 - 0.9 |
| IQL | β (优势截断) | 1.0 - 10.0 |
| TD3+BC | λ (BC权重) | 0.5 - 2.0 |
5. 实践注意事项
5.1 数据预处理
def preprocess_offline_data(dataset):
"""离线RL数据预处理"""
# 1. 去除异常奖励
dataset = clip_rewards(dataset, low=-100, high=100)
# 2. 归一化状态和动作
dataset.state = normalize(dataset.state)
dataset.action = normalize(dataset.action)
# 3. 数据增强(可选)
if use_augmentation:
dataset = mixup(dataset)
return dataset5.2 训练技巧
- 使用双 Critic:减少Q值过估计
- 目标网络:稳定训练
- Early Stopping:防止过度拟合到OOD区域
- 梯度裁剪:避免梯度爆炸
6. 参考文献
Footnotes
-
“A Survey on Offline Reinforcement Learning: Taxonomy, Review, and Open Problems” IEEE Transactions on Neural Networks (2023) ↩
-
Fujimoto & Gu. “A Minimalist Approach to Offline Reinforcement Learning” NeurIPS 2021 ↩
-
Wang et al. “Critic Regularized Regression” NeurIPS 2020 ↩
-
Kumar et al. “Conservative Q-Learning for Offline Reinforcement Learning” NeurIPS 2020 ↩
-
Kostrikov et al. “Offline Reinforcement Learning with Implicit Q-Learning” ICLR 2022 ↩