概述

dUltra1是由多所高校和研究机构提出的创新工作,旨在解决Masked Diffusion Language Models(MDLMs)的推理效率问题。尽管MDLMs具有并行生成的潜力,但现有开源模型即使使用复杂采样策略,每个前向传播也只能解码不到5个token,极大限制了其并行生成优势。

dUltra的核心创新是引入基于GRPO(Group Relative Policy Optimization)的强化学习框架,学习最优的解mask顺序(unmasking strategy),从而实现真正的高效并行生成。


问题背景

MDLMs的效率瓶颈

传统MDLM生成过程:

┌──────────────────────────────────────────────────────────────┐
│  每个时间步需要:                                              │
│  1. 全序列前向传播 → O(n) 计算                                │
│  2. 预测所有token的概率分布                                    │
│  3. 只解mask部分token                                          │
│                                                              │
│  问题:                                                       │
│  - 即使只恢复部分token,仍需处理全序列                         │
│  - mask比例小时,大量计算浪费在已确定token上                    │
│  - "5 tokens/pass"远低于理论上限                               │
└──────────────────────────────────────────────────────────────┘

现有加速方法的问题

方法原理局限性
置信度启发式高置信度token优先解mask固定策略,缺乏适应性
蒸馏加速从AR模型蒸馏可能off-policy,性能受限
dParallel/d3LLM行为克隆受限于base模型质量

dUltra核心设计

1. Unmasking Planner Head

dUltra引入了Unmasking Planner Head,这是一个可学习的模块,预测每个token的解mask概率:

class UnmaskingPlannerHead(nn.Module):
    """
    预测每个位置在当前时间步应该解mask的概率
    """
    def __init__(self, d_model, n_heads, n_layers):
        super().__init__()
        self.encoder = TransformerEncoder(d_model, n_heads, n_layers)
        self.head = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Linear(d_model // 2, 1),
            nn.Sigmoid()  # 输出(0,1)之间的概率
        )
        
    def forward(self, x_masked, hidden_states):
        """
        x_masked: 当前位置是否被mask (0/1)
        hidden_states: 扩散模型各层的隐藏状态
        """
        # 利用多层隐藏状态
        h = torch.stack(hidden_states, dim=0).mean(dim=0)
        h = self.encoder(h)
        
        # 预测每个位置的解mask概率
        unmask_probs = self.head(h).squeeze(-1)  # (batch, seq_len)
        
        # mask位置的概率保持,未mask位置设为0
        unmask_probs = unmask_probs * x_masked.float()
        
        return unmask_probs

2. 问题形式化

dUltra将解mask顺序建模为独立伯努利分布的组合:

其中:

  • :时间步t的隐状态
  • :解mask动作(哪些位置恢复)
  • :第i个位置解mask的概率

3. GRPO优化框架

dUltra使用GRPO(Group Relative Policy Optimization)进行优化:

class GRPOOptimizer:
    """
    Group Relative Policy Optimization
    """
    def __init__(self, model, planner, ref_model, ref_planner):
        self.model = model
        self.planner = planner
        self.ref_model = ref_model
        self.ref_planner = ref_planner
        
    def compute_reward(self, response, ground_truth, steps_used):
        """
        奖励函数设计
        """
        # 1. 可验证奖励:答案正确性
        verifiable_reward = calculate_accuracy(response, ground_truth)
        
        # 2. 蒸馏奖励:与AR模型的一致性
        distillation_reward = self.compute_distillation_reward(response)
        
        # 3. 效率奖励:使用的步数越少越好
        efficiency_reward = -0.01 * steps_used
        
        return verifiable_reward + distillation_reward + efficiency_reward
    
    def update(self, batch):
        # 采样多个unmasking轨迹
        responses, steps_list = self.sample_trajectories(batch)
        
        # 计算每个轨迹的奖励
        rewards = [
            self.compute_reward(r, gt, s) 
            for r, gt, s in zip(responses, batch['gt'], steps_list)
        ]
        
        # GRPO损失
        advantages = self.compute_advantages(rewards)
        
        # 策略更新
        loss = self.grpo_loss(advantages)
        
        return loss

4. 训练目标

dUltra的完整训练目标:

其中:

  • :标准扩散训练损失
  • :GRPO策略优化损失
  • :蒸馏损失

算法流程

