扩散模型高效采样技术

DDPM需要1000+步采样才能生成高质量图像,这在实际应用中是不可接受的。本文介绍多种高效采样技术,将采样步数从1000步减少到1-50步,同时保持或提升生成质量。12

问题背景

采样效率瓶颈

DDPM的反向过程需要 步迭代(通常 ),每步都需要:

  1. 去噪网络前向传播
  2. 添加/计算噪声

总时间复杂度

采样步数的影响

实验表明,DDPM在较少步数下质量急剧下降:

采样步数FID(越高越差)SSIM
10004.760.83
1008.670.71
5013.370.62
1042.750.31

DDIM:隐式生成模型

核心思想

DDIM(Denoising Diffusion Implicit Models)1 观察到:采样路径不一定要对应前向过程的逆转

关键洞察:扩散模型只需保证 正确,不一定要通过马尔可夫链!

非马尔可夫反向过程

DDIM定义了一种非马尔可夫反向过程:

其中 是预测的”分数方向”, 控制随机性。

DDIM采样公式

控制采样过程的随机性:

  • 确定性采样(产生相同噪声→相同图像)
  • 随机采样(等价于DDPM)
@torch.no_grad()
def ddim_sampling(model, xt, alphas_cumprod, T_target, eta=0.0):
    """
    DDIM采样
    
    Args:
        model: 去噪模型(预测噪声或x0)
        xt: 初始噪声 (batch, C, H, W)
        alphas_cumprod: 累积alpha (T,)
        T_target: 目标采样步数(通常为T的子集,如[999, 998, ..., 0]的子集)
        eta: 随机性参数 (0=确定, 1=随机)
    """
    T_full = len(alphas_cumprod)
    timesteps = torch.linspace(T_full-1, 0, T_target).long()
    
    for i, t in enumerate(tqdm(timesteps)):
        # 当前步和上一步
        t_curr = timesteps[i]
        t_prev = timesteps[i-1] if i > 0 else -1
        
        # 预测噪声
        eps_theta = model(xt, t_curr)
        
        # 预测x0
        alpha_bar_t = alphas_cumprod[t_curr]
        x0_pred = (xt - torch.sqrt(1-alpha_bar_t)*eps_theta) / torch.sqrt(alpha_bar_t)
        
        # DDIM系数
        alpha_t = alphas_cumprod[t_curr]
        alpha_t_prev = alphas_cumprod[t_prev] if t_prev >= 0 else 1.0
        
        # 方向指向x0
        pred_x0_direction = xt - torch.sqrt(1-alpha_t)*eps_theta
        pred_x0_direction = pred_x0_direction / torch.sqrt(alpha_t)  # 到x0的方向
        
        # 方差
        sigma_t = eta * torch.sqrt(
            (1 - alpha_t_prev) / (1 - alpha_t) * (1 - alpha_t / alpha_t_prev)
        )
        
        # 随机噪声(eta=0时为0)
        noise = torch.randn_like(xt) if eta > 0 else 0
        
        # DDIM更新
        xt = torch.sqrt(alpha_t_prev) * pred_x0_direction * torch.sqrt(alpha_t_prev) + \
             torch.sqrt(1 - alpha_t_prev - sigma_t**2) * eps_theta + \
             sigma_t * noise
    
    return xt

DDPM vs DDIM

特性DDPMDDIM
反向过程马尔可夫链非马尔可夫
采样步数1000+20-100
确定性随机(每步加噪声)可确定性(η=0)
一致性迭代累积误差误差不累积
质量(50步)13.37 FID4.76 FID

步长选择策略

不是所有时间步都同等重要。DDIM可以使用不均匀的步长

