扩散语言模型理论基础

扩散语言模型(Diffusion Language Models)将连续扩散模型的理论框架扩展到离散状态空间,为语言建模提供了全新的生成范式。12本文档系统阐述其数学理论基础,包括离散状态空间建模、变分推断框架以及与自回归模型的数学联系。

概述

传统的语言模型主要基于自回归(Autoregressive, AR)分解,而扩散语言模型将文本生成建模为从噪声到数据的渐进式转换过程。这一视角的转变带来了以下理论优势:

理论特性自回归模型扩散语言模型
分解结构条件概率链式分解层次化变量模型
推断方向前向因果推断双向近似推断
计算复杂度 次前向传播 次并行前向传播
上下文利用单向上下文双向上下文

扩散语言模型的理论核心在于如何在离散状态空间定义有效的扩散过程,以及如何通过变分推断优化边际似然。这两个问题构成了本文的理论主线。


离散状态空间扩散

离散数据建模

Token序列的One-Hot表示

对于词汇表大小为 的语言模型,每个token 。使用one-hot编码将token映射到离散概率分布空间:

其中 表示 维概率单纯形。对于独热向量 ,其编码为:

对于长度为 的序列 ,其联合表示为:

其中 表示张量积。然而,实际实现中通常使用嵌入表示以降低计算复杂度。

分类分布的处理

在离散扩散模型中,核心问题是如何在分类分布空间定义有效的扩散过程。主要有两种处理方式:

方式一:嵌入空间连续化

将one-hot向量投影到嵌入空间,在连续空间执行扩散:

其中 是词嵌入矩阵。模型预测嵌入空间中的去噪信号:

方式二:离散状态直接扩散

直接在分类分布空间定义马尔可夫过渡,保持状态的离散性质。这种方式需要特殊的过渡矩阵设计以确保数学上的良定义性。

词汇表的嵌入空间

词嵌入矩阵 建立了离散token空间与连续表示空间的桥梁:

其中每行 是词汇表中第 个token的 维嵌入向量。嵌入空间的几何性质对扩散模型性能有重要影响:

  • 语义相似性:语义相近的token在嵌入空间中距离较近
  • 分布均匀性:嵌入向量应均匀覆盖表示空间
  • 各向同性:嵌入分布应接近各向同性高斯分布

过渡动力学

离散时间马尔可夫链

表示时间步 的token。离散扩散的前向过程定义为一阶马尔可夫链:

其中 是过渡矩阵,满足:

完整的前向过程分布为:

对于任意时间步 ,边际分布可递归计算:

其中

过渡矩阵的性质

吸收态设计:Masked Diffusion Model使用特殊的过渡矩阵设计,将 [MASK] 状态作为吸收态:

其中第一个状态为 [MASK],其他 个状态对应原始token。参数 控制每步被mask的概率。

吸收时间的分布:设 为token被完全mask的时间步。对于独立token序列,被mask的比例服从二项分布:

细致平衡条件

对于连续扩散模型,细致平衡(Detailed Balance)条件确保存在可学习的反向过程:

在离散状态下,类似的平衡条件要求:

其中 是平稳分布。通过适当设计 ,可以使 接近均匀分布,从而满足边界条件。


变分推断框架

ELBO推导

证据下界(Evidence Lower Bound)

给定数据 ,我们希望最大化边际似然 。通过引入潜在变量 ,使用变分推断进行优化。

边际似然的分解

引入变分分布 进行重要性采样:

Jensen不等式应用:利用对数的凹性( 是凹函数),通过Jensen不等式得到下界:

等号成立当且仅当

ELBO的展开:定义证据下界(ELBO)为:

展开后得到:

进一步分解为:

简化的训练目标:在实践中,常使用简化的损失形式:

这个目标可以直接优化,而无需计算完整的ELBO。

Jensen不等式应用

Jensen不等式在变分推断中的核心作用体现在两个层面:

层面一:下界构造

对于任意分布

其中 必须是对数凹函数。当 时:

层面二:KL散度的非负性

从Jensen不等式可直接导出KL散度的非负性:

这一性质确保ELBO始终是边际似然的下界。

重参数化技巧

在离散扩散模型中,重参数化主要用于简化梯度估计和实现高效的变分推断。

连续空间的重参数化

给定 的后验分布 ,将其表示为确定性变换:

例如连续扩散中的形式:

离散去噪的蒙特卡洛估计:对于离散状态,梯度通过下式估计:

其中

训练目标

简化损失函数

从ELBO出发,通过一系列近似可得到可操作的训练目标。假设每个token独立扩散:

其中

等价的去噪形式:对于masked diffusion model,训练目标可写为:

其中 是时间步 对应的mask操作。

import torch
import torch.nn.functional as F
 
def diffusion_loss(model, x0, mask_schedule, vocab_size):
    """
    离散扩散模型的标准训练损失
    
    参数:
        model: 去噪网络 p_theta(x_0 | x_t, t)
        x0: 原始token序列 [batch, seq_len]
        mask_schedule: 每个时间步的mask概率
        vocab_size: 词汇表大小
    """
    batch_size, seq_len = x0.shape
    T = len(mask_schedule)
    
    # 采样时间步
    t = torch.randint(0, T, (batch_size,), device=x0.device)
    
    # 生成mask后的输入
    x_masked = mask_tokens(x0, t, mask_schedule)
    
    # 计算mask比例作为条件
    mask_ratio = mask_schedule[t].view(-1, 1, 1)  # [batch, 1, 1]
    
    # 归一化时间步 (0-1)
    t_normalized = t.float() / T
    
    # 模型预测
    logits = model(x_masked, t_normalized)  # [batch, seq_len, vocab_size]
    
    # 交叉熵损失 (仅在被mask位置计算)
    loss = F.cross_entropy(
        logits.view(-1, vocab_size), 
        x0.view(-1), 
        reduction='none'
    )
    
    return loss.mean()
 
 
