引言

掩码扩散模型(Masked Diffusion Model,MDM)为离散数据的任意顺序生成提供了一种引人注目的范式,特别适用于缺乏自然因果顺序的数据领域。1 然而,当前主流MDM与连续扩散模型的成功实践存在显著差距:其推理过程过于简化,已解mask的token无法被迭代精炼——即使生成存在错误。

路径规划(Path Planning,P2) 方法的提出旨在解决这一核心问题。P2将每个生成步骤分解为两个子阶段:规划(Planning)去噪(Denoising)。通过引入可学习的规划器,P2不仅能够选择待解mask的token,还支持对已生成token的重新采样和修正,从而显著提升生成质量。


1. 掩码扩散模型背景

1.1 离散扩散的核心思想

离散扩散模型在离散状态空间中操作,区别于连续扩散模型中的高斯噪声过程。对于词汇表大小为 的语言模型,token ,采用one-hot编码表示:

1.2 前向掩码过程

MDM的前向过程逐步将token替换为 [MASK]

每个token在每步有固定的mask概率,经过 步后,序列变为完全mask状态。

1.3 反向去噪过程

模型预测每个位置应该恢复的token:

其中 是词汇表上的分类分布,由Transformer网络参数化。

1.4 现有方法的问题

传统MDM推理采用**统一掩码(Uniform Masking)**策略,存在以下缺陷:

问题描述
路径固定所有token按相同概率解mask,缺乏针对性
错误累积早期生成的错误token无法修正
信息利用不足未利用去噪器对unmasked位置的信息

这些问题导致生成质量对解mask顺序高度敏感,且无法像连续扩散模型那样灵活地迭代精炼。


2. P2方法:规划与去噪的两阶段分解

2.1 核心思想

P2的核心创新在于将每个生成步骤分解为两个明确分工的子阶段:

┌─────────────────────────────────────────────────────────────┐
│                    P2 采样框架                               │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  输入:部分解mask序列 x_t                                    │
│         ↓                                                   │
│  ┌─────────────────┐                                       │
│  │   规划阶段       │  规划器 G_φ 选择待更新的token           │
│  │  (Planning)     │  - 对当前去噪质量较低的token重新mask      │
│  └────────┬────────┘  - 支持tokens的remasking               │
│           ↓                                                │
│  ┌─────────────────┐                                       │
│  │   去噪阶段       │  去噪器 D_θ 预测token取值               │
│  │  (Denoising)    │  - 基于双向上下文进行预测               │
│  └────────┬────────┘  - 输出干净的预测序列 z                │
│           ↓                                                │
│  输出:更新后的序列 x_{t-1}                                 │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2.2 数学形式化

时刻的部分mask序列, 为序列长度。P2的两阶段过程可形式化为:

规划阶段:规划器 基于去噪器输出 和当前状态 选择token子集 进行更新:

其中 近似表示第 个token在给定部分mask序列和预测干净数据条件下应被重新采样的概率。

去噪阶段:对选中的token进行重新采样:

2.3 与标准MDM的关系

标准MDM采样是P2的特殊情况,当规划器满足以下条件时:

即规划器仅选择当前为MASK的位置,不进行remasking,不对已解mask的位置做修正。


3. 规划器设计

P2框架支持多种规划器设计,从无需训练的启发式方法到可学习的策略网络。

3.1 Self-Planning(自规划)

自规划利用去噪器自身的置信度来指导token选择,无需额外模型或训练。

核心思想:利用去噪器输出的概率分布评估每个位置的质量,选择置信度最低的token进行重新采样。

实现方式

  1. 计算去噪器对每个位置 的预测置信度:
  1. 选择置信度最低的 个token进行remasking:

其中 由噪声调度决定:

是时间相关的掩码率函数。

变体

  • Top-K采样:选择概率最低的K个位置
  • Top-P采样:选择累计概率低于阈值的token集合
  • Temperature采样:对置信度分布应用温度缩放
def self_planning_score(x0_probs, last_mask, kappa_t):
    """
    Self-Planning: 使用去噪器置信度作为规划信号
    """
    # 计算每个位置的预测置信度
    confidence = x0_probs.max(dim=-1).values
    
    # 固定位置不可重采样
    confidence[last_mask] = float('inf')
    
    return confidence
 