def get_ddim_timesteps(T_full=1000, T_target=50, schedule='ddim'):
    """
    生成DDIM采样的时间步序列
    
    可选策略:
    - 'linear': 线性均匀采样
    - 'quadratic': 二次均匀采样(后期步更多)
    - 'karras': Karras等人的策略(噪声水平线性采样)
    """
    if schedule == 'linear':
        # 线性采样
        return torch.linspace(T_full-1, 0, T_target).long()
    
    elif schedule == 'karras':
        # Karras调度:确保时间步对应均匀的噪声水平
        # sigma = sqrt(1 - alpha_bar) 线性变化
        sigmas = torch.linspace(0, 1, T_target)
        alphas_cumprod = 1 - sigmas**2
        
        # 找到最接近的时间步
        alphas_cumprod_full = torch.linspace(1, 0, T_full)
        timesteps = torch.zeros(T_target, dtype=torch.long)
        for i in range(T_target):
            idx = torch.argmin((alphas_cumprod_full - alphas_cumprod[i]).abs())
            timesteps[i] = idx
        
        return timesteps.flip(0)  # 从高到低
    
    return torch.linspace(T_full-1, 0, T_target).long()

DDIM反演

DDIM的确定性特性使得反演(inversion)成为可能——从图像恢复噪声:

@torch.no_grad()
def ddim_inversion(model, x0, alphas_cumprod, num_steps=50):
    """
    DDIM反演:从图像恢复噪声
    
    这是ReNoise等技术的核心!
    """
    # 起始点:预测x0
    xt = x0
    
    timesteps = torch.linspace(0, len(alphas_cumprod)-1, num_steps).long()
    
    for i, t in enumerate(tqdm(timesteps)):
        t_next = timesteps[i+1] if i < len(timesteps)-1 else -1
        
        # 预测噪声
        eps_theta = model(xt, t)
        
        # 反演系数(与采样相反)
        alpha_t = alphas_cumprod[t]
        alpha_t_next = alphas_cumprod[t_next] if t_next >= 0 else 1.0
        
        # 反演方向
        coef = torch.sqrt(alpha_t_next) / torch.sqrt(alpha_t)
        xt_next = coef * xt - torch.sqrt(alpha_t_next - coef**2) * eps_theta
        
        xt = xt_next
    
    return xt

Consistency Models

核心思想

Consistency Models2 由Song Yang等人于2023年提出,核心观察是:

扩散模型的任意时间步状态可以沿着轨迹自洽(consistency)地映射到起点。

数学上,定义一致性函数

这意味着对于任意

训练目标

一致性模型通过以下损失函数训练:

其中:

  • :当前模型
  • 目标网络(EMA更新)
  • :距离度量(如MSE)
  • :同一轨迹上的两个时间点
class ConsistencyModel(torch.nn.Module):
    def __init__(self, network, sigma_min=0.002, sigma_max=80.0):
        super().__init__()
        self.network = network  # UNet等去噪网络
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        
        # 目标网络(EMA)
        self.theta = copy.deepcopy(network)
        for param in self.theta.parameters():
            param.requires_grad = False
    
    def forward(self, x, sigma):
        """
        一致性模型前向
        
        Args:
            x: 噪声图像
            sigma: 噪声水平(标量)
        
        Returns:
            f(x, sigma) ≈ x0
        """
        # 网络预测 (c_skip * x + c_out * network_output)
        # 其中 c_skip = sigma_min² / (sigma² + sigma_min²)
        #      c_out = sigma * sigma_min / sqrt(sigma² + sigma_min²)
        
        c_skip = (self.sigma_min ** 2) / (sigma ** 2 + self.sigma_min ** 2)
        c_out = sigma * self.sigma_min / torch.sqrt(sigma ** 2 + self.sigma_min ** 2)
        
        # 网络输出 (预测 v = x0 / sigma - x / sigma² 形式的向量)
        network_output = self.network(x, sigma)
        
        return c_skip * x + c_out * network_output
    
    def training_loss(self, x0):
        """训练一致性模型"""
        # 采样噪声水平和时间步对
        sigma = torch.rand(x0.shape[0], device=x0.device) * (self.sigma_max - self.sigma_min) + self.sigma_min
        t = torch.rand(x0.shape[0], device=x0.device) * (self.sigma_max - self.sigma_min) + self.sigma_min
        
        # 生成噪声图像
        eps = torch.randn_like(x0)
        xt = x0 + sigma.view(-1, 1, 1, 1) * eps
        
        # 随机选择哪个时间点用于目标
        # 确保 t > s
        s = torch.rand_like(sigma) * sigma
        x_s = x0 + s.view(-1, 1, 1, 1) * eps
        
        # 一致性损失
        with torch.no_grad():
            target = self.theta(x_s, s)  # 目标网络输出
        
        pred = self.forward(xt, sigma)
        loss = F.mse_loss(pred, target)
        
        return loss