┌──────────────────────────────────────────────────────────────┐
│  dUltra训练流程                                               │
│                                                              │
│  for each batch (x₀, prompt, ground_truth):                 │
│      │                                                       │
│      ├─→ 1. 扩散模型前向传播                                  │
│      │      x_T = fully_masked(x₀)                          │
│      │                                                       │
│      ├─→ 2. Planner预测解mask概率                            │
│      │      p = planner(x_T, hidden_states)                  │
│      │                                                       │
│      ├─→ 3. 伯努利采样确定解mask位置                         │
│      │      a ~ Bernoulli(p)                                 │
│      │                                                       │
│      ├─→ 4. 执行解mask,更新状态                             │
│      │      x_{t-1} = model(x_t, a)                         │
│      │                                                       │
│      ├─→ 5. 计算奖励                                         │
│      │      r = reward(response, gt, steps)                  │
│      │                                                       │
│      └─→ 6. GRPO更新Planner                                  │
│             θ ← θ + η ∇log π_θ(a|s) · A                      │
│                                                              │
└──────────────────────────────────────────────────────────────┘

推理过程

@torch.no_grad()
def dultra_generate(model, planner, prompt, max_steps=8):
    """
    dUltra推理过程
    """
    x = full_mask(len(prompt))
    x[:len(prompt)] = prompt
    
    for step in range(max_steps):
        # 1. 获取模型隐藏状态
        hidden_states = model.get_hidden_states(x)
        
        # 2. Planner预测解mask概率
        unmask_probs = planner(x, hidden_states)
        
        # 3. 基于概率采样解mask位置
        # 使用Gumbel-Softmax进行可微采样
        if step < max_steps - 1:
            # 训练时:Gumbel-Softmax
            gates = gumbel_softmax(unmask_probs, tau=1.0)
        else:
            # 最后一步:贪婪
            gates = (unmask_probs > 0.5).float()
        
        # 4. 更新状态
        x_new = model.forward_step(x, gates)
        x = x * (1 - gates) + x_new * gates
    
    return x

实验结果

1. 数学推理任务

模型GSM8KMATHavg_steps
AR (baseline)53.2%45.1%1.0x
dParallel51.8%43.5%0.9x
d3LLM52.5%44.8%0.85x
Fast-dLLM52.1%43.9%0.8x
dUltra54.8%47.2%0.75x

2. 代码生成任务

模型HumanEvalMBPPavg_steps
AR (baseline)58.3%62.1%1.0x
dParallel56.4%59.8%0.9x
d3LLM57.8%61.5%0.85x
dUltra60.1%64.3%0.72x

3. 效率分析

推理速度对比 (tokens/forward_pass):

   15 ┤
      │                                          ┌───┐
   12 ┤                                          │   │
      │                              ┌───┐       │   │
    8 ┤              ┌───┐          │   │       │   │
      │              │   │          │   │  ┌─┐  │   │
    5 ┤   ┌───┐      │   │    ┌─┐  │   │  │ │  │   │
      │   │   │      │   │    │ │  │   │  │ │  │   │
    0 ┼───┴───┴──────┴───┴────┴─┴──┴───┴──┴─┴──┴───┴─→ 模型
         MDLM  dParallel  d3LLM  Fast   dUltra
         基线

4. 学习到的Unmasking策略

dUltra学习到的策略具有以下特点:

  1. 早期:优先解mask高熵位置(不确定性高的token)
  2. 中期:平衡探索与利用
  3. 后期:利用已恢复的上下文加速剩余token

与其他方法的对比

维度dUltra置信度启发式蒸馏方法
策略学习可学习、自适应固定规则静态策略
off-policy风险无(on-policy)N/A
与base模型关系协同N/A从属
效率-质量权衡最优次优受限
可扩展性中等

技术洞察

1. On-Policy的重要性

# Off-Policy vs On-Policy对比
 
# Off-Policy (d3LLM):
#   使用AR模型生成的轨迹训练扩散模型
#   问题:分布偏移,策略可能被误导
 
# On-Policy (dUltra):
#   扩散模型生成轨迹,扩散模型更新策略
#   优势:策略一致性,避免分布偏移

2. 多目标奖励设计

dUltra的奖励函数精心平衡了多个目标:

  • :保证生成质量
  • :保持与AR模型的知识一致性
  • :鼓励高效生成

3. Gumbel-Softmax在推理中的应用

def gumbel_softmax(logits, tau=1.0):
    """
    Gumbel-Softmax采样
    用于在推理时实现可微的随机采样
    """
    gumbels = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20)
    y = (logits + gumbels) / tau
    return F.softmax(y, dim=-1)

参考

Footnotes

  1. dUltra: Ultra-Fast Diffusion Language Models via Reinforcement Learning