def select_tokens_for_remasking(scores, kappa_t, eta=1.0):
    """
    基于得分选择待重采样的token
    """
    num_to_mask = ((1 - kappa_t) * len(scores)).long()
    
    # eta控制重采样强度
    scores_for_masking = scores * eta
    
    # 选择得分最低的token(置信度最低)
    _, indices = torch.topk(scores_for_masking, num_to_mask, largest=False)
    
    mask = torch.zeros_like(scores, dtype=torch.bool)
    mask[indices] = True
    
    return mask

优点

  • 无需额外训练或模型
  • 计算开销小
  • 可作为即插即用的采样策略

3.2 BERT-Planning(BERT规划)

BERT规划利用预训练的BERT模型作为规划器,利用其双向上下文建模能力评估token质量。

核心思想:BERT在MLM任务上训练,具备判断token是否适合当前位置的能力。将BERT的MLM logits作为token质量评估信号。

实现方式

  1. 将部分mask序列输入BERT:
  1. 获取BERT对每个位置的预测分布:
  1. 结合去噪器和BERT的预测计算规划得分:

融合策略

策略公式适用场景
几何平均平衡双方信息
加权求和显式控制权重
串联拼接端到端融合
def bert_planning_score(model, bert_model, x_t, lambda_weight=0.5):
    """
    BERT-Planning: 结合去噪器和BERT的置信度
    """
    # 去噪器预测
    denoiser_logits = model(x_t)
    denoiser_probs = F.softmax(denoiser_logits, dim=-1)
    
    # BERT预测
    bert_logits = bert_model(x_t)
    bert_probs = F.softmax(bert_logits, dim=-1)
    
    # 融合得分(对数空间加权)
    log_den = torch.log(denoiser_probs + 1e-8)
    log_bert = torch.log(bert_probs + 1e-8)
    
    combined_score = (1 - lambda_weight) * log_den + lambda_weight * log_bert
    
    return combined_score

优点

  • 利用预训练知识,无需微调
  • BERT和MDM的互补性:BERT擅长MLM任务,MDM擅长序列生成
  • 可灵活调整融合权重

3.3 Trained-Planning(可学习规划)

可学习规划通过训练额外的策略网络来学习最优的token选择策略。

核心思想:将token选择建模为马尔可夫决策过程,策略网络学习在给定当前状态和去噪器输出的情况下选择最优的token子集。

训练目标:策略网络 最大化生成序列的期望质量:

约束条件为规划决策对去噪目标的贡献。

实现方式

class LearnedPlanner(nn.Module):
    """
    可学习的规划器网络
    """
    def __init__(self, d_model, n_heads, n_layers):
        super().__init__()
        self.encoder = TransformerEncoder(
            d_model=d_model,
            n_heads=n_heads,
            n_layers=n_layers
        )
        self.score_head = nn.Linear(d_model, 1)
        
    def forward(self, x_t, z, mask=None):
        """
        x_t: 当前部分mask序列
        z: 去噪器预测的干净序列
        """
        # 拼接当前状态和去噪预测
        h = torch.cat([x_t, z], dim=-1)
        h = self.encoder(h, key_padding_mask=mask)
        
        # 输出每个位置的选择得分
        scores = self.score_head(h).squeeze(-1)
        
        return scores
    
    def select_tokens(self, scores, kappa_t, training=True):
        """
        基于得分选择token
        """
        num_to_select = int((1 - kappa_t) * len(scores))
        
        if training:
            # 训练时使用Gumbel-Softmax进行可微分采样
            logits = scores.unsqueeze(0)
            gumbel = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8))
            sample = F.softmax((logits + gumbel) / 0.1, dim=-1)
        else:
            # 推理时使用确定性选择
            sample = F.one_hot(
                torch.topk(scores, num_to_select).indices,
                num_classes=len(scores)
            ).float()
        
        return sample

规划器训练损失

优点

  • 可端到端优化
  • 可针对特定任务学习最优策略
  • 可探索复杂的token依赖关系

4. 扩展ELBO理论

4.1 标准MDM的ELBO

标准MDM的变分下界定义为数据对数似然的近似:

4.2 P2的扩展ELBO

P2引入规划器后,推导出扩展的ELBO:

其中 包含两个新项:

规划正则化项

Remasking修正项

4.3 形式化证明

定理:对于任意规划器 和去噪器 ,P2采样的对数似然满足:

其中 是规划器的正则化项。

证明概要

  1. 从标准变分推断出发:
  1. 引入规划器的条件分布:
  1. 应用Jensen不等式并重排项,可分离出规划器相关的正则化项。

  2. 最终得到扩展ELBO的分解形式。

4.4 关键洞察