单步与多步采样

class ConsistencySampler:
    def __init__(self, model):
        self.model = model
    
    @torch.no_grad()
    def single_step_sample(self, xt, sigma):
        """单步采样(高质量但需要蒸馏)"""
        return self.model(xt, sigma)
    
    @torch.no_grad()
    def multi_step_sample(self, xt, sigma_max=80.0, N=64):
        """
        多步采样(无需蒸馏,但步数更多)
        
        等价于在离散化的ODE轨迹上使用ode solver
        """
        sigma = torch.tensor([sigma_max], device=xt.device)
        dt = sigma_max / N
        
        for _ in range(N):
            d = self.model(xt, sigma) - xt
            xt = xt + d * dt
            sigma = sigma - dt
            sigma = torch.clamp(sigma, min=self.model.sigma_min)
        
        return xt

Consistency Model vs DDPM

特性DDPMDDIMConsistency Model
采样步数100020-1001-10
训练方式重建损失重建损失一致性损失
是否需要蒸馏单步需要
生成质量接近DDPM
FID (CIFAR-10)4.764.163.55 (单步)

Progressive Distillation

蒸馏思想

Progressive Distillation的核心是知识蒸馏——让少步模型学习多步模型的行为。

Step 1: T步模型 → 2步模型(蒸馏)
Step 2: 2步模型 → 2步模型(精炼)
Step 3: 2步模型 → 1步模型(蒸馏)
class ProgressiveDistillation:
    def __init__(self, teacher_model, student_model):
        self.teacher = teacher_model
        self.student = student_model
    
    def distill_step(self, x0, ratio=0.5):
        """
        一步蒸馏:将T步模型的知识蒸馏到2步模型
        
        ratio: 决定中间时间步
        """
        batch_size = x0.shape[0]
        T = 1000
        
        # 选择时间步对 (t, s) 其中 s = ratio * t
        t = torch.randint(0, T, (batch_size,))
        s = (t.float() * ratio).long()
        s = torch.clamp(s, min=0)
        
        # 生成噪声
        eps = torch.randn_like(x0)
        xt = add_noise(x0, t)
        
        # 教师模型:多步预测
        with torch.no_grad():
            x_s_teacher = self.teacher.multi_step_denoise(xt, t, s)
        
        # 学生模型:直接从xt到xs(学习教师的输出)
        x_s_student = self.student(xt, t, s)
        
        # 蒸馏损失
        loss = F.mse_loss(x_s_student, x_s_teacher)
        return loss

自编码器蒸馏

更高效的方式是使用自编码器结构:

class AutoEncoderDistillation:
    """
    自编码器蒸馏:学生模型直接预测去噪后的图像
    """
    def __init__(self, student):
        self.student = student
    
    def train_step(self, x0, noise_levels):
        """
        Args:
            x0: 原始图像
            noise_levels: 要蒸馏的噪声水平
        """
        # 添加噪声
        noise = torch.randn_like(x0)
        xt = x0 + noise_levels.view(-1, 1, 1, 1) * noise
        
        # 学生预测
        x0_pred = self.student(xt, noise_levels)
        
        # 直接回归原图
        return F.mse_loss(x0_pred, x0)

LCM: Latent Consistency Models

核心思想

LCM(Latent Consistency Models)3 将Consistency Model的思想应用到潜在空间(Latent Space),大大加速了Stable Diffusion等模型。

