Decision Transformer与序列建模方法

1. 动机与核心思想

传统的离线RL方法通过值函数近似来学习策略,这不可避免地需要处理OOD动作的外推问题。

Decision Transformer (DT)1 的核心创新是:将RL问题重新定义为条件序列建模问题。

1.1 核心洞察

不再显式学习值函数和策略,而是将整个轨迹(历史状态、动作、奖励)建模为序列,利用语言模型的思路进行生成。

1.2 关键设计

输入序列格式:
┌─────────┬─────────┬─────────┬─────────┬─────────┬─────────┐
│ R_to_go │ s_t    │ a_t    │ s_{t-1} │ a_{t-1} │ ...     │
│ (target)│ (state)│ (action)│ (state) │ (action) │         │
└─────────┴─────────┴─────────┴─────────┴─────────┴─────────┘
          ↑                              ↑
          └── 预测位置:预测 a_t ────────┘

输出:给定状态下应采取的动作

三种token类型

  • State tokens :状态嵌入
  • Action tokens :动作嵌入(离散或连续)
  • Return-to-go tokens :累积回报的嵌入

2. 轨迹表示与嵌入

2.1 Return-to-Go计算

表示从时间步 到轨迹结束的累积折扣回报:

在输入中, 是期望达到的回报目标。

2.2 Embedding层

class DecisionTransformer(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=128, max_length=30):
        super().__init__()
        
        # 嵌入维度
        self.hidden_size = hidden_size
        
        # 线性嵌入层
        self.state_embed = nn.Linear(state_dim, hidden_size)
        self.action_embed = nn.Linear(action_dim, hidden_size)
        self.return_embed = nn.Linear(1, hidden_size)
        
        # 位置编码
        self.pos_embed = nn.Parameter(torch.zeros(1, max_length, hidden_size))
        
        # Transformer编码器
        self.Transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=hidden_size,
                nhead=4,
                dim_feedforward=hidden_size * 4,
                dropout=0.1,
                activation='gelu',
                batch_first=True
            ),
            num_layers=3
        )
        
        # 预测头
        self.action_head = nn.Linear(hidden_size, action_dim)
        
        # 初始化
        self._init_weights()
    
    def _init_weights(self):
        nn.init.normal_(self.pos_embed, std=0.02)
        nn.init.xavier_uniform_(self.state_embed.weight)
        nn.init.xavier_uniform_(self.action_embed.weight)

2.3 序列构建

def build_sequence(self, states, actions, rewards, target_return):
    """
    构建输入序列
    格式: [R_target, s_t, a_t, s_{t-1}, a_{t-1}, ...]
    """
    seq_len = len(states)
    tokens = []
    
    # 从后向前构建(因果掩码使得只能看到过去)
    for t in range(seq_len - 1, -1, -1):
        # Return-to-go
        tokens.append(self.return_embed(rewards[t].unsqueeze(-1)))
        # State
        tokens.append(self.state_embed(states[t]))
        # Action
        if t < seq_len - 1:  # 第一个状态不需要动作
            tokens.append(self.action_embed(actions[t]))
    
    # 堆叠并添加位置编码
    x = torch.stack(tokens, dim=1)  # [batch, seq_len, hidden]
    x = x + self.pos_embed[:, :x.size(1)]
    
    return x

3. 训练与推理

3.1 训练目标

DT使用标准的自回归监督学习

def forward(self, states, actions, rewards, targets=None):
    """
    states: [batch, seq_len, state_dim]
    actions: [batch, seq_len, action_dim]
    rewards: [batch, seq_len, 1]
    """
    batch_size, seq_len = states.shape[:2]
    
    # 构建输入序列
    x = self.build_sequence(states, actions, rewards)
    
    # 通过Transformer
    x = self.Transformer(x)
    
    # 只在动作位置预测
    # 动作位置在序列中是每隔一个(除了第一个)
    action_indices = torch.arange(2, x.size(1), 3)  # [0, 3, 6, ...]
    action_logits = self.action_head(x[:, action_indices])
    
    if targets is not None:
        # 训练时:计算交叉熵损失
        # actions的最后一个是用于预测的
        loss = F.cross_entropy(
            action_logits.reshape(-1, action_logits.size(-1)),
            targets.reshape(-1)
        )
        return loss, action_logits
    
    return action_logits

3.2 推理过程

