1. 研究背景与问题定义

1.1 语言建模的两种范式

当前语言建模主要分为两种范式1

范式特点代表模型优势局限
自回归 (AR)逐token生成GPT, LLaMA生成质量高 延迟
扩散 (Discrete)逐步去噪LLaDA, MDM可并行离散空间建模困难

1.2 离散扩散的挑战

离散扩散模型在语言建模中面临:

  1. 嵌入空间不连续:离散token之间的语义关系难以捕捉
  2. 码本利用率低:大多数token被映射到少量活跃码字
  3. 梯度估计困难:离散操作不可导

1.3 连续潜空间的解决思路

字节跳动Seed团队的论文《Continuous Latent Diffusion Language Model》提出1

核心思想:在连续潜空间进行扩散建模,避免离散空间的局限性。

2. 技术框架

2.1 连续潜空间扩散

整体架构

┌─────────────────────────────────────────────────────────────────────────┐
│                    连续潜空间扩散语言模型                                │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  文本输入: "The cat sat on the mat"                                    │
│      │                                                                     │
│      ▼                                                                     │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                    Tokenizer (VQ-VAE)                            │    │
│  │                                                                 │    │
│  │   文本 ──► 语义嵌入 ──► 量化 ──► 离散码序列                      │    │
│  │               │                                                    │    │
│  │               ▼                                                    │    │
│  │            连续潜向量 (用于扩散)                                   │    │
│  │                                                                 │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│      │                                                                     │
│      ▼                                                                     │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                 连续扩散过程 (Continuous Diffusion)                 │    │
│  │                                                                 │    │
│  │   z_0 (潜向量) ──► z_t ──► z_T ──► 去噪 ──► z_0               │    │
│  │                                                                 │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│      │                                                                     │
│      ▼                                                                     │
│  ┌─────────────────────────────────────────────────────────────────┐    │
│  │                    重建与输出                                      │    │
│  │                                                                 │    │
│  │   潜向量 ──► 解码器 ──► 文本输出                                  │    │
│  │                                                                 │    │
│  └─────────────────────────────────────────────────────────────────┘    │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘

2.2 连续潜空间定义

设原始文本序列为 ,tokenizer将其映射为:

其中 是连续语义嵌入,然后通过向量量化得到离散码:

2.3 连续扩散过程

正向过程(加噪)

逆向过程(去噪)

3. 核心创新

3.1 连续vs离散对比

特性离散扩散 (LLaDA)连续扩散 (CLDM)
表示空间离散码本连续潜空间
嵌入连续性断开保持
梯度流动困难(straight-through)自然
去噪目标分类回归
计算效率中等

3.2 语义流形假设

假设:自然语言的语义嵌入位于一个低维流形上。

定理(流形假设):设 是语义流形,则:

  • 语义相似的句子在流形上距离近
  • 扩散过程应保持在流形附近

连续扩散的优势

时, 接近高斯噪声;当 时, 接近数据流形。

3.3 自编码器-扩散联合

class ContinuousLatentDiffusionLM(nn.Module):
    """
    连续潜空间扩散语言模型
    """
    def __init__(self, config):
        super().__init__()
        
        # 语义编码器
        self.encoder = SemanticEncoder(config)
        
        # 量化器(可选,用于保留离散码本)
        self.quantizer = ResidualVectorQuantizer(config)
        
        # 扩散模型
        self.diffusion = DiffusionModel(config)
        
        # 解码器
        self.decoder = SemanticDecoder(config)
        
    def encode(self, x):
        """编码文本为连续潜向量"""
        z = self.encoder(x)
        return z
    
    def decode(self, z):
        """解码潜向量为文本"""
        return self.decoder(z)
    
    def forward(self, x, t):
        """
        前向传播
        """
        # 编码
        z0 = self.encode(x)
        
        # 加噪
        noise = torch.randn_like(z0)
        zt = torch.sqrt(self.alphas[t]) * z0 + torch.sqrt(1 - self.alphas[t]) * noise
        
        # 去噪预测
        z0_pred = self.diffusion(zt, t)
        
        return z0_pred, z0
    
    def training_loss(self, x, t):
        """训练损失"""
        z0_pred, z0 = self.forward(x, t)
        
        # MSE损失(连续空间)
        loss = F.mse_loss(z0_pred, z0)
        
        return loss

