概述

LLaDA(Large Language Diffusion with mAsking)是首个从零开始训练的大规模离散扩散语言模型,由上海人工智能实验室等机构提出。1 该工作挑战了”LLaM级别的语言能力必须依赖自回归建模”的传统观点,证明了扩散模型同样可以胜任大规模语言建模任务。

核心贡献

  1. 首个8B参数扩散语言模型:从零预训练,而非从AR模型蒸馏
  2. 强上下文学习能力:在ICL任务上与LLaMA3 8B相当
  3. 解决Reversal Curse:在反向诗句补全任务上超越GPT-4o
  4. 双向建模优势:利用完整上下文信息

架构设计

Masked Diffusion Model框架

LLaDA采用标准的Masked Diffusion Model(MDM)框架:

┌──────────────────────────────────────────────────────────────┐
│  LLaDA训练范式                                               │
│                                                              │
│  前向过程(数据处理):                                        │
│  x₀ (原始文本) → x₁ → x₂ → ... → x_T (全MASK)               │
│     ↓           ↓                             ↓             │
│   保留      10%MASK    20%MASK         100%MASK              │
│                                                              │
│  反向过程(模型预测):                                        │
│  x_T → p_θ(x_{T-1}|x_T) → ... → x₀                       │
│                                                              │
│  训练目标:最大化原始数据的似然下界                            │
└──────────────────────────────────────────────────────────────┘

网络架构

LLaDA使用标准Transformer Encoder作为骨干网络:

组件配置
模型规模8B参数
隐藏维度4096
注意力头数32
层数32
上下文长度4096
词汇表大小~32K BPE
class LLaDAConfig:
    vocab_size = 32000
    d_model = 4096       # 隐藏维度
    n_heads = 32          # 注意力头数
    n_layers = 32         # Transformer层数
    max_len = 4096        # 最大序列长度
    
class LLaDA(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.tok_embed = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_embed = nn.Embedding(config.max_len, config.d_model)
        self.time_embed = nn.Embedding(1000, config.d_model)  # 时间步嵌入
        
        self.layers = nn.ModuleList([
            TransformerLayer(config) for _ in range(config.n_layers)
        ])
        
        self.output_proj = nn.Linear(config.d_model, config.vocab_size)
        
    def forward(self, x_masked, t):
        # x_masked: 已部分mask的token序列
        # t: 归一化时间步 (0-1)
        h = self.tok_embed(x_masked) + self.pos_embed.weight[:len(x_masked)]
        h = h + self.time_embed((t * 999).long())
        
        for layer in self.layers:
            h = layer(h)
            
        return self.output_proj(h)

时间步嵌入

与图像扩散模型类似,LLaDA使用时间步嵌入来调制表示:

class TimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )
        
    def forward(self, t):
        # t: 形状 (batch,) , 值在 [0, 1]
        half_dim = self.dim // 2
        emb = torch.log(torch.tensor(10000)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim) * -emb).to(t.device)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return self.mlp(emb)

训练流程

1. 预训练

预训练数据:2.3T tokens
训练硬件:13万H800 GPU小时
优化器:AdamW
学习率:余弦调度,峰值 4e-4

预训练配置

参数
批量大小4M tokens
序列长度4096
训练步数~575K
峰值学习率4e-4
权重衰减0.1
预热步数2000

2. 监督微调(SFT)

# SFT阶段:使用指令微调数据
sft_data = [
    {"prompt": "什么是机器学习?", "response": "机器学习是..."},
    {"prompt": "写一首诗", "response": "春风又绿江南岸..."},
    ...
]
 
def sft_collate(batch):
    # 构建对话格式
    text = "<|user|>" + batch["prompt"] + "<|assistant|>" + batch["response"]
    return tokenize_and_mask(text)

3. 无问答对构建

LLaDA使用特殊的无问答对数据构建方法:

┌─────────────────────────────────────────────────────────────┐
│  无问答对训练示例                                            │
│                                                              │
│  输入:对话历史                                              │
│  "用户:今天天气如何?\n助手:今天晴朗温暖。"                  │
│                                                              │
│  训练方式:                                                  │
│  1. 将"助手:"作为prompt,固定不动                           │
│  2. 模型学习生成"今天晴朗温暖。"                             │
│  3. 通过mask控制只预测assistant回复部分                       │
└─────────────────────────────────────────────────────────────┘

