条件扩散模型与Classifier-Free Guidance

在实际应用中,我们通常需要条件生成——根据文本描述、类别标签或其他信息来引导生成过程。本文深入解析条件扩散模型的设计和Classifier-Free Guidance(CFG)技术。1

条件生成问题

什么是条件生成

条件生成(Conditional Generation)是指在给定额外信息(条件)的情况下生成目标数据:

应用场景

  • Text-to-Image:根据文本描述生成图像(Stable Diffusion, DALL-E)
  • Image-to-Image:根据输入图像和条件生成(img2img, inpainting)
  • 类别条件生成:根据类别标签生成对应类别的图像(Class-conditional CIFAR-10)

条件扩散模型

时间调节

最简单的条件形式是将时间步 作为条件输入。但真正的条件(如文本)需要更复杂的处理。

class ConditionalDiffusionModel(torch.nn.Module):
    def __init__(self, denoiser, condition_encoder):
        super().__init__()
        self.denoiser = denoiser
        self.condition_encoder = condition_encoder  # 文本编码器等
    
    def forward(self, xt, t, condition):
        """
        Args:
            xt: 当前噪声图像 (batch, C, H, W)
            t: 时间步 (batch,)
            condition: 条件输入 (batch, seq_len)
        """
        # 编码条件
        cond_emb = self.condition_encoder(condition)  # (batch, d_model)
        
        # 将条件注入去噪网络
        return self.denoiser(xt, t, cond_emb)

条件注入方式

注入方式实现特点
Concatenation在通道维度拼接简单,适用于低维条件
Cross-AttentionQKV交叉注意力灵活,支持变长条件
Adaptive NormAdaGN, AdaLN参数高效,常用
Feature ModulationFiLM, Scale-Shift表达能力强
# Cross-Attention注入
class CrossAttentionBlock(torch.nn.Module):
    def __init__(self, d_model, d_cond, num_heads=8):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.to_q = nn.Linear(d_model, d_model)
        self.to_k = nn.Linear(d_cond, d_model)
        self.to_v = nn.Linear(d_cond, d_model)
        self.to_out = nn.Linear(d_model, d_model)
    
    def forward(self, x, cond):
        """
        Args:
            x: 图像特征 (batch, seq, d_model)
            cond: 条件嵌入 (batch, cond_seq, d_cond)
        """
        x_norm = self.norm(x)
        q = self.to_q(x_norm)
        k = self.to_k(cond)
        v = self.to_v(cond)
        
        # 注意力
        attn = F.scaled_dot_product_attention(q, k, v)
        return x + self.to_out(attn)
 
# AdaGN注入
class AdaGNBlock(torch.nn.Module):
    def __init__(self, d_model, num_groups=32):
        super().__init__()
        self.norm = nn.GroupNorm(num_groups, d_model)
        self.proj = nn.Linear(d_model * 2, d_model * 4)  # scale, shift from condition
    
    def forward(self, x, cond_emb):
        """
        Args:
            x: 特征 (batch, C, H, W)
            cond_emb: 条件嵌入 (batch, d_model)
        """
        x_norm = self.norm(x)
        scale_shift = self.proj(F.silu(cond_emb))
        scale, shift = scale_shift.chunk(2, dim=1)
        scale = scale.unsqueeze(-1).unsqueeze(-1)
        shift = shift.unsqueeze(-1).unsqueeze(-1)
        return x_norm * (1 + scale) + shift

Classifier-Guidance

贝叶斯条件化

假设有一个预训练分类器 ,我们想将其知识迁移到扩散模型的条件生成:

在扩散模型框架下,这等价于在每个时间步注入分类器的梯度:

其中 引导强度(guidance scale)。

Classifier-Guidance采样