@torch.no_grad()
def generate(self, initial_state, target_return, env, max_steps=200):
    """
    自回归生成动作序列
    """
    # 初始化
    states = [initial_state]
    actions = []
    rewards = []
    returns = [target_return]
    
    for t in range(max_steps):
        # 构建当前序列
        seq_states = torch.stack(states[-30:]).unsqueeze(0)  # 固定窗口
        seq_actions = torch.stack(actions[-29:]) if actions else torch.zeros(0)
        seq_rewards = torch.tensor(rewards[-30:]).unsqueeze(-1)
        
        # 预测下一个动作
        action_logits = self.forward(seq_states, seq_actions, seq_rewards)
        action = action_logits[0, -1]  # 最后一个动作
        
        # 在环境中执行
        next_state, reward, done, _ = env.step(action)
        
        # 更新
        states.append(next_state)
        actions.append(action)
        rewards.append(reward)
        
        if done:
            break
    
    return actions, states, rewards

4. 与传统RL的对比

4.1 范式差异

特性传统RLDecision Transformer
学习目标值函数Q(s,a)条件生成p(a|s,R)
OOD处理悲观估计/约束序列建模自然避免
长程依赖Bellman方程传播Transformer直接建模
训练稳定性需要target network监督学习,稳定
可扩展性一般强(受益于LLM进展)

4.2 Stitching能力

DT的一个关键能力是Stitching:通过组合不同轨迹的片段来达到高回报区域。

原始轨迹1: A → B → C (中回报)
原始轨迹2: D → E → C (中回报)  
原始轨迹3: C → F → G (高回报)

DT可以学到: D → E → C → F → G (高回报)
                   ↑
                   └── 拼接点

5. 变体与扩展

5.1 Q-learning Decision Transformer (QDT)2

QDT在DT中引入动态规划的思想:

class QDT(nn.Module):
    """
    QDT核心改进:
    1. 学习Q函数而非直接预测动作
    2. 结合TD学习进行值更新
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Q函数头
        self.q_head = nn.Linear(self.hidden_size, 1)
        
    def forward(self, states, actions, rewards, target_return):
        x = self.build_sequence(states, actions, rewards)
        x = self.Transformer(x)
        
        # 预测Q值
        action_indices = torch.arange(2, x.size(1), 3)
        q_values = self.q_head(x[:, action_indices])
        
        return q_values

5.2 Trajectory Transformer3

将整个轨迹视为生成任务:

  • 使用GPT-style的decoder架构
  • 联合建模状态、动作、奖励
  • 使用beam search进行规划

5.3 Critic-Guided Decision Transformer

在DT基础上加入Critic来引导生成4

class CriticGuidedDT:
    """
    训练时:DT + Critic联合训练
    推理时:Critic过滤/重排DT生成的候选动作
    """
    def __init__(self, dt, critic):
        self.dt = dt
        self.critic = critic
    
    @torch.no_grad()
    def act(self, state, target_return, k=5):
        # 1. DT生成k个候选动作
        candidates = self.dt.generate_candidates(state, target_return, k)
        
        # 2. Critic评估并选择
        q_values = self.critic(state, candidates)
        best_idx = q_values.argmax()
        
        return candidates[best_idx]

6. 实验结果

6.1 D4RL基准对比

环境BCCQLIQLDT
halfcheetah-medium42.644.047.442.6
hopper-medium52.958.566.367.6
walker2d-medium75.372.578.374.0

观察:DT在某些任务(hopper-medium)上表现优异,但在需要精确Q值估计的任务上可能不如IQL。

6.2 长程依赖任务

DT的优势在长序列决策任务中更明显:

┌─────────────────────────────────────────┐
│        Atari 游戏(长序列)              │
│  ┌───────────────────────────────────┐  │
│  │ DT: 更擅长处理跨时间的依赖关系    │  │
│  │ IQL: 短视,容易陷入局部最优       │  │
│  └───────────────────────────────────┘  │
└─────────────────────────────────────────┘

7. 优缺点分析

7.1 优点

  1. 训练稳定:纯监督学习,无需Bellman迭代
  2. 长程建模能力强:Transformer天然适合长序列
  3. 可扩展性好:受益于大语言模型的研究进展
  4. 灵活的条件生成:可以指定不同的目标回报

7.2 缺点

  1. 计算成本高:Transformer的O(n²)复杂度
  2. 序列长度限制:需要截断历史
  3. Stitching能力受限:取决于数据的覆盖程度
  4. 超参数敏感:位置编码、模型深度等

8. 参考文献

Footnotes

  1. Chen et al. “Decision Transformer: Reinforcement Learning via Sequence Modeling” NeurIPS 2021

  2. Yamagata et al. “Q-learning Decision Transformer: Leveraging Dynamic Programming for Conditional Sequence Modelling in Offline RL” ICML 2023

  3. Janner et al. “Trajectory Transformer: Reinforcement Learning with Transformer World Models” 2021

  4. Wu et al. “Critic-Guided Decision Transformer for Offline Reinforcement Learning” AAAI 2024