概述
LLaDA(Large Language Diffusion with mAsking)是首个从零开始训练的大规模离散扩散语言模型,由上海人工智能实验室等机构提出。1 该工作挑战了”LLaM级别的语言能力必须依赖自回归建模”的传统观点,证明了扩散模型同样可以胜任大规模语言建模任务。
核心贡献
- 首个8B参数扩散语言模型:从零预训练,而非从AR模型蒸馏
- 强上下文学习能力:在ICL任务上与LLaMA3 8B相当
- 解决Reversal Curse:在反向诗句补全任务上超越GPT-4o
- 双向建模优势:利用完整上下文信息
架构设计
Masked Diffusion Model框架
LLaDA采用标准的Masked Diffusion Model(MDM)框架:
┌──────────────────────────────────────────────────────────────┐
│ LLaDA训练范式 │
│ │
│ 前向过程(数据处理): │
│ x₀ (原始文本) → x₁ → x₂ → ... → x_T (全MASK) │
│ ↓ ↓ ↓ │
│ 保留 10%MASK 20%MASK 100%MASK │
│ │
│ 反向过程(模型预测): │
│ x_T → p_θ(x_{T-1}|x_T) → ... → x₀ │
│ │
│ 训练目标:最大化原始数据的似然下界 │
└──────────────────────────────────────────────────────────────┘
网络架构
LLaDA使用标准Transformer Encoder作为骨干网络:
| 组件 | 配置 |
|---|---|
| 模型规模 | 8B参数 |
| 隐藏维度 | 4096 |
| 注意力头数 | 32 |
| 层数 | 32 |
| 上下文长度 | 4096 |
| 词汇表大小 | ~32K BPE |
class LLaDAConfig:
vocab_size = 32000
d_model = 4096 # 隐藏维度
n_heads = 32 # 注意力头数
n_layers = 32 # Transformer层数
max_len = 4096 # 最大序列长度
class LLaDA(nn.Module):
def __init__(self, config):
super().__init__()
self.tok_embed = nn.Embedding(config.vocab_size, config.d_model)
self.pos_embed = nn.Embedding(config.max_len, config.d_model)
self.time_embed = nn.Embedding(1000, config.d_model) # 时间步嵌入
self.layers = nn.ModuleList([
TransformerLayer(config) for _ in range(config.n_layers)
])
self.output_proj = nn.Linear(config.d_model, config.vocab_size)
def forward(self, x_masked, t):
# x_masked: 已部分mask的token序列
# t: 归一化时间步 (0-1)
h = self.tok_embed(x_masked) + self.pos_embed.weight[:len(x_masked)]
h = h + self.time_embed((t * 999).long())
for layer in self.layers:
h = layer(h)
return self.output_proj(h)时间步嵌入
与图像扩散模型类似,LLaDA使用时间步嵌入来调制表示:
class TimeEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.SiLU(),
nn.Linear(dim * 4, dim)
)
def forward(self, t):
# t: 形状 (batch,) , 值在 [0, 1]
half_dim = self.dim // 2
emb = torch.log(torch.tensor(10000)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim) * -emb).to(t.device)
emb = t[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
return self.mlp(emb)训练流程
1. 预训练
预训练数据:2.3T tokens
训练硬件:13万H800 GPU小时
优化器:AdamW
学习率:余弦调度,峰值 4e-4
预训练配置:
| 参数 | 值 |
|---|---|
| 批量大小 | 4M tokens |
| 序列长度 | 4096 |
| 训练步数 | ~575K |
| 峰值学习率 | 4e-4 |
| 权重衰减 | 0.1 |
| 预热步数 | 2000 |
2. 监督微调(SFT)
# SFT阶段:使用指令微调数据
sft_data = [
{"prompt": "什么是机器学习?", "response": "机器学习是..."},
{"prompt": "写一首诗", "response": "春风又绿江南岸..."},
...
]
def sft_collate(batch):
# 构建对话格式
text = "<|user|>" + batch["prompt"] + "<|assistant|>" + batch["response"]
return tokenize_and_mask(text)3. 无问答对构建
LLaDA使用特殊的无问答对数据构建方法:
┌─────────────────────────────────────────────────────────────┐
│ 无问答对训练示例 │
│ │
│ 输入:对话历史 │
│ "用户:今天天气如何?\n助手:今天晴朗温暖。" │
│ │
│ 训练方式: │
│ 1. 将"助手:"作为prompt,固定不动 │
│ 2. 模型学习生成"今天晴朗温暖。" │
│ 3. 通过mask控制只预测assistant回复部分 │
└─────────────────────────────────────────────────────────────┘
实验结果
1. 语言建模性能
| 模型 | PPL (WikiText-103) | PPL (Penn Treebank) |
|---|---|---|
| LLaDA 8B | 12.8 | 18.5 |
| LLaMA3 8B | 12.6 | 18.2 |
| LLaMA2 7B | 14.2 | 20.8 |
LLaDA 8B在语言建模困惑度上与同规模AR模型相当。
2. 上下文学习(ICL)
LLaDA在上下文学习任务上展现了与LLaMA3 8B相当的性能:
| 任务 | LLaDA 8B | LLaMA3 8B |
|---|---|---|
| LAMBADA (Accuracy) | 67.2% | 68.2% |
| HellaSwag | 76.8% | 77.2% |
| WinoGrande | 70.1% | 71.3% |
3. Reversal Curse解决
Reversal Curse指的是模型难以学习”A是B”形式知识的反向版本(如”奥赛罗是《威尼斯商人》的作者”→“谁写了《威尼斯商人》?”)。
| 模型 | Reversal Poem Completion |
|---|---|
| LLaDA 8B | 82.3% |
| GPT-4o | 71.5% |
| GPT-4 | 65.8% |
LLaDA的双向建模有效缓解了Reversal Curse问题。
4. 指令跟随
经过SFT后,LLaDA展现出良好的指令跟随能力:
# 示例交互
prompt = """<|user|>帮我写一个快速排序算法,用Python实现<|assistant|>"""
# 生成结果
result = llada.generate(prompt, max_len=500, temperature=0.8)
print(result)
# 输出:当然可以!以下是Python实现的快速排序算法...关键发现与洞察
1. 扩散模型可扩展性
训练损失曲线:
3.5 ┤ ╭─╮
│ ╭───╯ ╰───╮
3.0 ┤ ╭──╯ ╰───╮
│ ╭───╯ ╰────
2.5 ┤ ╭───╯
│ ───╯
2.0 ┤ ──
└────┬────┬────┬────┬────┬────┬────┬────→ 训练步数
100K 200K 300K 400K 500K
LLaDA展现出与AR模型类似的Scaling行为,损失随模型规模平稳下降。
2. 双向注意力的优势
与AR模型相比,MDM在以下场景具有优势:
- 填充任务:需要利用双向上下文
- 条件生成:给定部分内容生成其余
- Reversal Curse缓解:双向建模有助于学习双向关联
3. 训练稳定性
# 训练稳定性技巧
class StableTraining:
# 1. 渐进式mask调度
mask_schedule = "linear" # 从少到多的mask
# 2. 梯度裁剪
max_grad_norm = 1.0
# 3. 混合精度训练
dtype = torch.bfloat16
# 4. 学习率调度
scheduler = "cosine" # 余弦退火与LLaMA对比分析
| 维度 | LLaDA (扩散) | LLaMA (AR) |
|---|---|---|
| 建模方式 | 双向mask预测 | 单向自回归 |
| 生成方式 | 并行解mask | 顺序自回归 |
| 上下文利用 | 完整上下文 | 仅前缀 |
| ICL能力 | 相当 | 相当 |
| Reversal处理 | 更好 | 较弱 |
| 推理效率 | T步并行 | N步顺序 |
| 训练复杂度 | 较高 | 标准 |
代码实现框架
class LLaDAModel:
def __init__(self, config):
self.config = config
self.model = TransformerForMaskedLM(config)
def training_step(self, batch):
x0 = batch["input_ids"] # 原始token
# 随机采样时间步
t = torch.rand(len(x0))
# 前向mask过程
x_masked = self.mask_tokens(x0, t)
# 模型预测
logits = self.model(x_masked, t)
# 交叉熵损失
loss = F.cross_entropy(
logits.view(-1, self.config.vocab_size),
x0.view(-1)
)
return loss
@torch.no_grad()
def generate(self, prompt, max_len=100, temperature=1.0):
x = [MASK] * max_len
x[:len(prompt)] = prompt
for step in range(self.num_diffusion_steps):
t = step / self.num_diffusion_steps
logits = self.model(x, t)
# 采样并更新
probs = F.softmax(logits / temperature, dim=-1)
x = torch.multinomial(probs, 1).squeeze(-1)
return x