@torch.no_grad()
def classifier_guided_sampling(diffusion_model, classifier, xt, y, gamma=1.0, T=1000):
    """
    带分类器引导的采样
    
    Args:
        diffusion_model: 去噪模型
        classifier: 预训练分类器
        xt: 初始噪声
        y: 目标类别
        gamma: 引导强度
    """
    for t in reversed(range(T)):
        # 1. 用无分类器模型预测噪声
        eps_uncond = diffusion_model(xt, t)  # 无条件预测
        
        # 2. 计算分类器梯度
        with torch.enable_grad():
            xt.requires_grad = True
            logits = classifier(xt)
            log_prob = F.log_softmax(logits, dim=1)
            # 分类器梯度 = score梯度
            classifier_grad = torch.autograd.grad(
                log_prob[:, y].sum(), xt
            )[0]
        
        # 3. 结合引导
        # 分类器梯度指向使类别概率增大的方向
        # 对于DDPM: μ = μ_uncond + γ * σ² * ∇log p(y|x)
        guidance = gamma * classifier_grad
        
        # 4. 修正均值估计
        alpha_bar_t = alphas_cumprod[t]
        eps_guided = eps_uncond + guidance * (1 - alpha_bar_t)
        
        # 5. 采样下一步...
        xt = denoise_step(xt, eps_guided, t)
    
    return xt

Classifier-Guidance的问题

  1. 需要额外训练分类器:增加训练成本
  2. 分类器与扩散模型可能不匹配:需要针对噪声图像训练的分类器
  3. 梯度可能不稳定:需要小心处理

Classifier-Free Guidance (CFG)

核心思想

Classifier-Free Guidance(Ho & Salimans, 2022)1 通过无条件模型同时学习条件和无条件生成,无需单独的分类器

关键洞察

  • 无条件预测:
  • 条件预测:
  • 两者做差,得到”条件方向”

其中 是引导强度(通常 )。

def cfg_guidance(eps_cond, eps_uncond, w):
    """
    Classifier-Free Guidance
    
    Args:
        eps_cond: 条件预测 (batch, ...)
        eps_uncond: 无条件预测 (batch, ...)
        w: 引导强度
    
    Returns:
        带引导的预测
    """
    return eps_uncond + w * (eps_cond - eps_uncond)

训练策略

CFG的核心是在训练时随机丢弃条件,让模型同时学习条件和条件生成:

def training_step(model, x0, condition, cfg_prob=0.1):
    """
    CFG训练:随机丢弃条件
    
    Args:
        model: 去噪模型(接受 x, t, condition)
        x0: 原始图像
        condition: 条件输入(如文本)
        cfg_prob: 条件丢弃概率
    """
    batch_size = x0.shape[0]
    
    # 随机决定哪些样本丢弃条件
    mask = torch.rand(batch_size) > cfg_prob
    condition_masked = condition.clone()
    condition_masked[~mask] = ""  # 或特殊标记
    
    # 前向过程
    t = torch.randint(0, T, (batch_size,))
    eps = torch.randn_like(x0)
    xt = add_noise(x0, t, alphas_cumprod)
    
    # 预测(条件或无条件)
    eps_theta = model(xt, t, condition_masked)
    
    # 计算损失
    loss = F.mse_loss(eps_theta, eps)
    return loss

为什么CFG有效

数学解释

是真实条件分布,我们可以定义:

做一阶近似(朗顿展开):

这意味着条件信息可以通过梯度形式注入,而CFG正是通过条件预测与无条件预测的差来估计这个梯度!

引导强度分析

效果
无条件生成(随机)
标准条件生成
增强条件约束,生成更符合条件
过大过饱和、伪影、降低质量

经验法则

  • 文本到图像:
  • 类别条件:
  • img2img:

CFG的实践技巧

动态阈值(Dynamic Thresholding)

高引导强度下,像素值可能超出 范围。动态阈值可以解决这个问题:

