Decision Transformer与序列建模方法
1. 动机与核心思想
传统的离线RL方法通过值函数近似来学习策略,这不可避免地需要处理OOD动作的外推问题。
Decision Transformer (DT)1 的核心创新是:将RL问题重新定义为条件序列建模问题。
1.1 核心洞察
不再显式学习值函数和策略,而是将整个轨迹(历史状态、动作、奖励)建模为序列,利用语言模型的思路进行生成。
1.2 关键设计
输入序列格式:
┌─────────┬─────────┬─────────┬─────────┬─────────┬─────────┐
│ R_to_go │ s_t │ a_t │ s_{t-1} │ a_{t-1} │ ... │
│ (target)│ (state)│ (action)│ (state) │ (action) │ │
└─────────┴─────────┴─────────┴─────────┴─────────┴─────────┘
↑ ↑
└── 预测位置:预测 a_t ────────┘
输出:给定状态下应采取的动作
三种token类型:
- State tokens :状态嵌入
- Action tokens :动作嵌入(离散或连续)
- Return-to-go tokens :累积回报的嵌入
2. 轨迹表示与嵌入
2.1 Return-to-Go计算
表示从时间步 到轨迹结束的累积折扣回报:
在输入中, 是期望达到的回报目标。
2.2 Embedding层
class DecisionTransformer(nn.Module):
def __init__(self, state_dim, action_dim, hidden_size=128, max_length=30):
super().__init__()
# 嵌入维度
self.hidden_size = hidden_size
# 线性嵌入层
self.state_embed = nn.Linear(state_dim, hidden_size)
self.action_embed = nn.Linear(action_dim, hidden_size)
self.return_embed = nn.Linear(1, hidden_size)
# 位置编码
self.pos_embed = nn.Parameter(torch.zeros(1, max_length, hidden_size))
# Transformer编码器
self.Transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=4,
dim_feedforward=hidden_size * 4,
dropout=0.1,
activation='gelu',
batch_first=True
),
num_layers=3
)
# 预测头
self.action_head = nn.Linear(hidden_size, action_dim)
# 初始化
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.pos_embed, std=0.02)
nn.init.xavier_uniform_(self.state_embed.weight)
nn.init.xavier_uniform_(self.action_embed.weight)2.3 序列构建
def build_sequence(self, states, actions, rewards, target_return):
"""
构建输入序列
格式: [R_target, s_t, a_t, s_{t-1}, a_{t-1}, ...]
"""
seq_len = len(states)
tokens = []
# 从后向前构建(因果掩码使得只能看到过去)
for t in range(seq_len - 1, -1, -1):
# Return-to-go
tokens.append(self.return_embed(rewards[t].unsqueeze(-1)))
# State
tokens.append(self.state_embed(states[t]))
# Action
if t < seq_len - 1: # 第一个状态不需要动作
tokens.append(self.action_embed(actions[t]))
# 堆叠并添加位置编码
x = torch.stack(tokens, dim=1) # [batch, seq_len, hidden]
x = x + self.pos_embed[:, :x.size(1)]
return x3. 训练与推理
3.1 训练目标
DT使用标准的自回归监督学习:
def forward(self, states, actions, rewards, targets=None):
"""
states: [batch, seq_len, state_dim]
actions: [batch, seq_len, action_dim]
rewards: [batch, seq_len, 1]
"""
batch_size, seq_len = states.shape[:2]
# 构建输入序列
x = self.build_sequence(states, actions, rewards)
# 通过Transformer
x = self.Transformer(x)
# 只在动作位置预测
# 动作位置在序列中是每隔一个(除了第一个)
action_indices = torch.arange(2, x.size(1), 3) # [0, 3, 6, ...]
action_logits = self.action_head(x[:, action_indices])
if targets is not None:
# 训练时:计算交叉熵损失
# actions的最后一个是用于预测的
loss = F.cross_entropy(
action_logits.reshape(-1, action_logits.size(-1)),
targets.reshape(-1)
)
return loss, action_logits
return action_logits3.2 推理过程
@torch.no_grad()
def generate(self, initial_state, target_return, env, max_steps=200):
"""
自回归生成动作序列
"""
# 初始化
states = [initial_state]
actions = []
rewards = []
returns = [target_return]
for t in range(max_steps):
# 构建当前序列
seq_states = torch.stack(states[-30:]).unsqueeze(0) # 固定窗口
seq_actions = torch.stack(actions[-29:]) if actions else torch.zeros(0)
seq_rewards = torch.tensor(rewards[-30:]).unsqueeze(-1)
# 预测下一个动作
action_logits = self.forward(seq_states, seq_actions, seq_rewards)
action = action_logits[0, -1] # 最后一个动作
# 在环境中执行
next_state, reward, done, _ = env.step(action)
# 更新
states.append(next_state)
actions.append(action)
rewards.append(reward)
if done:
break
return actions, states, rewards4. 与传统RL的对比
4.1 范式差异
| 特性 | 传统RL | Decision Transformer |
|---|---|---|
| 学习目标 | 值函数Q(s,a) | 条件生成p(a|s,R) |
| OOD处理 | 悲观估计/约束 | 序列建模自然避免 |
| 长程依赖 | Bellman方程传播 | Transformer直接建模 |
| 训练稳定性 | 需要target network | 监督学习,稳定 |
| 可扩展性 | 一般 | 强(受益于LLM进展) |
4.2 Stitching能力
DT的一个关键能力是Stitching:通过组合不同轨迹的片段来达到高回报区域。
原始轨迹1: A → B → C (中回报)
原始轨迹2: D → E → C (中回报)
原始轨迹3: C → F → G (高回报)
DT可以学到: D → E → C → F → G (高回报)
↑
└── 拼接点
5. 变体与扩展
5.1 Q-learning Decision Transformer (QDT)2
QDT在DT中引入动态规划的思想:
class QDT(nn.Module):
"""
QDT核心改进:
1. 学习Q函数而非直接预测动作
2. 结合TD学习进行值更新
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Q函数头
self.q_head = nn.Linear(self.hidden_size, 1)
def forward(self, states, actions, rewards, target_return):
x = self.build_sequence(states, actions, rewards)
x = self.Transformer(x)
# 预测Q值
action_indices = torch.arange(2, x.size(1), 3)
q_values = self.q_head(x[:, action_indices])
return q_values5.2 Trajectory Transformer3
将整个轨迹视为生成任务:
- 使用GPT-style的decoder架构
- 联合建模状态、动作、奖励
- 使用beam search进行规划
5.3 Critic-Guided Decision Transformer
在DT基础上加入Critic来引导生成4:
class CriticGuidedDT:
"""
训练时:DT + Critic联合训练
推理时:Critic过滤/重排DT生成的候选动作
"""
def __init__(self, dt, critic):
self.dt = dt
self.critic = critic
@torch.no_grad()
def act(self, state, target_return, k=5):
# 1. DT生成k个候选动作
candidates = self.dt.generate_candidates(state, target_return, k)
# 2. Critic评估并选择
q_values = self.critic(state, candidates)
best_idx = q_values.argmax()
return candidates[best_idx]6. 实验结果
6.1 D4RL基准对比
| 环境 | BC | CQL | IQL | DT |
|---|---|---|---|---|
| halfcheetah-medium | 42.6 | 44.0 | 47.4 | 42.6 |
| hopper-medium | 52.9 | 58.5 | 66.3 | 67.6 |
| walker2d-medium | 75.3 | 72.5 | 78.3 | 74.0 |
观察:DT在某些任务(hopper-medium)上表现优异,但在需要精确Q值估计的任务上可能不如IQL。
6.2 长程依赖任务
DT的优势在长序列决策任务中更明显:
┌─────────────────────────────────────────┐
│ Atari 游戏(长序列) │
│ ┌───────────────────────────────────┐ │
│ │ DT: 更擅长处理跨时间的依赖关系 │ │
│ │ IQL: 短视,容易陷入局部最优 │ │
│ └───────────────────────────────────┘ │
└─────────────────────────────────────────┘
7. 优缺点分析
7.1 优点
- 训练稳定:纯监督学习,无需Bellman迭代
- 长程建模能力强:Transformer天然适合长序列
- 可扩展性好:受益于大语言模型的研究进展
- 灵活的条件生成:可以指定不同的目标回报
7.2 缺点
- 计算成本高:Transformer的O(n²)复杂度
- 序列长度限制:需要截断历史
- Stitching能力受限:取决于数据的覆盖程度
- 超参数敏感:位置编码、模型深度等
8. 参考文献
Footnotes
-
Chen et al. “Decision Transformer: Reinforcement Learning via Sequence Modeling” NeurIPS 2021 ↩
-
Yamagata et al. “Q-learning Decision Transformer: Leveraging Dynamic Programming for Conditional Sequence Modelling in Offline RL” ICML 2023 ↩
-
Janner et al. “Trajectory Transformer: Reinforcement Learning with Transformer World Models” 2021 ↩
-
Wu et al. “Critic-Guided Decision Transformer for Offline Reinforcement Learning” AAAI 2024 ↩