洞察含义
统一掩码是最优特例当去噪器完美时,统一掩码策略最优
非均匀规划提升现实性能真实去噪器存在误差,智能规划可补偿
Remasking是关键能力允许修正已生成token的错误
规划器是ELBO的放大器良好的规划策略可提高变分下界

5. 实验结果

5.1 实验设置

P2在多种领域和任务上进行评估,包括:

领域任务数据集
生物序列蛋白质折叠ProteinBench
生物序列RNA结构预测RNAcentral
代码生成Python代码生成HumanEval
数学推理数学问题求解MATH
文本生成故事补全WritingPrompts

5.2 主要结果

P2在各项任务上均取得显著提升:

任务指标基线P2相对提升
蛋白质折叠Foldability基准MDMP2+22%
RNA结构预测pLDDT基准MDMP2+8%
代码生成pass@1基准MDMP2+33%
数学推理Accuracy基准MDMP2+4%
故事生成ROUGE-L基准MDMP2+68%

5.3 结果分析

5.3.1 蛋白质折叠提升22%

蛋白质序列的折叠质量取决于氨基酸的全局一致性。P2通过以下机制提升折叠质量:

  1. 早期关键位识别:规划器优先精炼对折叠结构影响大的保守位点
  2. 错误修正能力:remasking允许修正破坏二级结构的错误token
  3. 双向信息利用:结合去噪器的全局建模能力
蛋白质序列生成质量对比:

基线MDM:  M K V L F L G I I L W A V E G L V L S ... (部分错误导致折叠失败)
P2:       M K V L F L G I I L W A V E G L V L S ... (修正后正确折叠)
                                      ↑           ↑
                                   修正位1     修正位2

5.3.2 RNA pLDDT提升8%

RNA的结构预测依赖于碱基配对的正确性。P2的优势在于:

  • 识别并优先修正影响配对约束的核苷酸
  • 允许多轮迭代精炼直到达到稳定的二级结构

5.3.3 代码生成pass@1提升33%

代码生成对局部语法和全局逻辑的一致性要求极高。P2的提升来自:

机制贡献
语法错误修正remasking修复语法错误
变量名一致性规划器维持标识符一致性
缩进结构保持双向上下文确保结构正确
# 基线MDM生成的代码(可能存在语法错误)
def calculate_fibonacci(n):
    if n <= 0
        return 0
    elif n == 1:
        return 1
    return calculate_fibonacci(n - 1) + calculate_fibonacci(n - 2)  # 缺少冒号
 
# P2优化后的代码
def calculate_fibonacci(n):
    if n <= 0:
        return 0
    elif n == 1:
        return 1
    return calculate_fibonacci(n - 1) + calculate_fibonacci(n - 2)

5.4 规划器消融实验

规划器蛋白质Foldability代码pass@1故事ROUGE
无规划(基线)基准基准基准
Self-Planning+15%+22%+45%
BERT-Planning+18%+28%+55%
Trained-Planning+22%+33%+68%

实验表明:

  • 任务适配性:不同任务的最优规划器不同
  • 复杂度权衡:更复杂的规划器带来更好效果,但增加计算开销
  • 即插即用:即使是无训练的Self-Planning也能带来显著提升

6. 与现有MDM采样策略的泛化关系

6.1 现有策略分类

现有MDM采样策略可分为以下几类:

类别方法特点
随机掩码Uniform Random统一概率mask
确定性掩码Left-to-Right, Right-to-Left固定顺序
置信度掩码Most Confident First基于去噪置信度
启发式掩码Confidence Margin基于概率边界

6.2 P2的泛化能力

P2框架可以统一表示上述所有策略:

策略P2参数化描述
Uniform Random禁用remasking,仅mask当前MASK位置
Left-to-Right按位置顺序选择
Confidence First选择置信度最高的位置
P2 (Full)可学习的规划器,支持remasking

6.3 泛化关系图

                    P2 框架
                       │
        ┌──────────────┼──────────────┐
        │              │              │
   支持Remasking   规划+去噪分离   可学习的π_φ
        │              │              │
        ▼              ▼              ▼
    ┌─────────────────────────────────────┐
    │         现有采样策略(特例)          │
    ├─────────────────────────────────────┤
    │  Uniform │ L2R │ Confidence │ 其他  │
    └─────────────────────────────────────┘

6.4 理论保证

定理(策略包含性):设 为所有现有MDM采样策略的集合,则 ,其中 为P2可表示的策略空间。

推论:对于任意现有策略 ,存在P2参数配置使得:

即P2可以精确复现任意现有策略。