def dynamic_threshold(eps, percentile=0.99):
    """
    动态阈值:将超出范围的像素裁剪到百分位数
    
    Args:
        eps: 噪声预测 (batch, C, H, W)
        percentile: 百分位数(0.99即保留99%的值范围)
    
    Returns:
        裁剪后的噪声
    """
    # 获取绝对值的百分位数阈值
    threshold = torch.quantile(eps.abs(), percentile)
    # 裁剪
    return torch.clamp(eps, min=-threshold, max=threshold)
 
# 在采样时使用
eps_guided = cfg_guidance(eps_cond, eps_uncond, w)
eps_clipped = dynamic_threshold(eps_guided, percentile=0.99)

Negative Prompt(负提示)

CFG可以自然地实现负提示——引导模型避免生成某些内容:

def cfg_with_negative_prompt(eps_cond, eps_neg, eps_uncond=None, w=7.5, w_neg=1.0):
    """
    带负提示的CFG
    
    Args:
        eps_cond: 正向条件预测
        eps_neg: 负向条件预测
        eps_uncond: 无条件预测(可选)
        w: 正向引导强度
        w_neg: 负向引导强度
    """
    if eps_uncond is None:
        # 如果没有单独的无条件预测,使用负提示作为无条件近似
        eps_uncond = eps_neg
    
    # 正向引导
    eps_pos_guidance = eps_cond - eps_uncond
    # 负向引导(从条件中减去)
    eps_neg_guidance = eps_cond - eps_neg
    
    # 组合
    return eps_cond + w * eps_pos_guidance - w_neg * eps_neg_guidance
 
# 使用示例
# 负提示可以是 "blurry, low quality, distorted"
prompt = "a beautiful landscape"
negative_prompt = "blurry, low quality, distorted"
eps_cond = model(xt, t, prompt_emb)
eps_neg = model(xt, t, neg_prompt_emb)
eps_guided = cfg_with_negative_prompt(eps_cond, eps_neg, w=7.5, w_neg=1.0)

CFG++(增强版CFG)

CFG++通过引入动态引导强度来改进原始CFG:

def cfgpp_guidance(eps_cond, eps_uncond, x, w_base=7.5, s_base=1.0):
    """
    CFG++: 结合动态缩放
    
    核心思想:让模型同时预测有条件和无条件,动态调整引导强度
    """
    # 基础CFG
    eps_guidance = eps_cond - eps_uncond
    
    # 动态缩放:基于当前像素范围
    pixel_scale = x.abs().mean()
    s = s_base / (pixel_scale + 1e-6)
    
    return eps_uncond + w_base * eps_guidance * s

渐进式CFG

在采样初期使用高引导强度快速建立构图,后期降低引导强度以保留细节:

@torch.no_grad()
def progressive_cfg_sampling(model, xt, condition, guidance_fn):
    """
    渐进式CFG采样
    
    早期高引导(构图),后期低引导(细节)
    """
    guidance_schedule = cosine_schedule(T, w_min=1.0, w_max=15.0)
    
    for t in reversed(range(T)):
        w = guidance_schedule[t]
        eps_cond = model(xt, t, condition)
        eps_uncond = model(xt, t, None)  # 无条件
        
        eps_guided = guidance_fn(eps_cond, eps_uncond, w)
        xt = denoise_step(xt, eps_guided, t)
    
    return xt
 
def cosine_schedule(T, w_min=1.0, w_max=15.0):
    """余弦引导强度调度"""
    t_norm = torch.arange(T) / T  # [0, 1]
    # 余弦调度:早期高引导,后期低引导
    w = w_max - (w_max - w_min) * (1 - torch.cos(torch.pi * t_norm)) / 2
    return w.flip(0)  # 翻转:从T-1到0

CFG的理论分析

Apple ML的理论工作

Apple ML团队的研究表明2

  1. CFG ≈ Predictor-Corrector:CFG可以理解为在DDPM采样过程中加入了”校正器”
  2. 一阶近似的有效性:CFG是一阶近似,高阶项通常较小
  3. 最优引导强度:存在理论上的最优 ,与数据分布有关

线性扩散模型中的分析

