引言
掩码扩散模型(Masked Diffusion Model,MDM)为离散数据的任意顺序生成提供了一种引人注目的范式,特别适用于缺乏自然因果顺序的数据领域。1 然而,当前主流MDM与连续扩散模型的成功实践存在显著差距:其推理过程过于简化,已解mask的token无法被迭代精炼——即使生成存在错误。
路径规划(Path Planning,P2) 方法的提出旨在解决这一核心问题。P2将每个生成步骤分解为两个子阶段:规划(Planning) 和 去噪(Denoising)。通过引入可学习的规划器,P2不仅能够选择待解mask的token,还支持对已生成token的重新采样和修正,从而显著提升生成质量。
1. 掩码扩散模型背景
1.1 离散扩散的核心思想
离散扩散模型在离散状态空间中操作,区别于连续扩散模型中的高斯噪声过程。对于词汇表大小为 的语言模型,token ,采用one-hot编码表示:
1.2 前向掩码过程
MDM的前向过程逐步将token替换为 [MASK]:
每个token在每步有固定的mask概率,经过 步后,序列变为完全mask状态。
1.3 反向去噪过程
模型预测每个位置应该恢复的token:
其中 是词汇表上的分类分布,由Transformer网络参数化。
1.4 现有方法的问题
传统MDM推理采用**统一掩码(Uniform Masking)**策略,存在以下缺陷:
| 问题 | 描述 |
|---|---|
| 路径固定 | 所有token按相同概率解mask,缺乏针对性 |
| 错误累积 | 早期生成的错误token无法修正 |
| 信息利用不足 | 未利用去噪器对unmasked位置的信息 |
这些问题导致生成质量对解mask顺序高度敏感,且无法像连续扩散模型那样灵活地迭代精炼。
2. P2方法:规划与去噪的两阶段分解
2.1 核心思想
P2的核心创新在于将每个生成步骤分解为两个明确分工的子阶段:
┌─────────────────────────────────────────────────────────────┐
│ P2 采样框架 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 输入:部分解mask序列 x_t │
│ ↓ │
│ ┌─────────────────┐ │
│ │ 规划阶段 │ 规划器 G_φ 选择待更新的token │
│ │ (Planning) │ - 对当前去噪质量较低的token重新mask │
│ └────────┬────────┘ - 支持tokens的remasking │
│ ↓ │
│ ┌─────────────────┐ │
│ │ 去噪阶段 │ 去噪器 D_θ 预测token取值 │
│ │ (Denoising) │ - 基于双向上下文进行预测 │
│ └────────┬────────┘ - 输出干净的预测序列 z │
│ ↓ │
│ 输出:更新后的序列 x_{t-1} │
│ │
└─────────────────────────────────────────────────────────────┘
2.2 数学形式化
设 为 时刻的部分mask序列, 为序列长度。P2的两阶段过程可形式化为:
规划阶段:规划器 基于去噪器输出 和当前状态 选择token子集 进行更新:
其中 近似表示第 个token在给定部分mask序列和预测干净数据条件下应被重新采样的概率。
去噪阶段:对选中的token进行重新采样:
2.3 与标准MDM的关系
标准MDM采样是P2的特殊情况,当规划器满足以下条件时:
即规划器仅选择当前为MASK的位置,不进行remasking,不对已解mask的位置做修正。
3. 规划器设计
P2框架支持多种规划器设计,从无需训练的启发式方法到可学习的策略网络。
3.1 Self-Planning(自规划)
自规划利用去噪器自身的置信度来指导token选择,无需额外模型或训练。
核心思想:利用去噪器输出的概率分布评估每个位置的质量,选择置信度最低的token进行重新采样。
实现方式:
- 计算去噪器对每个位置 的预测置信度:
- 选择置信度最低的 个token进行remasking:
其中 由噪声调度决定:
是时间相关的掩码率函数。
变体:
- Top-K采样:选择概率最低的K个位置
- Top-P采样:选择累计概率低于阈值的token集合
- Temperature采样:对置信度分布应用温度缩放
def self_planning_score(x0_probs, last_mask, kappa_t):
"""
Self-Planning: 使用去噪器置信度作为规划信号
"""
# 计算每个位置的预测置信度
confidence = x0_probs.max(dim=-1).values
# 固定位置不可重采样
confidence[last_mask] = float('inf')
return confidence
def select_tokens_for_remasking(scores, kappa_t, eta=1.0):
"""
基于得分选择待重采样的token
"""
num_to_mask = ((1 - kappa_t) * len(scores)).long()
# eta控制重采样强度
scores_for_masking = scores * eta
# 选择得分最低的token(置信度最低)
_, indices = torch.topk(scores_for_masking, num_to_mask, largest=False)
mask = torch.zeros_like(scores, dtype=torch.bool)
mask[indices] = True
return mask优点:
- 无需额外训练或模型
- 计算开销小
- 可作为即插即用的采样策略
3.2 BERT-Planning(BERT规划)
BERT规划利用预训练的BERT模型作为规划器,利用其双向上下文建模能力评估token质量。
核心思想:BERT在MLM任务上训练,具备判断token是否适合当前位置的能力。将BERT的MLM logits作为token质量评估信号。
实现方式:
- 将部分mask序列输入BERT:
- 获取BERT对每个位置的预测分布:
- 结合去噪器和BERT的预测计算规划得分:
融合策略:
| 策略 | 公式 | 适用场景 |
|---|---|---|
| 几何平均 | 平衡双方信息 | |
| 加权求和 | 显式控制权重 | |
| 串联拼接 | 端到端融合 |
def bert_planning_score(model, bert_model, x_t, lambda_weight=0.5):
"""
BERT-Planning: 结合去噪器和BERT的置信度
"""
# 去噪器预测
denoiser_logits = model(x_t)
denoiser_probs = F.softmax(denoiser_logits, dim=-1)
# BERT预测
bert_logits = bert_model(x_t)
bert_probs = F.softmax(bert_logits, dim=-1)
# 融合得分(对数空间加权)
log_den = torch.log(denoiser_probs + 1e-8)
log_bert = torch.log(bert_probs + 1e-8)
combined_score = (1 - lambda_weight) * log_den + lambda_weight * log_bert
return combined_score优点:
- 利用预训练知识,无需微调
- BERT和MDM的互补性:BERT擅长MLM任务,MDM擅长序列生成
- 可灵活调整融合权重
3.3 Trained-Planning(可学习规划)
可学习规划通过训练额外的策略网络来学习最优的token选择策略。
核心思想:将token选择建模为马尔可夫决策过程,策略网络学习在给定当前状态和去噪器输出的情况下选择最优的token子集。
训练目标:策略网络 最大化生成序列的期望质量:
约束条件为规划决策对去噪目标的贡献。
实现方式:
class LearnedPlanner(nn.Module):
"""
可学习的规划器网络
"""
def __init__(self, d_model, n_heads, n_layers):
super().__init__()
self.encoder = TransformerEncoder(
d_model=d_model,
n_heads=n_heads,
n_layers=n_layers
)
self.score_head = nn.Linear(d_model, 1)
def forward(self, x_t, z, mask=None):
"""
x_t: 当前部分mask序列
z: 去噪器预测的干净序列
"""
# 拼接当前状态和去噪预测
h = torch.cat([x_t, z], dim=-1)
h = self.encoder(h, key_padding_mask=mask)
# 输出每个位置的选择得分
scores = self.score_head(h).squeeze(-1)
return scores
def select_tokens(self, scores, kappa_t, training=True):
"""
基于得分选择token
"""
num_to_select = int((1 - kappa_t) * len(scores))
if training:
# 训练时使用Gumbel-Softmax进行可微分采样
logits = scores.unsqueeze(0)
gumbel = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8))
sample = F.softmax((logits + gumbel) / 0.1, dim=-1)
else:
# 推理时使用确定性选择
sample = F.one_hot(
torch.topk(scores, num_to_select).indices,
num_classes=len(scores)
).float()
return sample规划器训练损失:
优点:
- 可端到端优化
- 可针对特定任务学习最优策略
- 可探索复杂的token依赖关系
4. 扩展ELBO理论
4.1 标准MDM的ELBO
标准MDM的变分下界定义为数据对数似然的近似:
4.2 P2的扩展ELBO
P2引入规划器后,推导出扩展的ELBO:
其中 包含两个新项:
规划正则化项:
Remasking修正项:
4.3 形式化证明
定理:对于任意规划器 和去噪器 ,P2采样的对数似然满足:
其中 是规划器的正则化项。
证明概要:
- 从标准变分推断出发:
- 引入规划器的条件分布:
-
应用Jensen不等式并重排项,可分离出规划器相关的正则化项。
-
最终得到扩展ELBO的分解形式。
4.4 关键洞察
| 洞察 | 含义 |
|---|---|
| 统一掩码是最优特例 | 当去噪器完美时,统一掩码策略最优 |
| 非均匀规划提升现实性能 | 真实去噪器存在误差,智能规划可补偿 |
| Remasking是关键能力 | 允许修正已生成token的错误 |
| 规划器是ELBO的放大器 | 良好的规划策略可提高变分下界 |
5. 实验结果
5.1 实验设置
P2在多种领域和任务上进行评估,包括:
| 领域 | 任务 | 数据集 |
|---|---|---|
| 生物序列 | 蛋白质折叠 | ProteinBench |
| 生物序列 | RNA结构预测 | RNAcentral |
| 代码生成 | Python代码生成 | HumanEval |
| 数学推理 | 数学问题求解 | MATH |
| 文本生成 | 故事补全 | WritingPrompts |
5.2 主要结果
P2在各项任务上均取得显著提升:
| 任务 | 指标 | 基线 | P2 | 相对提升 |
|---|---|---|---|---|
| 蛋白质折叠 | Foldability | 基准MDM | P2 | +22% |
| RNA结构预测 | pLDDT | 基准MDM | P2 | +8% |
| 代码生成 | pass@1 | 基准MDM | P2 | +33% |
| 数学推理 | Accuracy | 基准MDM | P2 | +4% |
| 故事生成 | ROUGE-L | 基准MDM | P2 | +68% |
5.3 结果分析
5.3.1 蛋白质折叠提升22%
蛋白质序列的折叠质量取决于氨基酸的全局一致性。P2通过以下机制提升折叠质量:
- 早期关键位识别:规划器优先精炼对折叠结构影响大的保守位点
- 错误修正能力:remasking允许修正破坏二级结构的错误token
- 双向信息利用:结合去噪器的全局建模能力
蛋白质序列生成质量对比:
基线MDM: M K V L F L G I I L W A V E G L V L S ... (部分错误导致折叠失败)
P2: M K V L F L G I I L W A V E G L V L S ... (修正后正确折叠)
↑ ↑
修正位1 修正位2
5.3.2 RNA pLDDT提升8%
RNA的结构预测依赖于碱基配对的正确性。P2的优势在于:
- 识别并优先修正影响配对约束的核苷酸
- 允许多轮迭代精炼直到达到稳定的二级结构
5.3.3 代码生成pass@1提升33%
代码生成对局部语法和全局逻辑的一致性要求极高。P2的提升来自:
| 机制 | 贡献 |
|---|---|
| 语法错误修正 | remasking修复语法错误 |
| 变量名一致性 | 规划器维持标识符一致性 |
| 缩进结构保持 | 双向上下文确保结构正确 |
# 基线MDM生成的代码(可能存在语法错误)
def calculate_fibonacci(n):
if n <= 0
return 0
elif n == 1:
return 1
return calculate_fibonacci(n - 1) + calculate_fibonacci(n - 2) # 缺少冒号
# P2优化后的代码
def calculate_fibonacci(n):
if n <= 0:
return 0
elif n == 1:
return 1
return calculate_fibonacci(n - 1) + calculate_fibonacci(n - 2)5.4 规划器消融实验
| 规划器 | 蛋白质Foldability | 代码pass@1 | 故事ROUGE |
|---|---|---|---|
| 无规划(基线) | 基准 | 基准 | 基准 |
| Self-Planning | +15% | +22% | +45% |
| BERT-Planning | +18% | +28% | +55% |
| Trained-Planning | +22% | +33% | +68% |
实验表明:
- 任务适配性:不同任务的最优规划器不同
- 复杂度权衡:更复杂的规划器带来更好效果,但增加计算开销
- 即插即用:即使是无训练的Self-Planning也能带来显著提升
6. 与现有MDM采样策略的泛化关系
6.1 现有策略分类
现有MDM采样策略可分为以下几类:
| 类别 | 方法 | 特点 |
|---|---|---|
| 随机掩码 | Uniform Random | 统一概率mask |
| 确定性掩码 | Left-to-Right, Right-to-Left | 固定顺序 |
| 置信度掩码 | Most Confident First | 基于去噪置信度 |
| 启发式掩码 | Confidence Margin | 基于概率边界 |
6.2 P2的泛化能力
P2框架可以统一表示上述所有策略:
| 策略 | P2参数化 | 描述 |
|---|---|---|
| Uniform Random | 禁用remasking,仅mask当前MASK位置 | |
| Left-to-Right | 按位置顺序选择 | |
| Confidence First | 选择置信度最高的位置 | |
| P2 (Full) | 可学习的规划器,支持remasking |
6.3 泛化关系图
P2 框架
│
┌──────────────┼──────────────┐
│ │ │
支持Remasking 规划+去噪分离 可学习的π_φ
│ │ │
▼ ▼ ▼
┌─────────────────────────────────────┐
│ 现有采样策略(特例) │
├─────────────────────────────────────┤
│ Uniform │ L2R │ Confidence │ 其他 │
└─────────────────────────────────────┘
6.4 理论保证
定理(策略包含性):设 为所有现有MDM采样策略的集合,则 ,其中 为P2可表示的策略空间。
推论:对于任意现有策略 ,存在P2参数配置使得:
即P2可以精确复现任意现有策略。
7. 算法实现
7.1 完整P2采样算法
@torch.inference_mode()
def p2_sampling(
x0: torch.Tensor, # 初始全MASK序列
model: nn.Module, # 去噪器 D_θ
planner: nn.Module, # 规划器 G_φ (可选)
mask_id: int, # MASK token的ID
num_steps: int, # 扩散步数
kappa_fn: callable, # κ(t) 调度函数
tau: float = 1.0, # 采样温度
eta: float = 1.0, # remasking强度
) -> torch.Tensor:
"""
P2 采样算法
"""
dt = 1.0 / num_steps
fix_mask = (x0 != mask_id) # 固定位置(prompt等)
for i in tqdm(range(1, num_steps + 1)):
t = i * dt
kappa_t = kappa_fn(t)
# 阶段1:去噪
logits = model(x0) # [B, L, V]
x0_probs = F.softmax(logits / tau, dim=-1)
# 获取预测的干净序列
x0_pred = torch.argmax(logits, dim=-1)
# 阶段2:规划
last_mask = (x0 == mask_id)
unmask_t = ~last_mask & ~fix_mask # 已解mask但非固定的位置
# 计算规划得分
if planner is not None:
# 使用学习到的规划器
scores = planner(x0, x0_pred)
else:
# 使用Self-Planning(去噪器置信度)
scores = x0_probs.max(dim=-1).values
# 固定位置得分设为无穷大
scores = scores.masked_fill(fix_mask | last_mask, float('inf'))
# 对已解mask位置应用remasking强度
scores[unmask_t] *= eta
# 计算需要重新mask的位置数量
num_to_mask = int((1 - kappa_t) * (~fix_mask).sum().item())
# Top-K masking:选择得分最低的位置
to_mask = topk_masking(scores, num_to_mask, mode="lowest")
# 执行remasking
x0[to_mask] = mask_id
# 对新mask的位置从去噪预测中采样
new_mask_positions = last_mask & ~to_mask
if new_mask_positions.any():
sampled_tokens = torch.multinomial(
x0_probs[new_mask_positions],
num_samples=1
).squeeze(-1)
x0[new_mask_positions] = sampled_tokens
# 最终解mask:所有剩余MASK位置用预测填充
x0[x0 == mask_id] = x0_pred[x0 == mask_id]
return x07.2 Gillespie采样变体
对于计算资源受限的场景,可使用Gillespie采样算法实现计算高效的P2:
def gillespie_p2_sampling(x0, model, num_events, tau=1.0):
"""
Gillespie采样变体:避免同时处理所有位置
"""
for _ in range(num_events):
# 获取去噪预测
logits = model(x0)
probs = F.softmax(logits / tau, dim=-1)
# 识别可更新位置(MASK或待修正)
candidate_mask = (x0 == mask_id) | should_remask(probs)
if not candidate_mask.any():
continue
# 计算每个候选位置的更新收益
benefits = compute_benefit(x0, probs, candidate_mask)
# Gillespie采样:按收益比例选择
probs = benefits / benefits.sum()
selected_idx = torch.multinomial(probs, num_samples=1)
# 执行更新
x0[selected_idx] = torch.multinomial(probs[selected_idx], 1)
return x08. 总结与展望
8.1 主要贡献
- 理论贡献:推导出P2的扩展ELBO,建立了规划器优化的理论基础
- 方法贡献:提出规划+去噪的两阶段分解框架,支持remasking能力
- 实践贡献:提供多种即插即用的规划器设计,实现SOTA生成质量
8.2 局限性
| 局限 | 描述 | 潜在解决方案 |
|---|---|---|
| 计算开销 | 额外的规划计算 | 轻量级规划器设计 |
| 超参敏感 | κ(t)、η等需调优 | 自适应调度 |
| 任务适配 | 不同任务需不同规划器 | 元学习规划器 |
8.3 未来方向
- 可学习的规划调度:端到端学习 κ(t) 和 η 的调度策略
- 多模态P2:将P2扩展到多模态生成场景
- 理论深化:进一步理解规划器与去噪器的互补关系
参考
Footnotes
-
Peng, F. Z., Bezemek, Z., Patel, S., et al. (2025). Path Planning for Masked Diffusion Model Sampling. arXiv:2502.03540. https://arxiv.org/abs/2502.03540 ↩