7. 算法实现

7.1 完整P2采样算法

@torch.inference_mode()
def p2_sampling(
    x0: torch.Tensor,           # 初始全MASK序列
    model: nn.Module,           # 去噪器 D_θ
    planner: nn.Module,          # 规划器 G_φ (可选)
    mask_id: int,                # MASK token的ID
    num_steps: int,              # 扩散步数
    kappa_fn: callable,          # κ(t) 调度函数
    tau: float = 1.0,            # 采样温度
    eta: float = 1.0,            # remasking强度
) -> torch.Tensor:
    """
    P2 采样算法
    """
    dt = 1.0 / num_steps
    fix_mask = (x0 != mask_id)   # 固定位置(prompt等)
    
    for i in tqdm(range(1, num_steps + 1)):
        t = i * dt
        kappa_t = kappa_fn(t)
        
        # 阶段1:去噪
        logits = model(x0)                   # [B, L, V]
        x0_probs = F.softmax(logits / tau, dim=-1)
        
        # 获取预测的干净序列
        x0_pred = torch.argmax(logits, dim=-1)
        
        # 阶段2:规划
        last_mask = (x0 == mask_id)
        unmask_t = ~last_mask & ~fix_mask     # 已解mask但非固定的位置
        
        # 计算规划得分
        if planner is not None:
            # 使用学习到的规划器
            scores = planner(x0, x0_pred)
        else:
            # 使用Self-Planning(去噪器置信度)
            scores = x0_probs.max(dim=-1).values
        
        # 固定位置得分设为无穷大
        scores = scores.masked_fill(fix_mask | last_mask, float('inf'))
        
        # 对已解mask位置应用remasking强度
        scores[unmask_t] *= eta
        
        # 计算需要重新mask的位置数量
        num_to_mask = int((1 - kappa_t) * (~fix_mask).sum().item())
        
        # Top-K masking:选择得分最低的位置
        to_mask = topk_masking(scores, num_to_mask, mode="lowest")
        
        # 执行remasking
        x0[to_mask] = mask_id
        
        # 对新mask的位置从去噪预测中采样
        new_mask_positions = last_mask & ~to_mask
        if new_mask_positions.any():
            sampled_tokens = torch.multinomial(
                x0_probs[new_mask_positions], 
                num_samples=1
            ).squeeze(-1)
            x0[new_mask_positions] = sampled_tokens
    
    # 最终解mask:所有剩余MASK位置用预测填充
    x0[x0 == mask_id] = x0_pred[x0 == mask_id]
    
    return x0

7.2 Gillespie采样变体

对于计算资源受限的场景,可使用Gillespie采样算法实现计算高效的P2:

def gillespie_p2_sampling(x0, model, num_events, tau=1.0):
    """
    Gillespie采样变体:避免同时处理所有位置
    """
    for _ in range(num_events):
        # 获取去噪预测
        logits = model(x0)
        probs = F.softmax(logits / tau, dim=-1)
        
        # 识别可更新位置(MASK或待修正)
        candidate_mask = (x0 == mask_id) | should_remask(probs)
        
        if not candidate_mask.any():
            continue
        
        # 计算每个候选位置的更新收益
        benefits = compute_benefit(x0, probs, candidate_mask)
        
        # Gillespie采样:按收益比例选择
        probs = benefits / benefits.sum()
        selected_idx = torch.multinomial(probs, num_samples=1)
        
        # 执行更新
        x0[selected_idx] = torch.multinomial(probs[selected_idx], 1)
    
    return x0

8. 总结与展望

8.1 主要贡献

  1. 理论贡献:推导出P2的扩展ELBO,建立了规划器优化的理论基础
  2. 方法贡献:提出规划+去噪的两阶段分解框架,支持remasking能力
  3. 实践贡献:提供多种即插即用的规划器设计,实现SOTA生成质量

8.2 局限性

局限描述潜在解决方案
计算开销额外的规划计算轻量级规划器设计
超参敏感κ(t)、η等需调优自适应调度
任务适配不同任务需不同规划器元学习规划器

8.3 未来方向

  • 可学习的规划调度:端到端学习 κ(t) 和 η 的调度策略
  • 多模态P2:将P2扩展到多模态生成场景
  • 理论深化:进一步理解规划器与去噪器的互补关系

参考

Footnotes

  1. Peng, F. Z., Bezemek, Z., Patel, S., et al. (2025). Path Planning for Masked Diffusion Model Sampling. arXiv:2502.03540. https://arxiv.org/abs/2502.03540