实验结果

1. 语言建模性能

模型PPL (WikiText-103)PPL (Penn Treebank)
LLaDA 8B12.818.5
LLaMA3 8B12.618.2
LLaMA2 7B14.220.8

LLaDA 8B在语言建模困惑度上与同规模AR模型相当。

2. 上下文学习(ICL)

LLaDA在上下文学习任务上展现了与LLaMA3 8B相当的性能:

任务LLaDA 8BLLaMA3 8B
LAMBADA (Accuracy)67.2%68.2%
HellaSwag76.8%77.2%
WinoGrande70.1%71.3%

3. Reversal Curse解决

Reversal Curse指的是模型难以学习”A是B”形式知识的反向版本(如”奥赛罗是《威尼斯商人》的作者”→“谁写了《威尼斯商人》?”)。

模型Reversal Poem Completion
LLaDA 8B82.3%
GPT-4o71.5%
GPT-465.8%

LLaDA的双向建模有效缓解了Reversal Curse问题。

4. 指令跟随

经过SFT后,LLaDA展现出良好的指令跟随能力:

# 示例交互
prompt = """<|user|>帮我写一个快速排序算法,用Python实现<|assistant|>"""
 
# 生成结果
result = llada.generate(prompt, max_len=500, temperature=0.8)
print(result)
# 输出:当然可以!以下是Python实现的快速排序算法...

关键发现与洞察

1. 扩散模型可扩展性

训练损失曲线:

   3.5 ┤                        ╭─╮
       │                   ╭───╯  ╰───╮
   3.0 ┤              ╭──╯            ╰───╮
       │         ╭───╯                      ╰────
   2.5 ┤    ╭───╯
       │ ───╯
   2.0 ┤ ──
       └────┬────┬────┬────┬────┬────┬────┬────→ 训练步数
             100K  200K  300K  400K  500K

LLaDA展现出与AR模型类似的Scaling行为,损失随模型规模平稳下降。

2. 双向注意力的优势

与AR模型相比,MDM在以下场景具有优势:

  • 填充任务:需要利用双向上下文
  • 条件生成:给定部分内容生成其余
  • Reversal Curse缓解:双向建模有助于学习双向关联

3. 训练稳定性

# 训练稳定性技巧
class StableTraining:
    # 1. 渐进式mask调度
    mask_schedule = "linear"  # 从少到多的mask
    
    # 2. 梯度裁剪
    max_grad_norm = 1.0
    
    # 3. 混合精度训练
    dtype = torch.bfloat16
    
    # 4. 学习率调度
    scheduler = "cosine"  # 余弦退火

与LLaMA对比分析

维度LLaDA (扩散)LLaMA (AR)
建模方式双向mask预测单向自回归
生成方式并行解mask顺序自回归
上下文利用完整上下文仅前缀
ICL能力相当相当
Reversal处理更好较弱
推理效率T步并行N步顺序
训练复杂度较高标准

代码实现框架

class LLaDAModel:
    def __init__(self, config):
        self.config = config
        self.model = TransformerForMaskedLM(config)
        
    def training_step(self, batch):
        x0 = batch["input_ids"]  # 原始token
        
        # 随机采样时间步
        t = torch.rand(len(x0))
        
        # 前向mask过程
        x_masked = self.mask_tokens(x0, t)
        
        # 模型预测
        logits = self.model(x_masked, t)
        
        # 交叉熵损失
        loss = F.cross_entropy(
            logits.view(-1, self.config.vocab_size),
            x0.view(-1)
        )
        
        return loss
    
    @torch.no_grad()
    def generate(self, prompt, max_len=100, temperature=1.0):
        x = [MASK] * max_len
        x[:len(prompt)] = prompt
        
        for step in range(self.num_diffusion_steps):
            t = step / self.num_diffusion_steps
            logits = self.model(x, t)
            
            # 采样并更新
            probs = F.softmax(logits / temperature, dim=-1)
            x = torch.multinomial(probs, 1).squeeze(-1)
            
        return x

参考

Footnotes

  1. LLaDA: Large Language Diffusion Models