def mask_tokens(x0, t, mask_schedule):
    """
    根据时间步生成mask后的序列
    
    每步独立决定是否mask每个token
    """
    batch_size, seq_len = x0.shape
    device = x0.device
    
    # 生成mask掩码
    mask_prob = mask_schedule[t].unsqueeze(1)  # [batch, 1]
    mask = torch.rand(batch_size, seq_len, device=device) < mask_prob
    
    # 创建mask token (假设vocab_size-1为[MASK] token)
    x_masked = x0.clone()
    x_masked[mask] = -100  # -100在交叉熵中表示忽略该位置
    
    return x_masked

对比不同参数化

参数化方式一: 预测

直接预测原始数据:

损失:

参数化方式二: 预测

预测添加的噪声(适用于连续空间):

参数化方式三: 预测

预测当前噪声状态:


与自回归模型的数学联系

边际似然对比

AR模型:分解的乘积形式

自回归模型将联合分布分解为条件概率的链式乘积:

对应的对数似然

序列建模的核心假设:每个token的条件分布只依赖于前面的token,即因果马尔可夫性质

扩散模型:层次化变量模型

扩散模型引入潜在变量 建立层次化结构:

边际似然的变分下界

两种模型的概率图模型对比

自回归模型 (链式结构):
x_1 → x_2 → x_3 → ... → x_n
 ↓     ↓     ↓
 p     p     p
(条件分布)

扩散模型 (层次结构):
       z_T → z_{T-1} → ... → z_1 → z_0 = x_0
       ↓      ↓              ↓     ↓
      p     p              p     p
    (先验)  (过渡)        (过渡) (重建)

推断复杂度分析

前向过程的高效计算

AR模型:生成 个token需要 次前向传播,每次依赖前一次的结果:

扩散模型前向过程

  • 前向加噪过程可以预先计算: 可直接闭式计算
  • 时间步采样是并行的: 次前向传播计算任意
# 连续扩散的闭式计算
def forward_process_closed_form(x0, alphas_bar, epsilon):
    """
    闭式计算前向过程
    
    x_t = sqrt(alphas_bar[t]) * x0 + sqrt(1 - alphas_bar[t]) * epsilon
    """
    return torch.sqrt(alphas_bar) * x0 + torch.sqrt(1 - alphas_bar) * epsilon

关键性质:前向过程无需迭代计算,可直接获得任意时间步的分布。

反向过程的近似推断

AR模型:推断过程即生成过程,精确但顺序执行:

扩散模型:反向过程需要通过学习或近似推断:

训练阶段的效率:两者均可并行训练。AR模型使用teacher forcing,扩散模型使用随机时间步:

阶段AR模型扩散模型
训练 并行(teacher forcing) 并行(随机时间步)
推断 顺序 顺序(可并行)

推断阶段的权衡:若 ,扩散模型在推断时可能具有优势。


生成质量与效率的理论分析

似然 vs 生成质量的权衡

步数与生成质量的关系

扩散模型的生成质量与采样步数 密切相关。理论分析表明:

渐近质量:当 时,在适当的平滑假设下,反向过程可精确恢复数据分布:

有限步数的误差界:设 为每步过渡的近似误差,则累积误差满足:

最优步数选择:实践中发现存在”临界步数”,超过该步数后增加步数的边际收益递减:

采样策略的理论分析

确定性去噪:DDPM使用确定性均值估计:

随机采样:DDIM等方法引入随机性,允许非马尔可夫反向过程:

def ddim_step(model, x_t, t, prev_t, eta=0.0):
    """
    DDIM采样步骤
    
    eta=0: 确定性 (DDIM)
    eta=1: 随机性 (DDPM)
    """
    alpha_t = compute_alpha(t)
    alpha_prev = compute_alpha(prev_t)
    
    # 预测噪声
    eps = model(x_t, t)
    
    # 估计原始数据
    x0_hat = (x_t - sqrt(1-alpha_t) * eps) / sqrt(alpha_t)
    
    # 方向估计
    pred_x0 = x0_hat.clamp(-1, 1)
    dir_xt = sqrt(1 - alpha_prev - eta * sigma_t**2) * eps
    
    # 随机成分
    sigma_t = eta * sqrt((1 - alpha_prev) / (1 - alpha_t)) * sqrt(1 - alpha_t / alpha_prev)
    
    x_prev = sqrt(alpha_prev) * pred_x0 + dir_xt + sigma_t * torch.randn_like(x_t)
    
    return x_prev

双向注意力的理论优势

上下文利用效率

扩散语言模型在每步去噪时利用完整序列的上下文信息,这与自回归模型形成对比:

AR模型的上下文利用

生成第 个token时,完全无法利用 的信息。

扩散模型的上下文利用

生成时利用完整序列的双向依赖关系。

互信息理论界:设 ,则:

其中 是序列内部依赖的互信息。扩散模型直接建模 ,而AR模型将其分解为单向项的组合。

条件生成的理论分析

无条件生成:从先验分布 开始,渐进式去噪:

条件生成:给定条件 (如提示文本),修改反向过程:

无条件 vs 条件生成的误差分析

扩散模型的条件生成通过修改网络输入或引导机制实现,其理论保证依赖于条件与数据分布的匹配程度。


参考

Footnotes

  1. LLaDA: Large Language Diffusion Models

  2. Masked Diffusion Models: A Study on Discrete Token Prediction