概述
dUltra1是由多所高校和研究机构提出的创新工作,旨在解决Masked Diffusion Language Models(MDLMs)的推理效率问题。尽管MDLMs具有并行生成的潜力,但现有开源模型即使使用复杂采样策略,每个前向传播也只能解码不到5个token,极大限制了其并行生成优势。
dUltra的核心创新是引入基于GRPO(Group Relative Policy Optimization)的强化学习框架,学习最优的解mask顺序(unmasking strategy),从而实现真正的高效并行生成。
问题背景
MDLMs的效率瓶颈
传统MDLM生成过程:
┌──────────────────────────────────────────────────────────────┐
│ 每个时间步需要: │
│ 1. 全序列前向传播 → O(n) 计算 │
│ 2. 预测所有token的概率分布 │
│ 3. 只解mask部分token │
│ │
│ 问题: │
│ - 即使只恢复部分token,仍需处理全序列 │
│ - mask比例小时,大量计算浪费在已确定token上 │
│ - "5 tokens/pass"远低于理论上限 │
└──────────────────────────────────────────────────────────────┘
现有加速方法的问题
| 方法 | 原理 | 局限性 |
|---|---|---|
| 置信度启发式 | 高置信度token优先解mask | 固定策略,缺乏适应性 |
| 蒸馏加速 | 从AR模型蒸馏 | 可能off-policy,性能受限 |
| dParallel/d3LLM | 行为克隆 | 受限于base模型质量 |
dUltra核心设计
1. Unmasking Planner Head
dUltra引入了Unmasking Planner Head,这是一个可学习的模块,预测每个token的解mask概率:
class UnmaskingPlannerHead(nn.Module):
"""
预测每个位置在当前时间步应该解mask的概率
"""
def __init__(self, d_model, n_heads, n_layers):
super().__init__()
self.encoder = TransformerEncoder(d_model, n_heads, n_layers)
self.head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, 1),
nn.Sigmoid() # 输出(0,1)之间的概率
)
def forward(self, x_masked, hidden_states):
"""
x_masked: 当前位置是否被mask (0/1)
hidden_states: 扩散模型各层的隐藏状态
"""
# 利用多层隐藏状态
h = torch.stack(hidden_states, dim=0).mean(dim=0)
h = self.encoder(h)
# 预测每个位置的解mask概率
unmask_probs = self.head(h).squeeze(-1) # (batch, seq_len)
# mask位置的概率保持,未mask位置设为0
unmask_probs = unmask_probs * x_masked.float()
return unmask_probs2. 问题形式化
dUltra将解mask顺序建模为独立伯努利分布的组合:
其中:
- :时间步t的隐状态
- :解mask动作(哪些位置恢复)
- :第i个位置解mask的概率
3. GRPO优化框架
dUltra使用GRPO(Group Relative Policy Optimization)进行优化:
class GRPOOptimizer:
"""
Group Relative Policy Optimization
"""
def __init__(self, model, planner, ref_model, ref_planner):
self.model = model
self.planner = planner
self.ref_model = ref_model
self.ref_planner = ref_planner
def compute_reward(self, response, ground_truth, steps_used):
"""
奖励函数设计
"""
# 1. 可验证奖励:答案正确性
verifiable_reward = calculate_accuracy(response, ground_truth)
# 2. 蒸馏奖励:与AR模型的一致性
distillation_reward = self.compute_distillation_reward(response)
# 3. 效率奖励:使用的步数越少越好
efficiency_reward = -0.01 * steps_used
return verifiable_reward + distillation_reward + efficiency_reward
def update(self, batch):
# 采样多个unmasking轨迹
responses, steps_list = self.sample_trajectories(batch)
# 计算每个轨迹的奖励
rewards = [
self.compute_reward(r, gt, s)
for r, gt, s in zip(responses, batch['gt'], steps_list)
]
# GRPO损失
advantages = self.compute_advantages(rewards)
# 策略更新
loss = self.grpo_loss(advantages)
return loss4. 训练目标
dUltra的完整训练目标:
其中:
- :标准扩散训练损失
- :GRPO策略优化损失
- :蒸馏损失
算法流程
┌──────────────────────────────────────────────────────────────┐
│ dUltra训练流程 │
│ │
│ for each batch (x₀, prompt, ground_truth): │
│ │ │
│ ├─→ 1. 扩散模型前向传播 │
│ │ x_T = fully_masked(x₀) │
│ │ │
│ ├─→ 2. Planner预测解mask概率 │
│ │ p = planner(x_T, hidden_states) │
│ │ │
│ ├─→ 3. 伯努利采样确定解mask位置 │
│ │ a ~ Bernoulli(p) │
│ │ │
│ ├─→ 4. 执行解mask,更新状态 │
│ │ x_{t-1} = model(x_t, a) │
│ │ │
│ ├─→ 5. 计算奖励 │
│ │ r = reward(response, gt, steps) │
│ │ │
│ └─→ 6. GRPO更新Planner │
│ θ ← θ + η ∇log π_θ(a|s) · A │
│ │
└──────────────────────────────────────────────────────────────┘
推理过程
@torch.no_grad()
def dultra_generate(model, planner, prompt, max_steps=8):
"""
dUltra推理过程
"""
x = full_mask(len(prompt))
x[:len(prompt)] = prompt
for step in range(max_steps):
# 1. 获取模型隐藏状态
hidden_states = model.get_hidden_states(x)
# 2. Planner预测解mask概率
unmask_probs = planner(x, hidden_states)
# 3. 基于概率采样解mask位置
# 使用Gumbel-Softmax进行可微采样
if step < max_steps - 1:
# 训练时:Gumbel-Softmax
gates = gumbel_softmax(unmask_probs, tau=1.0)
else:
# 最后一步:贪婪
gates = (unmask_probs > 0.5).float()
# 4. 更新状态
x_new = model.forward_step(x, gates)
x = x * (1 - gates) + x_new * gates
return x实验结果
1. 数学推理任务
| 模型 | GSM8K | MATH | avg_steps |
|---|---|---|---|
| AR (baseline) | 53.2% | 45.1% | 1.0x |
| dParallel | 51.8% | 43.5% | 0.9x |
| d3LLM | 52.5% | 44.8% | 0.85x |
| Fast-dLLM | 52.1% | 43.9% | 0.8x |
| dUltra | 54.8% | 47.2% | 0.75x |
2. 代码生成任务
| 模型 | HumanEval | MBPP | avg_steps |
|---|---|---|---|
| AR (baseline) | 58.3% | 62.1% | 1.0x |
| dParallel | 56.4% | 59.8% | 0.9x |
| d3LLM | 57.8% | 61.5% | 0.85x |
| dUltra | 60.1% | 64.3% | 0.72x |
3. 效率分析
推理速度对比 (tokens/forward_pass):
15 ┤
│ ┌───┐
12 ┤ │ │
│ ┌───┐ │ │
8 ┤ ┌───┐ │ │ │ │
│ │ │ │ │ ┌─┐ │ │
5 ┤ ┌───┐ │ │ ┌─┐ │ │ │ │ │ │
│ │ │ │ │ │ │ │ │ │ │ │ │
0 ┼───┴───┴──────┴───┴────┴─┴──┴───┴──┴─┴──┴───┴─→ 模型
MDLM dParallel d3LLM Fast dUltra
基线
4. 学习到的Unmasking策略
dUltra学习到的策略具有以下特点:
- 早期:优先解mask高熵位置(不确定性高的token)
- 中期:平衡探索与利用
- 后期:利用已恢复的上下文加速剩余token
与其他方法的对比
| 维度 | dUltra | 置信度启发式 | 蒸馏方法 |
|---|---|---|---|
| 策略学习 | 可学习、自适应 | 固定规则 | 静态策略 |
| off-policy风险 | 无(on-policy) | N/A | 有 |
| 与base模型关系 | 协同 | N/A | 从属 |
| 效率-质量权衡 | 最优 | 次优 | 受限 |
| 可扩展性 | 强 | 弱 | 中等 |
技术洞察
1. On-Policy的重要性
# Off-Policy vs On-Policy对比
# Off-Policy (d3LLM):
# 使用AR模型生成的轨迹训练扩散模型
# 问题:分布偏移,策略可能被误导
# On-Policy (dUltra):
# 扩散模型生成轨迹,扩散模型更新策略
# 优势:策略一致性,避免分布偏移2. 多目标奖励设计
dUltra的奖励函数精心平衡了多个目标:
- :保证生成质量
- :保持与AR模型的知识一致性
- :鼓励高效生成
3. Gumbel-Softmax在推理中的应用
def gumbel_softmax(logits, tau=1.0):
"""
Gumbel-Softmax采样
用于在推理时实现可微的随机采样
"""
gumbels = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20)
y = (logits + gumbels) / tau
return F.softmax(y, dim=-1)