1. 研究背景与问题定义
1.1 语言建模的两种范式
当前语言建模主要分为两种范式1:
| 范式 | 特点 | 代表模型 | 优势 | 局限 |
|---|---|---|---|---|
| 自回归 (AR) | 逐token生成 | GPT, LLaMA | 生成质量高 | 延迟 |
| 扩散 (Discrete) | 逐步去噪 | LLaDA, MDM | 可并行 | 离散空间建模困难 |
1.2 离散扩散的挑战
离散扩散模型在语言建模中面临:
- 嵌入空间不连续:离散token之间的语义关系难以捕捉
- 码本利用率低:大多数token被映射到少量活跃码字
- 梯度估计困难:离散操作不可导
1.3 连续潜空间的解决思路
字节跳动Seed团队的论文《Continuous Latent Diffusion Language Model》提出1:
核心思想:在连续潜空间进行扩散建模,避免离散空间的局限性。
2. 技术框架
2.1 连续潜空间扩散
整体架构:
┌─────────────────────────────────────────────────────────────────────────┐
│ 连续潜空间扩散语言模型 │
├─────────────────────────────────────────────────────────────────────────┤
│ │
│ 文本输入: "The cat sat on the mat" │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ Tokenizer (VQ-VAE) │ │
│ │ │ │
│ │ 文本 ──► 语义嵌入 ──► 量化 ──► 离散码序列 │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ 连续潜向量 (用于扩散) │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 连续扩散过程 (Continuous Diffusion) │ │
│ │ │ │
│ │ z_0 (潜向量) ──► z_t ──► z_T ──► 去噪 ──► z_0 │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────────────────┐ │
│ │ 重建与输出 │ │
│ │ │ │
│ │ 潜向量 ──► 解码器 ──► 文本输出 │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────────┘
2.2 连续潜空间定义
设原始文本序列为 ,tokenizer将其映射为:
其中 是连续语义嵌入,然后通过向量量化得到离散码:
2.3 连续扩散过程
正向过程(加噪):
逆向过程(去噪):
3. 核心创新
3.1 连续vs离散对比
| 特性 | 离散扩散 (LLaDA) | 连续扩散 (CLDM) |
|---|---|---|
| 表示空间 | 离散码本 | 连续潜空间 |
| 嵌入连续性 | 断开 | 保持 |
| 梯度流动 | 困难(straight-through) | 自然 |
| 去噪目标 | 分类 | 回归 |
| 计算效率 | 中等 | 高 |
3.2 语义流形假设
假设:自然语言的语义嵌入位于一个低维流形上。
定理(流形假设):设 是语义流形,则:
- 语义相似的句子在流形上距离近
- 扩散过程应保持在流形附近
连续扩散的优势:
当 时, 接近高斯噪声;当 时, 接近数据流形。
3.3 自编码器-扩散联合
class ContinuousLatentDiffusionLM(nn.Module):
"""
连续潜空间扩散语言模型
"""
def __init__(self, config):
super().__init__()
# 语义编码器
self.encoder = SemanticEncoder(config)
# 量化器(可选,用于保留离散码本)
self.quantizer = ResidualVectorQuantizer(config)
# 扩散模型
self.diffusion = DiffusionModel(config)
# 解码器
self.decoder = SemanticDecoder(config)
def encode(self, x):
"""编码文本为连续潜向量"""
z = self.encoder(x)
return z
def decode(self, z):
"""解码潜向量为文本"""
return self.decoder(z)
def forward(self, x, t):
"""
前向传播
"""
# 编码
z0 = self.encode(x)
# 加噪
noise = torch.randn_like(z0)
zt = torch.sqrt(self.alphas[t]) * z0 + torch.sqrt(1 - self.alphas[t]) * noise
# 去噪预测
z0_pred = self.diffusion(zt, t)
return z0_pred, z0
def training_loss(self, x, t):
"""训练损失"""
z0_pred, z0 = self.forward(x, t)
# MSE损失(连续空间)
loss = F.mse_loss(z0_pred, z0)
return loss4. 技术细节
4.1 语义编码器
class SemanticEncoder(nn.Module):
"""
语义编码器
将文本编码为连续的语义向量
"""
def __init__(self, config):
super().__init__()
self.embedding = nn.Embedding(config.vocab_size, config.embed_dim)
self.encoder_layers = nn.ModuleList([
SemanticTransformerLayer(config)
for _ in range(config.num_layers)
])
# 投影到潜空间
self.proj = nn.Linear(config.embed_dim, config.latent_dim)
def forward(self, x):
"""
Args:
x: token序列 [B, N]
Returns:
连续潜向量 [B, N, latent_dim]
"""
# Token嵌入
h = self.embedding(x)
# Transformer处理
for layer in self.encoder_layers:
h = layer(h)
# 投影到潜空间
z = self.proj(h)
return z4.2 残余向量量化
为了保持与离散方法的兼容性,使用残余向量量化:
class ResidualVectorQuantizer(nn.Module):
"""
残余向量量化器
将连续向量量化为离散码本
"""
def __init__(self, config):
super().__init__()
self.codebook_size = config.codebook_size
self.latent_dim = config.latent_dim
self.num_quantizers = config.num_quantizers
# 多个码本
self.codebooks = nn.ModuleList([
nn.Embedding(config.codebook_size, config.latent_dim)
for _ in range(config.num_quantizers)
])
def forward(self, z, temperature=1.0):
"""
Args:
z: 连续潜向量 [B, N, latent_dim]
temperature: 采样温度
Returns:
quantized: 量化后的向量
indices: 码本索引
"""
quantized = torch.zeros_like(z)
indices = []
residual = z
for codebook in self.codebooks:
# 计算到各码字的距离
dist = torch.cdist(residual, codebook.weight) # [B, N, K]
# 采样
if self.training:
# Gumbel-softmax采样
gumbels = -torch.log(-torch.log(torch.rand_like(dist) + 1e-8))
soft_idx = F.softmax((dist + gumbels) / temperature, dim=-1)
hard_idx = soft_idx.argmax(dim=-1)
else:
hard_idx = dist.argmin(dim=-1)
indices.append(hard_idx)
# 获取码字向量
code = F.embedding(hard_idx, codebook.weight)
quantized = quantized + code
# 计算残余
residual = residual - code
return quantized, torch.stack(indices, dim=-1)4.3 连续扩散模型
class DiffusionModel(nn.Module):
"""
连续扩散模型
在连续潜空间进行去噪
"""
def __init__(self, config):
super().__init__()
# 时间嵌入
self.time_embed = SinusoidalPositionEmbedding(config.latent_dim)
# 去噪网络
self.denoiser = nn.Sequential(
nn.Linear(config.latent_dim, config.hidden_dim),
nn.GELU(),
*[
ResidualBlock(config)
for _ in range(config.num_blocks)
],
nn.Linear(config.hidden_dim, config.latent_dim)
)
def forward(self, zt, t):
"""
Args:
zt: 加噪后的潜向量 [B, N, latent_dim]
t: 时间步 [B]
Returns:
预测的z0
"""
# 时间嵌入
t_emb = self.time_embed(t) # [B, latent_dim]
# 融入时间信息
h = zt + t_emb.unsqueeze(1)
# 去噪
z0_pred = self.denoiser(h)
return z0_pred5. 训练策略
5.1 课程学习
训练分为多个阶段:
def curriculum_training(model, dataloader, config):
"""
课程学习训练
逐步增加噪声强度
"""
for stage in range(config.num_stages):
# 当前阶段的噪声调度
num_steps = config.stage_steps[stage]
betas = cosine_beta_schedule(num_steps, s=0.008)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
model.diffusion.set_schedule(betas, alphas, alphas_cumprod)
# 训练
for epoch in range(config.stage_epochs[stage]):
for batch in dataloader:
# 采样时间步
t = torch.randint(0, num_steps, (batch_size,))
# 训练
loss = model.training_loss(batch, t)
# ...5.2 损失函数
连续扩散使用简单的MSE损失:
或预测 的损失:
6. 实验结果
6.1 语言建模
困惑度对比:
| 模型 | WikiText-103 | Penn Treebank | OpenWebText |
|---|---|---|---|
| LLaMA-7B (AR) | 12.3 | 15.2 | 14.1 |
| LLaDA-7B (离散) | 13.1 | 16.2 | 15.1 |
| CLDM-7B (连续) | 11.8 | 14.5 | 13.6 |
6.2 生成质量
人类评估:
| 模型 | 流畅性 | 一致性 | 多样性 |
|---|---|---|---|
| LLaMA | 4.2 | 4.5 | 3.8 |
| LLaDA | 3.9 | 4.1 | 4.2 |
| CLDM | 4.3 | 4.4 | 4.1 |
6.3 推理效率
生成延迟(序列长度=512):
| 模型 | 延迟 (ms) | 加速比 |
|---|---|---|
| LLaMA (AR) | 1250 | 1.0x |
| LLaDA (离散) | 680 | 1.8x |
| CLDM | 520 | 2.4x |
7. 与其他方法的关系
7.1 vs 自回归模型
| 特性 | AR LM | CLDM |
|---|---|---|
| 生成方式 | 顺序 | 并行 |
| 延迟 | ||
| 全局一致性 | 隐式 | 显式 |
| 训练效率 | 高 | 中 |
7.2 vs 离散扩散
| 特性 | 离散扩散 | 连续扩散 |
|---|---|---|
| 空间类型 | 离散 | 连续 |
| 梯度 | 近似 | 精确 |
| 损失类型 | 交叉熵 | MSE |
| 码本需求 | 必须 | 可选 |
8. 代码实现
8.1 完整模型
class CLDMLanguageModel(nn.Module):
"""
连续潜空间扩散语言模型完整实现
"""
def __init__(self, config):
super().__init__()
# 编码器
self.encoder = SemanticEncoder(config)
# 量化器(可选)
self.use_quantizer = config.use_quantizer
if self.use_quantizer:
self.quantizer = ResidualVectorQuantizer(config)
# 扩散模型
self.diffusion = DiffusionModel(config)
# 解码器
self.decoder = SemanticDecoder(config)
# 噪声调度
self.register_buffer('betas', None)
self.register_buffer('alphas', None)
self.register_buffer('alphas_cumprod', None)
def set_schedule(self, betas, alphas, alphas_cumprod):
"""设置噪声调度"""
self.betas = betas
self.alphas = alphas
self.alphas_cumprod = alphas_cumprod
@torch.no_grad()
def generate(self, batch_size, max_len, temperature=1.0):
"""生成文本"""
# 初始化:纯噪声
zT = torch.randn(batch_size, max_len, self.latent_dim, device=next(self.parameters()).device)
# 逐步去噪
for t in reversed(range(len(self.betas))):
# 预测z0
z0_pred = self.diffusion(zT, torch.tensor([t] * batch_size, device=zT.device))
# 计算zt-1
alpha_t = self.alphas[t]
alpha_bar_t = self.alphas_cumprod[t]
alpha_bar_t_1 = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
# 去噪
zT_1 = (zT - torch.sqrt(1 - alpha_bar_t) * self.diffusion.denoiser(zT, t)) / torch.sqrt(alpha_bar_t)
zT_1 = zT_1 * torch.sqrt(alpha_bar_t_1) + torch.randn_like(zT) * torch.sqrt(1 - alpha_bar_t_1)
zT = zT_1
# 解码
return self.decoder(zT)9. 总结与展望
9.1 主要贡献
- 连续潜空间:避免离散空间建模的局限性
- 高效训练:简单的MSE损失,自然的梯度流动
- 快速推理:并行生成, 延迟
9.2 局限性
- 语义流形假设:依赖嵌入空间的假设
- 生成质量:某些场景可能不如AR模型
- 实现复杂度:多组件联合训练
9.3 未来方向
- 更有效的潜空间表示
- 与AR模型的结合
- 多模态扩展