4. 技术细节

4.1 语义编码器

class SemanticEncoder(nn.Module):
    """
    语义编码器
    将文本编码为连续的语义向量
    """
    def __init__(self, config):
        super().__init__()
        
        self.embedding = nn.Embedding(config.vocab_size, config.embed_dim)
        
        self.encoder_layers = nn.ModuleList([
            SemanticTransformerLayer(config)
            for _ in range(config.num_layers)
        ])
        
        # 投影到潜空间
        self.proj = nn.Linear(config.embed_dim, config.latent_dim)
        
    def forward(self, x):
        """
        Args:
            x: token序列 [B, N]
        Returns:
            连续潜向量 [B, N, latent_dim]
        """
        # Token嵌入
        h = self.embedding(x)
        
        # Transformer处理
        for layer in self.encoder_layers:
            h = layer(h)
        
        # 投影到潜空间
        z = self.proj(h)
        
        return z

4.2 残余向量量化

为了保持与离散方法的兼容性,使用残余向量量化:

class ResidualVectorQuantizer(nn.Module):
    """
    残余向量量化器
    将连续向量量化为离散码本
    """
    def __init__(self, config):
        super().__init__()
        
        self.codebook_size = config.codebook_size
        self.latent_dim = config.latent_dim
        self.num_quantizers = config.num_quantizers
        
        # 多个码本
        self.codebooks = nn.ModuleList([
            nn.Embedding(config.codebook_size, config.latent_dim)
            for _ in range(config.num_quantizers)
        ])
        
    def forward(self, z, temperature=1.0):
        """
        Args:
            z: 连续潜向量 [B, N, latent_dim]
            temperature: 采样温度
        Returns:
            quantized: 量化后的向量
            indices: 码本索引
        """
        quantized = torch.zeros_like(z)
        indices = []
        
        residual = z
        
        for codebook in self.codebooks:
            # 计算到各码字的距离
            dist = torch.cdist(residual, codebook.weight)  # [B, N, K]
            
            # 采样
            if self.training:
                # Gumbel-softmax采样
                gumbels = -torch.log(-torch.log(torch.rand_like(dist) + 1e-8))
                soft_idx = F.softmax((dist + gumbels) / temperature, dim=-1)
                hard_idx = soft_idx.argmax(dim=-1)
            else:
                hard_idx = dist.argmin(dim=-1)
            
            indices.append(hard_idx)
            
            # 获取码字向量
            code = F.embedding(hard_idx, codebook.weight)
            quantized = quantized + code
            
            # 计算残余
            residual = residual - code
        
        return quantized, torch.stack(indices, dim=-1)

4.3 连续扩散模型

class DiffusionModel(nn.Module):
    """
    连续扩散模型
    在连续潜空间进行去噪
    """
    def __init__(self, config):
        super().__init__()
        
        # 时间嵌入
        self.time_embed = SinusoidalPositionEmbedding(config.latent_dim)
        
        # 去噪网络
        self.denoiser = nn.Sequential(
            nn.Linear(config.latent_dim, config.hidden_dim),
            nn.GELU(),
            *[
                ResidualBlock(config)
                for _ in range(config.num_blocks)
            ],
            nn.Linear(config.hidden_dim, config.latent_dim)
        )
        
    def forward(self, zt, t):
        """
        Args:
            zt: 加噪后的潜向量 [B, N, latent_dim]
            t: 时间步 [B]
        Returns:
            预测的z0
        """
        # 时间嵌入
        t_emb = self.time_embed(t)  # [B, latent_dim]
        
        # 融入时间信息
        h = zt + t_emb.unsqueeze(1)
        
        # 去噪
        z0_pred = self.denoiser(h)
        
        return z0_pred

5. 训练策略

5.1 课程学习

训练分为多个阶段:

