条件扩散模型与Classifier-Free Guidance
在实际应用中,我们通常需要条件生成——根据文本描述、类别标签或其他信息来引导生成过程。本文深入解析条件扩散模型的设计和Classifier-Free Guidance(CFG)技术。1
条件生成问题
什么是条件生成
条件生成(Conditional Generation)是指在给定额外信息(条件)的情况下生成目标数据:
应用场景:
- Text-to-Image:根据文本描述生成图像(Stable Diffusion, DALL-E)
- Image-to-Image:根据输入图像和条件生成(img2img, inpainting)
- 类别条件生成:根据类别标签生成对应类别的图像(Class-conditional CIFAR-10)
条件扩散模型
时间调节
最简单的条件形式是将时间步 作为条件输入。但真正的条件(如文本)需要更复杂的处理。
class ConditionalDiffusionModel(torch.nn.Module):
def __init__(self, denoiser, condition_encoder):
super().__init__()
self.denoiser = denoiser
self.condition_encoder = condition_encoder # 文本编码器等
def forward(self, xt, t, condition):
"""
Args:
xt: 当前噪声图像 (batch, C, H, W)
t: 时间步 (batch,)
condition: 条件输入 (batch, seq_len)
"""
# 编码条件
cond_emb = self.condition_encoder(condition) # (batch, d_model)
# 将条件注入去噪网络
return self.denoiser(xt, t, cond_emb)条件注入方式
| 注入方式 | 实现 | 特点 |
|---|---|---|
| Concatenation | 在通道维度拼接 | 简单,适用于低维条件 |
| Cross-Attention | QKV交叉注意力 | 灵活,支持变长条件 |
| Adaptive Norm | AdaGN, AdaLN | 参数高效,常用 |
| Feature Modulation | FiLM, Scale-Shift | 表达能力强 |
# Cross-Attention注入
class CrossAttentionBlock(torch.nn.Module):
def __init__(self, d_model, d_cond, num_heads=8):
super().__init__()
self.norm = nn.LayerNorm(d_model)
self.to_q = nn.Linear(d_model, d_model)
self.to_k = nn.Linear(d_cond, d_model)
self.to_v = nn.Linear(d_cond, d_model)
self.to_out = nn.Linear(d_model, d_model)
def forward(self, x, cond):
"""
Args:
x: 图像特征 (batch, seq, d_model)
cond: 条件嵌入 (batch, cond_seq, d_cond)
"""
x_norm = self.norm(x)
q = self.to_q(x_norm)
k = self.to_k(cond)
v = self.to_v(cond)
# 注意力
attn = F.scaled_dot_product_attention(q, k, v)
return x + self.to_out(attn)
# AdaGN注入
class AdaGNBlock(torch.nn.Module):
def __init__(self, d_model, num_groups=32):
super().__init__()
self.norm = nn.GroupNorm(num_groups, d_model)
self.proj = nn.Linear(d_model * 2, d_model * 4) # scale, shift from condition
def forward(self, x, cond_emb):
"""
Args:
x: 特征 (batch, C, H, W)
cond_emb: 条件嵌入 (batch, d_model)
"""
x_norm = self.norm(x)
scale_shift = self.proj(F.silu(cond_emb))
scale, shift = scale_shift.chunk(2, dim=1)
scale = scale.unsqueeze(-1).unsqueeze(-1)
shift = shift.unsqueeze(-1).unsqueeze(-1)
return x_norm * (1 + scale) + shiftClassifier-Guidance
贝叶斯条件化
假设有一个预训练分类器 ,我们想将其知识迁移到扩散模型的条件生成:
在扩散模型框架下,这等价于在每个时间步注入分类器的梯度:
其中 是引导强度(guidance scale)。
Classifier-Guidance采样
@torch.no_grad()
def classifier_guided_sampling(diffusion_model, classifier, xt, y, gamma=1.0, T=1000):
"""
带分类器引导的采样
Args:
diffusion_model: 去噪模型
classifier: 预训练分类器
xt: 初始噪声
y: 目标类别
gamma: 引导强度
"""
for t in reversed(range(T)):
# 1. 用无分类器模型预测噪声
eps_uncond = diffusion_model(xt, t) # 无条件预测
# 2. 计算分类器梯度
with torch.enable_grad():
xt.requires_grad = True
logits = classifier(xt)
log_prob = F.log_softmax(logits, dim=1)
# 分类器梯度 = score梯度
classifier_grad = torch.autograd.grad(
log_prob[:, y].sum(), xt
)[0]
# 3. 结合引导
# 分类器梯度指向使类别概率增大的方向
# 对于DDPM: μ = μ_uncond + γ * σ² * ∇log p(y|x)
guidance = gamma * classifier_grad
# 4. 修正均值估计
alpha_bar_t = alphas_cumprod[t]
eps_guided = eps_uncond + guidance * (1 - alpha_bar_t)
# 5. 采样下一步...
xt = denoise_step(xt, eps_guided, t)
return xtClassifier-Guidance的问题
- 需要额外训练分类器:增加训练成本
- 分类器与扩散模型可能不匹配:需要针对噪声图像训练的分类器
- 梯度可能不稳定:需要小心处理
Classifier-Free Guidance (CFG)
核心思想
Classifier-Free Guidance(Ho & Salimans, 2022)1 通过无条件模型同时学习条件和无条件生成,无需单独的分类器!
关键洞察:
- 无条件预测:
- 条件预测:
- 两者做差,得到”条件方向”
其中 是引导强度(通常 )。
def cfg_guidance(eps_cond, eps_uncond, w):
"""
Classifier-Free Guidance
Args:
eps_cond: 条件预测 (batch, ...)
eps_uncond: 无条件预测 (batch, ...)
w: 引导强度
Returns:
带引导的预测
"""
return eps_uncond + w * (eps_cond - eps_uncond)训练策略
CFG的核心是在训练时随机丢弃条件,让模型同时学习条件和条件生成:
def training_step(model, x0, condition, cfg_prob=0.1):
"""
CFG训练:随机丢弃条件
Args:
model: 去噪模型(接受 x, t, condition)
x0: 原始图像
condition: 条件输入(如文本)
cfg_prob: 条件丢弃概率
"""
batch_size = x0.shape[0]
# 随机决定哪些样本丢弃条件
mask = torch.rand(batch_size) > cfg_prob
condition_masked = condition.clone()
condition_masked[~mask] = "" # 或特殊标记
# 前向过程
t = torch.randint(0, T, (batch_size,))
eps = torch.randn_like(x0)
xt = add_noise(x0, t, alphas_cumprod)
# 预测(条件或无条件)
eps_theta = model(xt, t, condition_masked)
# 计算损失
loss = F.mse_loss(eps_theta, eps)
return loss为什么CFG有效
数学解释:
设 是真实条件分布,我们可以定义:
对 做一阶近似(朗顿展开):
这意味着条件信息可以通过梯度形式注入,而CFG正是通过条件预测与无条件预测的差来估计这个梯度!
引导强度分析
| 值 | 效果 |
|---|---|
| 无条件生成(随机) | |
| 标准条件生成 | |
| 增强条件约束,生成更符合条件 | |
| 过大 | 过饱和、伪影、降低质量 |
经验法则:
- 文本到图像:
- 类别条件:
- img2img:
CFG的实践技巧
动态阈值(Dynamic Thresholding)
高引导强度下,像素值可能超出 范围。动态阈值可以解决这个问题:
def dynamic_threshold(eps, percentile=0.99):
"""
动态阈值:将超出范围的像素裁剪到百分位数
Args:
eps: 噪声预测 (batch, C, H, W)
percentile: 百分位数(0.99即保留99%的值范围)
Returns:
裁剪后的噪声
"""
# 获取绝对值的百分位数阈值
threshold = torch.quantile(eps.abs(), percentile)
# 裁剪
return torch.clamp(eps, min=-threshold, max=threshold)
# 在采样时使用
eps_guided = cfg_guidance(eps_cond, eps_uncond, w)
eps_clipped = dynamic_threshold(eps_guided, percentile=0.99)Negative Prompt(负提示)
CFG可以自然地实现负提示——引导模型避免生成某些内容:
def cfg_with_negative_prompt(eps_cond, eps_neg, eps_uncond=None, w=7.5, w_neg=1.0):
"""
带负提示的CFG
Args:
eps_cond: 正向条件预测
eps_neg: 负向条件预测
eps_uncond: 无条件预测(可选)
w: 正向引导强度
w_neg: 负向引导强度
"""
if eps_uncond is None:
# 如果没有单独的无条件预测,使用负提示作为无条件近似
eps_uncond = eps_neg
# 正向引导
eps_pos_guidance = eps_cond - eps_uncond
# 负向引导(从条件中减去)
eps_neg_guidance = eps_cond - eps_neg
# 组合
return eps_cond + w * eps_pos_guidance - w_neg * eps_neg_guidance
# 使用示例
# 负提示可以是 "blurry, low quality, distorted"
prompt = "a beautiful landscape"
negative_prompt = "blurry, low quality, distorted"
eps_cond = model(xt, t, prompt_emb)
eps_neg = model(xt, t, neg_prompt_emb)
eps_guided = cfg_with_negative_prompt(eps_cond, eps_neg, w=7.5, w_neg=1.0)CFG++(增强版CFG)
CFG++通过引入动态引导强度来改进原始CFG:
def cfgpp_guidance(eps_cond, eps_uncond, x, w_base=7.5, s_base=1.0):
"""
CFG++: 结合动态缩放
核心思想:让模型同时预测有条件和无条件,动态调整引导强度
"""
# 基础CFG
eps_guidance = eps_cond - eps_uncond
# 动态缩放:基于当前像素范围
pixel_scale = x.abs().mean()
s = s_base / (pixel_scale + 1e-6)
return eps_uncond + w_base * eps_guidance * s渐进式CFG
在采样初期使用高引导强度快速建立构图,后期降低引导强度以保留细节:
@torch.no_grad()
def progressive_cfg_sampling(model, xt, condition, guidance_fn):
"""
渐进式CFG采样
早期高引导(构图),后期低引导(细节)
"""
guidance_schedule = cosine_schedule(T, w_min=1.0, w_max=15.0)
for t in reversed(range(T)):
w = guidance_schedule[t]
eps_cond = model(xt, t, condition)
eps_uncond = model(xt, t, None) # 无条件
eps_guided = guidance_fn(eps_cond, eps_uncond, w)
xt = denoise_step(xt, eps_guided, t)
return xt
def cosine_schedule(T, w_min=1.0, w_max=15.0):
"""余弦引导强度调度"""
t_norm = torch.arange(T) / T # [0, 1]
# 余弦调度:早期高引导,后期低引导
w = w_max - (w_max - w_min) * (1 - torch.cos(torch.pi * t_norm)) / 2
return w.flip(0) # 翻转:从T-1到0CFG的理论分析
Apple ML的理论工作
Apple ML团队的研究表明2:
- CFG ≈ Predictor-Corrector:CFG可以理解为在DDPM采样过程中加入了”校正器”
- 一阶近似的有效性:CFG是一阶近似,高阶项通常较小
- 最优引导强度:存在理论上的最优 ,与数据分布有关
线性扩散模型中的分析
在简化的一维线性扩散模型中3:
CFG的最优引导强度可以解析求得:
其中 是噪声方差。这解释了为什么 不需要选太大——过高的 会放大噪声!
CFG的局限性
| 问题 | 描述 | 解决方案 |
|---|---|---|
| 计算成本 | 需要两次前向传播(条件+无条件) | 使用共享编码器,分支推理 |
| 引导强度敏感 | 过高/过低都影响质量 | 网格搜索,或使用渐进式CFG |
| 负提示冲突 | 正负提示可能产生冲突 | Prompt weighting、LoRA |
| 分布外条件 | 对意外条件可能产生奇怪输出 | 模型微调、RLHF |
完整实现示例
import torch
import torch.nn.functional as F
class CFGDiffusionSampler:
def __init__(self, model, text_encoder, alphas_cumprod, cfg_scale=7.5):
self.model = model
self.text_encoder = text_encoder
self.alphas_cumprod = alphas_cumprod
self.cfg_scale = cfg_scale
@torch.no_grad()
def sample(self, xt, prompt, negative_prompt="", num_steps=50, latent=True):
"""
CFG采样主循环
Args:
xt: 初始噪声
prompt: 正向提示词
negative_prompt: 负向提示词
num_steps: 采样步数
"""
batch_size = xt.shape[0]
# 编码提示词
pos_emb = self.text_encoder(prompt)
neg_emb = self.text_encoder(negative_prompt) if negative_prompt else None
# 时间步调度
timesteps = torch.linspace(999, 0, num_steps, device=xt.device).long()
for i, t in enumerate(timesteps):
# 调整时间步到实际DDPM步数
t_expanded = t.unsqueeze(0).expand(batch_size)
# 条件预测
eps_pos = self.model(xt, t_expanded, pos_emb)
# 无条件/负向预测
if negative_prompt:
eps_uncond = self.model(xt, t_expanded, neg_emb)
else:
eps_uncond = self.model(xt, t_expanded, torch.zeros_like(pos_emb))
# CFG引导
eps_guided = eps_uncond + self.cfg_scale * (eps_pos - eps_uncond)
# 动态阈值(可选)
if self.cfg_scale > 1:
eps_guided = self.dynamic_threshold(eps_guided)
# 去噪步骤
xt = self.denoise_step(xt, eps_guided, t.item())
return xt
@torch.no_grad()
def denoise_step(self, xt, eps, t):
"""单个去噪步骤"""
alpha_t = self.alphas_cumprod[t]
alpha_bar_t = self.alphas_cumprod[t]
# 系数
coef1 = (1 - alpha_t) / torch.sqrt(1 - alpha_bar_t)
coef2 = 1 / torch.sqrt(alpha_t)
# 均值
pred_x0 = coef2 * (xt - coef1 * eps)
# 添加噪声(如果不是最后一步)
if t > 0:
noise = torch.randn_like(xt)
beta_t = 1 - alpha_t / (1 - alpha_bar_t) * (1 - alpha_t)
xt = torch.sqrt(alpha_t) * pred_x0 + torch.sqrt(beta_t) * noise
return xt
def dynamic_threshold(self, eps, percentile=0.995):
"""动态阈值"""
threshold = torch.quantile(eps.abs(), percentile)
return torch.clamp(eps, -threshold, threshold)参考
相关链接:扩散模型理论基础
Footnotes
-
Ho & Salimans, “Classifier-Free Diffusion Guidance”, NeurIPS Workshop 2022. https://arxiv.org/abs/2207.12598 ↩ ↩2
-
Apple ML Team, “Understanding Classifier-Free Guidance”, Apple ML Research 2023. https://machinelearning.apple.com/research/classifier-free-guidance ↩
-
“Towards Understanding the Mechanisms of Classifier-Free Guidance”, arXiv 2025. https://arxiv.org/abs/2505.19210 ↩