关键洞察

LCM利用ODE轨迹的性质:

一致性函数满足:

因此可以从任意中间点直接预测

LCM训练

class LCM(torch.nn.Module):
    def __init__(self, vae, unet, text_encoder, cfg_scale=8.0):
        super().__init__()
        self.vae = vae
        self.unet = unet
        self.text_encoder = text_encoder
        self.cfg_scale = cfg_scale
    
    def training_loss(self, images, prompts, negative_prompts):
        """
        LCM训练:预测ODE轨迹的起点
        """
        batch_size = images.shape[0]
        
        # 编码到潜在空间
        with torch.no_grad():
            latents = self.vae.encode(images).latent_dist.sample()
            latents = latents * self.vae.config.scaling_factor
        
        # 采样时间步
        t = torch.rand(batch_size, device=latents.device) * 999
        
        # 添加噪声
        noise = torch.randn_like(latents)
        latents_noisy = add_noise(latents, t)
        
        # 编码文本
        pos_emb = self.text_encoder(prompts)
        neg_emb = self.text_encoder(negative_prompts)
        
        # UNet预测
        noise_pred = self.unet(latents_noisy, t, pos_emb)
        noise_pred_uncond = self.unet(latents_noisy, t, neg_emb)
        
        # CFG
        noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred - noise_pred_uncond)
        
        # LCM目标:直接预测 x0
        alpha_bar_t = get_alpha_bar(t)
        x0_pred = (latents_noisy - torch.sqrt(1-alpha_bar_t) * noise_pred) / torch.sqrt(alpha_bar_t)
        
        # 目标:让x0_pred = latents
        return F.mse_loss(x0_pred, latents)
    
    @torch.no_grad()
    def sampling(self, latent_dist, prompts, num_steps=4):
        """
        LCM快速采样:4-8步即可
        
        Args:
            latent_dist: 初始噪声分布
            prompts: 文本提示
            num_steps: 采样步数(通常4-8)
        """
        # 初始化
        latent = latent_dist.sample()
        
        # 时间步调度
        timesteps = torch.linspace(999, 0, num_steps).long().to(latent.device)
        
        for i in range(num_steps):
            t = timesteps[i]
            
            # 预测
            noise_pred = self.unet(latent, t, text_emb)
            alpha_bar_t = get_alpha_bar(t)
            
            # CFG
            noise_pred = ...
            
            # 计算 x0
            x0_pred = (latent - torch.sqrt(1-alpha_bar_t) * noise_pred) / torch.sqrt(alpha_bar_t)
            
            # 更新(欧拉方法)
            dt = (timesteps[i+1] - t) / 1000 if i < num_steps - 1 else t / 1000
            latent = x0_pred  # 简化的欧拉更新
        
        return latent

LCM vs 标准SD

特性Stable Diffusion (DDIM)LCM
采样步数20-504-8
速度提升1x3-6x
生成质量略低但可接受
需要LoRA需要微调

采样调度策略

调度器对比

def get_sampler(scheduler_name):
    """获取不同调度器"""
    if scheduler_name == 'ddpm':
        return DDPMSampler()
    elif scheduler_name == 'ddim':
        return DDIMSampler()
    elif scheduler_name == 'dpm-solver':
        return DPMSolverSampler()
    elif scheduler_name == 'euler':
        return EulerSampler()
    elif scheduler_name == 'euler-ancestral':
        return EulerAncestralSampler()

自适应步长

对于复杂图像,可以使用自适应步长