def curriculum_training(model, dataloader, config):
    """
    课程学习训练
    逐步增加噪声强度
    """
    for stage in range(config.num_stages):
        # 当前阶段的噪声调度
        num_steps = config.stage_steps[stage]
        betas = cosine_beta_schedule(num_steps, s=0.008)
        alphas = 1 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        
        model.diffusion.set_schedule(betas, alphas, alphas_cumprod)
        
        # 训练
        for epoch in range(config.stage_epochs[stage]):
            for batch in dataloader:
                # 采样时间步
                t = torch.randint(0, num_steps, (batch_size,))
                
                # 训练
                loss = model.training_loss(batch, t)
                # ...

5.2 损失函数

连续扩散使用简单的MSE损失

预测 的损失

6. 实验结果

6.1 语言建模

困惑度对比

模型WikiText-103Penn TreebankOpenWebText
LLaMA-7B (AR)12.315.214.1
LLaDA-7B (离散)13.116.215.1
CLDM-7B (连续)11.814.513.6

6.2 生成质量

人类评估

模型流畅性一致性多样性
LLaMA4.24.53.8
LLaDA3.94.14.2
CLDM4.34.44.1

6.3 推理效率

生成延迟(序列长度=512)

模型延迟 (ms)加速比
LLaMA (AR)12501.0x
LLaDA (离散)6801.8x
CLDM5202.4x

7. 与其他方法的关系

7.1 vs 自回归模型

特性AR LMCLDM
生成方式顺序并行
延迟
全局一致性隐式显式
训练效率

7.2 vs 离散扩散

特性离散扩散连续扩散
空间类型离散连续
梯度近似精确
损失类型交叉熵MSE
码本需求必须可选

8. 代码实现

8.1 完整模型

class CLDMLanguageModel(nn.Module):
    """
    连续潜空间扩散语言模型完整实现
    """
    def __init__(self, config):
        super().__init__()
        
        # 编码器
        self.encoder = SemanticEncoder(config)
        
        # 量化器(可选)
        self.use_quantizer = config.use_quantizer
        if self.use_quantizer:
            self.quantizer = ResidualVectorQuantizer(config)
        
        # 扩散模型
        self.diffusion = DiffusionModel(config)
        
        # 解码器
        self.decoder = SemanticDecoder(config)
        
        # 噪声调度
        self.register_buffer('betas', None)
        self.register_buffer('alphas', None)
        self.register_buffer('alphas_cumprod', None)
        
    def set_schedule(self, betas, alphas, alphas_cumprod):
        """设置噪声调度"""
        self.betas = betas
        self.alphas = alphas
        self.alphas_cumprod = alphas_cumprod
        
    @torch.no_grad()
    def generate(self, batch_size, max_len, temperature=1.0):
        """生成文本"""
        # 初始化:纯噪声
        zT = torch.randn(batch_size, max_len, self.latent_dim, device=next(self.parameters()).device)
        
        # 逐步去噪
        for t in reversed(range(len(self.betas))):
            # 预测z0
            z0_pred = self.diffusion(zT, torch.tensor([t] * batch_size, device=zT.device))
            
            # 计算zt-1
            alpha_t = self.alphas[t]
            alpha_bar_t = self.alphas_cumprod[t]
            alpha_bar_t_1 = self.alphas_cumprod[t-1] if t > 0 else torch.tensor(1.0)
            
            # 去噪
            zT_1 = (zT - torch.sqrt(1 - alpha_bar_t) * self.diffusion.denoiser(zT, t)) / torch.sqrt(alpha_bar_t)
            zT_1 = zT_1 * torch.sqrt(alpha_bar_t_1) + torch.randn_like(zT) * torch.sqrt(1 - alpha_bar_t_1)
            
            zT = zT_1
        
        # 解码
        return self.decoder(zT)

9. 总结与展望

9.1 主要贡献

  1. 连续潜空间:避免离散空间建模的局限性
  2. 高效训练:简单的MSE损失,自然的梯度流动
  3. 快速推理:并行生成, 延迟

9.2 局限性

  1. 语义流形假设:依赖嵌入空间的假设
  2. 生成质量:某些场景可能不如AR模型
  3. 实现复杂度:多组件联合训练

9.3 未来方向

  • 更有效的潜空间表示
  • 与AR模型的结合
  • 多模态扩展

参考文献

Footnotes

  1. Continuous Latent Diffusion LM: ByteDance Seed, “Continuous Latent Diffusion Language Model”, arXiv:2605.06548 2