在简化的一维线性扩散模型中3

CFG的最优引导强度可以解析求得:

其中 是噪声方差。这解释了为什么 不需要选太大——过高的 会放大噪声!

CFG的局限性

问题描述解决方案
计算成本需要两次前向传播(条件+无条件)使用共享编码器,分支推理
引导强度敏感过高/过低都影响质量网格搜索,或使用渐进式CFG
负提示冲突正负提示可能产生冲突Prompt weighting、LoRA
分布外条件对意外条件可能产生奇怪输出模型微调、RLHF

完整实现示例

import torch
import torch.nn.functional as F
 
class CFGDiffusionSampler:
    def __init__(self, model, text_encoder, alphas_cumprod, cfg_scale=7.5):
        self.model = model
        self.text_encoder = text_encoder
        self.alphas_cumprod = alphas_cumprod
        self.cfg_scale = cfg_scale
    
    @torch.no_grad()
    def sample(self, xt, prompt, negative_prompt="", num_steps=50, latent=True):
        """
        CFG采样主循环
        
        Args:
            xt: 初始噪声
            prompt: 正向提示词
            negative_prompt: 负向提示词
            num_steps: 采样步数
        """
        batch_size = xt.shape[0]
        
        # 编码提示词
        pos_emb = self.text_encoder(prompt)
        neg_emb = self.text_encoder(negative_prompt) if negative_prompt else None
        
        # 时间步调度
        timesteps = torch.linspace(999, 0, num_steps, device=xt.device).long()
        
        for i, t in enumerate(timesteps):
            # 调整时间步到实际DDPM步数
            t_expanded = t.unsqueeze(0).expand(batch_size)
            
            # 条件预测
            eps_pos = self.model(xt, t_expanded, pos_emb)
            
            # 无条件/负向预测
            if negative_prompt:
                eps_uncond = self.model(xt, t_expanded, neg_emb)
            else:
                eps_uncond = self.model(xt, t_expanded, torch.zeros_like(pos_emb))
            
            # CFG引导
            eps_guided = eps_uncond + self.cfg_scale * (eps_pos - eps_uncond)
            
            # 动态阈值(可选)
            if self.cfg_scale > 1:
                eps_guided = self.dynamic_threshold(eps_guided)
            
            # 去噪步骤
            xt = self.denoise_step(xt, eps_guided, t.item())
        
        return xt
    
    @torch.no_grad()
    def denoise_step(self, xt, eps, t):
        """单个去噪步骤"""
        alpha_t = self.alphas_cumprod[t]
        alpha_bar_t = self.alphas_cumprod[t]
        
        # 系数
        coef1 = (1 - alpha_t) / torch.sqrt(1 - alpha_bar_t)
        coef2 = 1 / torch.sqrt(alpha_t)
        
        # 均值
        pred_x0 = coef2 * (xt - coef1 * eps)
        
        # 添加噪声(如果不是最后一步)
        if t > 0:
            noise = torch.randn_like(xt)
            beta_t = 1 - alpha_t / (1 - alpha_bar_t) * (1 - alpha_t)
            xt = torch.sqrt(alpha_t) * pred_x0 + torch.sqrt(beta_t) * noise
        
        return xt
    
    def dynamic_threshold(self, eps, percentile=0.995):
        """动态阈值"""
        threshold = torch.quantile(eps.abs(), percentile)
        return torch.clamp(eps, -threshold, threshold)

参考

相关链接:扩散模型理论基础

Footnotes

  1. Ho & Salimans, “Classifier-Free Diffusion Guidance”, NeurIPS Workshop 2022. https://arxiv.org/abs/2207.12598 2

  2. Apple ML Team, “Understanding Classifier-Free Guidance”, Apple ML Research 2023. https://machinelearning.apple.com/research/classifier-free-guidance

  3. “Towards Understanding the Mechanisms of Classifier-Free Guidance”, arXiv 2025. https://arxiv.org/abs/2505.19210