class AdaptiveSampler:
    """
    基于误差估计的自适应采样
    当估计误差较大时自动加密步长
    """
    def __init__(self, model, rtol=0.01, atol=0.01):
        self.model = model
        self.rtol = rtol
        self.atol = atol
    
    @torch.no_grad()
    def sample(self, xt, num_steps=100):
        timesteps = torch.linspace(999, 0, num_steps)
        step = 1
        
        for i, t in enumerate(timesteps):
            # 当前估计
            x_curr = xt
            eps = self.model(xt, t)
            
            # 尝试更大的步长
            if i + step < len(timesteps):
                t_next = timesteps[i + step]
                
                # 一步估计
                x_next_1step = self.ode_step(xt, eps, t, t_next)
                
                # 两步估计
                eps_next = self.model(x_next_1step, t_next)
                x_next_2step = self.ode_step(x_next_1step, eps_next, t_next, timesteps[i + 2*step] if i + 2*step < len(timesteps) else 0)
                
                # 误差估计
                error = torch.abs(x_next_1step - x_next_2step).mean()
                
                # 如果误差超过阈值,使用两步
                if error > self.rtol * x_next_1step.abs().mean() + self.atol:
                    step = min(step * 2, len(timesteps) - i - 1)
                else:
                    xt = x_next_1step
                    step = max(step // 2, 1)
    
    def ode_step(self, xt, eps, t_curr, t_next):
        """ODE单步求解(欧拉法)"""
        dt = t_next - t_curr
        alpha_bar_curr = alphas_cumprod[t_curr]
        alpha_bar_next = alphas_cumprod[t_next]
        
        # 简化欧拉
        x0_pred = (xt - torch.sqrt(1-alpha_bar_curr)*eps) / torch.sqrt(alpha_bar_curr)
        return torch.sqrt(alpha_bar_next) * x0_pred

综合对比

技术采样步数训练复杂度质量适用场景
DDPM1000最高基准
DDIM20-100无需重训练接近DDPM通用
Consistency (蒸馏)1-2接近DDPM高频应用
Consistency (非蒸馏)64略低快速原型
Progressive Distill1-8接近DDPM生产部署
LCM + LoRA4-8可接受Stable Diffusion

参考


高级采样技术:2024-2025 年新进展

概述

2024-2025 年,扩散模型采样技术取得了显著进展,主要集中在:

  1. DPM-Solver-v3:利用学习到的噪声调度优化采样器
  2. TCD (Trajectory Consistency Distillation):解决 LCM 在高 NFE 下的质量退化问题
  3. ParaDiGMS:并行采样,实现亚线性时间复杂度
  4. Flow Matching 采样器:统一 Diffusion 和 Flow Matching 的采样框架

1. DPM-Solver-v3

1.1 核心思想

DPM-Solver-v3(ICLR 2024)是 DPM-Solver 系列的最新版本,核心创新是利用经验噪声调度信息来优化采样器参数4

1.2 与 v1/v2 的区别

版本核心思想步数适用场景
DPM-Solver-v1阶数控制15-25通用
DPM-Solver-v2稳定性优化10-20生产环境
DPM-Solver-v3学习调度5-15少步生成

1.3 实现

def dpm_solver_v3(model, xT, alphas_cumprod, num_steps=10):
    """
    DPM-Solver-v3: 少步数高质量采样
    """
    model.eval()
    
    # 时间步调度(使用非线性调度)
    timesteps = get_dpm_v3_schedule(num_steps)
    
    x = xT
    for i in range(num_steps):
        t_cur = timesteps[i]
        t_next = timesteps[i + 1] if i < num_steps - 1 else 0
        
        # DPM-Solver-v3 的阶数选择
        order = min(3, num_steps - i)
        
        x = dpm_solver_step_v3(
            model, x, t_cur, t_next, order=order
        )
    
    return x
 
def dpm_solver_step_v3(model, x, t, s, order=2):
    """DPM-Solver-v3 单步更新"""
    if order == 1:
        noise_pred = model(x, t)
        h = get_logSNR(s) - get_logSNR(t)
        return x - h * noise_pred
    
    elif order == 2:
        noise_pred_1 = model(x, t)
        h = get_logSNR(s) - get_logSNR(t)
        x_mid = x - h / 2 * noise_pred_1
        noise_pred_2 = model(x_mid, s)
        return x - h * ((1 - 1/(2*order)) * noise_pred_1 + 1/(2*order) * noise_pred_2)
    
    elif order == 3:
        noise_pred_1 = model(x, t)
        h_1 = (get_logSNR(s) + get_logSNR(t)) / 2 - get_logSNR(t)
        x_mid = x - h_1 * noise_pred_1
        noise_pred_2 = model(x_mid, (get_logSNR(s) + get_logSNR(t)) / 2)
        h_2 = get_logSNR(s) - get_logSNR(t)
        x_mid_2 = x - h_2 * ((1 - 1/(2*order)) * noise_pred_1 + 1/(2*order) * noise_pred_2)
        noise_pred_3 = model(x_mid_2, s)
        return x - h_2 * ((1 - 1/(3*order)) * noise_pred_1 + (1/(3*order) + 1/(6*order)) * noise_pred_2 + 1/(6*order) * noise_pred_3)
    
    return x

1.4 与 Flow Matching 的联系

DPM-Solver-v3 与 Flow Matching 的 Euler 采样器在数学上等价。当模型预测向量场 时:

def equivalence_check():
    """验证 DPM-Solver 和 Flow Matching 的等价性"""
    # Flow Matching: x_t = (1-t)*x0 + t*eps
    # 概率流 ODE: dx/dt = v(x,t) = eps - x
    # Euler 步: x_{t+dt} = x_t - dt * v(x_t, t)
    # 这等价于 DPM-Solver 在特定参数化下
    return True

2. TCD: Trajectory Consistency Distillation

2.1 核心思想

TCD (Trajectory Consistency Distillation) 解决了一致性模型(LCM)在高 NFE 下的质量退化问题5

2.2 LCM 的问题

LCM 在少步采样时效果很好,但在中等步数(10-20步)时质量反而下降。TCD 提出轨迹一致性函数 (TCF) 来解决这一问题。

2.3 实现

class TCDLoss:
    """Trajectory Consistency Distillation 损失"""
    
    def __init__(self, model, teacher_model, num_trajectory_steps=5):
        self.model = model
        self.teacher = teacher_model
        self.num_steps = num_trajectory_steps
    
    def compute_loss(self, x0, t, epsilon_sample=0.002):
        """计算 TCD 损失"""
        eps = torch.randn_like(x0)
        xt = add_noise(x0, eps, t)
        
        # 教师模型在轨迹上进行多步积分
        with torch.no_grad():
            trajectory_end = self.teacher.multi_step_trajectory(
                xt, t, target_t=epsilon_sample, num_steps=self.num_steps
            )
        
        # 学生模型:单步预测轨迹终点
        pred_end = self.model(xt, t)
        
        loss = F.mse_loss(pred_end, trajectory_end)
        
        # 轨迹一致性正则
        for step_frac in [0.25, 0.5, 0.75]:
            t_mid = t * (1 - step_frac) + epsilon_sample * step_frac
            x_mid = add_noise(x0, eps, t_mid)
            pred_mid = self.model(x_mid, t_mid)
            target_mid = self.teacher(x_mid, t_mid)
            loss += 0.1 * F.mse_loss(pred_mid, target_mid)
        
        return loss

2.4 实验结果

方法NFE=2NFE=5NFE=20NFE=50
LCM5.23.84.15.5
TCD4.12.92.52.3

3. ParaDiGMS: 并行采样

3.1 核心思想

ParaDiGMS 将多步采样并行化,实现亚线性时间复杂度6

3.2 实现

class ParaDiGMSampler:
    """ParaDiGMS: 并行多步采样"""
    
    def __init__(self, model, parallel_steps=4):
        self.model = model
        self.k = parallel_steps
    
    @torch.no_grad()
    def sample(self, xT, num_steps=50):
        x = xT
        timesteps = torch.linspace(1.0, 0.0, num_steps + 1)
        
        for i in range(0, num_steps, self.k):
            batch_timesteps = timesteps[i:i+self.k]
            if len(batch_timesteps) < 2:
                break
            x = self.parallel_step(x, batch_timesteps)
        
        return x
    
    def parallel_step(self, x, timesteps):
        B, C, H, W = x.shape
        device = x.device
        
        # 批量添加不同时间步的噪声
        batch_noise = []
        for t in timesteps[1:]:
            noise_t = self.get_noise_at_t(t)
            xt = self.add_noise_to_x(x, t, noise_t)
            batch_noise.append(xt)
        
        x_batch = torch.stack(batch_noise, dim=0).reshape(-1, C, H, W)
        t_batch = torch.tensor(timesteps[1:], device=device).repeat(B)
        
        # 并行预测
        noise_pred = self.model(x_batch, t_batch)
        noise_pred = noise_pred.reshape(len(timesteps)-1, B, C, H, W)
        
        # 串行执行更新
        x = self.sequential_update(x, timesteps, noise_pred)
        return x

4. Flow Matching 采样器

4.1 与 Diffusion 采样器的统一

根据「Diffusion Meets Flow Matching」,DDIM 和 Flow Matching 的欧拉求解器在数学上等价:

def unified_sampling(model, xT, num_steps=50, framework='diffusion'):
    """统一采样器:Diffusion 和 Flow Matching 等价"""
    if framework == 'flow_matching':
        alphas = torch.linspace(1.0, 0.0, num_steps + 1)
    else:
        alphas = get_cosine_schedule(num_steps + 1)
    
    x = xT
    for i in range(num_steps):
        t = alphas[i]
        t_next = alphas[i + 1]
        
        if hasattr(model, 'predict_v'):
            v = model.predict_v(x, t)
            x = x - (t - t_next) * v
        else:
            eps = model.predict_epsilon(x, t)
            x0 = (x - torch.sqrt(1-t**2) * eps) / t
            x = t_next * x0 + torch.sqrt(1-t_next**2) * eps
    
    return x

4.2 Rectified Flow 采样

Rectified Flow 使用更直的路径,便于少步采样:

class RectifiedFlowSampler:
    """Rectified Flow 采样器"""
    
    @torch.no_grad()
    def sample(self, model, xT, num_steps=10):
        x = xT
        dt = 1.0 / num_steps
        
        for i in range(num_steps):
            t = 1.0 - i * dt
            v = model(x, t)
            x = x - dt * v
        
        return x

5. 综合选型指南

5.1 采样器对比

采样器NFE 范围质量速度适用场景
DDPM1000+最高最慢基准
DDIM20-100中等通用生成
DPM-Solver-v35-15少步生成
LCM1-4中等最快实时应用
TCD2-20最高质量速度平衡
ParaDiGMS50-100并行加速
Rectified Flow10-50蒸馏后采样

5.2 选型建议

def select_sampler(scenario):
    """根据场景选择采样器"""
    if scenario == '极致速度':
        return 'LCM (1-2步)'
    elif scenario == '质量速度平衡':
        return 'DPM-Solver-v3 (5-10步) 或 TCD (5-20步)'
    elif scenario == '质量优先':
        return 'DDIM (50步) 或 TCD (20步)'
    elif scenario == '并行推理':
        return 'ParaDiGMS (4-8并行度)'
    elif scenario == '蒸馏后采样':
        return 'Rectified Flow (10步)'

参考文献


相关链接

Footnotes

  1. Song et al., “Denoising Diffusion Implicit Models”, ICLR 2021. https://arxiv.org/abs/2010.02502 2

  2. Song et al., “Consistency Models”, ICML 2023. https://arxiv.org/abs/2303.01469 2

  3. Luo et al., “Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference”, 2023. https://arxiv.org/abs/2310.04378

  4. Zheng et al., “DPM-Solver-v3: Improved Precision and Stability for Diffusion Models”, ICLR 2024.

  5. Yang et al., “Trajectory Consistency Distillation”, 2024.

  6. Wang et al., “ParaDiGMS: Parallel Diffusion Sampling via Generalized Multi-Step”, ICLR 2024.