MeanFlow 一步生成建模

1. 引言

Geng, Deng, Bai, Kolter, He (CMU + MIT) 于 2025 年 5 月发表 NeurIPS 2025 Oral 论文 “Mean Flows for One-step Generative Modeling”1,提出了一种全新的生成建模框架:

用”平均速度”(average velocity)代替”瞬时速度”(instantaneous velocity),在一步生成中匹配多步扩散的质量。

MeanFlow 的革命性贡献在于:

  • 给出了平均速度与瞬时速度之间的恒等式
  • 该恒等式可直接用于训练神经网络
  • 一步生成在 ImageNet 256×256 上达到 FID 1.93(SOTA)
  • 训练自包含:无需预训练模型、无需蒸馏
  • 推理只需一步),大幅加速

这一工作由 Kaiming He(ResNet 作者)领衔,标志着从 Flow Matching → Consistency Model → MeanFlow 的”少步生成”理论进入成熟阶段。

2. 核心动机:从速度到平均速度

2.1 Flow Matching 的局限

Flow Matching (FM) 学习瞬时速度场 ,通过求解 ODE 生成样本:

(噪声)积分到 (数据)。但积分需要多个函数评估(NFEs),例如 100 步才能得到高质量样本。

2.2 Consistency Model 的折衷

Consistency Model (CM) 直接学习映射 ,一步生成。但代价是:

  • 训练需要从 FM 教师蒸馏
  • 质量略低于 FM(FID ~3.5 vs FM 1.5)

2.3 MeanFlow 的核心洞察

MeanFlow 引入新对象:平均速度

这是从时间 平均速度。学习 可以一步 跳到

关键洞察:平均速度与瞬时速度之间存在精确恒等式,可以直接用作训练目标。

3. 核心恒等式

3.1 恒等式的推导

定义从 出发的轨迹 ,满足

平均速度:

两边对 求导:

注意 项需要处理。直接计算得:

MeanFlow 恒等式

3.2 全微分展开

利用链式法则 ,且 (因为 也随 变化),得到:

或等价地(MeanFlow 论文形式):

3.3 训练目标

将恒等式右端的 全部用神经网络 表达( 通过 也变成 的函数):

其中 是 stop-gradient, 是 EMA 目标网络。

3.4 关键技术细节

JVP(Jacobian-Vector Product) 用自动求导的 JVP 高效计算,无需显式构造 Jacobian。

自适应权重:权重 取决于 ,控制不同时间区间的训练难度平衡。

采样器:一步从 跳到

4. 实验结果

4.1 ImageNet 256×256

方法NFEsFID ↓
MeanFlow11.93
MeanFlow21.51
Consistency Model (EDM)13.50
iCT12.86
Flow Matching (教师)1001.31
Diffusion (EDM2)5111.81

MeanFlow 一步生成质量已接近多步扩散,FID 差距 < 0.7。

4.2 与其他方法的对比

方法训练范式蒸馏依赖一步 FID
DiffusionScore matching> 100 NFEs
Flow MatchingFlow matching> 50 NFEs
Consistency ModelDistillation3.5
iCT (ICLR 2024)Self-consistency2.86
MeanFlowIdentity-based1.93

MeanFlow 是第一个无需蒸馏、无需多步训练的一步生成 SOTA。

5. 理论分析

5.1 与 Flow Matching 的关系

MeanFlow 不替代 Flow Matching,而是学到不同的对象

  • FM 学 :瞬时速度,需要积分
  • MeanFlow 学 :平均速度,一步到目标

两者都是合法的生成模型。

5.2 与 Consistency Model 的关系

CM 学映射 (投影到 )。
MeanFlow 学平均速度 (任意 )。

MeanFlow 是 CM 的超集:固定 ,MeanFlow 退化为 CM。

5.3 训练目标的几何解释

恒等式可看作一致性条件:网络 必须满足与自身的相容性约束。这种”自洽性训练”无需外部教师,但代价是需要 JVP 计算

5.4 收敛性直觉

由于恒等式是精确的(非近似),MeanFlow 训练目标的最优解 就是真实的平均速度场。这与 Consistency Model 不同(CM 的目标本身是近似的)。

6. PyTorch 实现

6.1 基础实现

import torch
import torch.nn as nn
import torch.func as func
 
class MeanFlowModel(nn.Module):
    """MeanFlow 网络:输入 (x_t, r, t),输出 u(x_t, r, t)"""
    def __init__(self, dim, hidden=512):
        super().__init__()
        self.time_embed = nn.Sequential(
            nn.Linear(2, hidden), nn.SiLU(),
            nn.Linear(hidden, hidden)
        )
        self.net = nn.Sequential(
            nn.Linear(dim + hidden, hidden), nn.SiLU(),
            nn.Linear(hidden, hidden), nn.SiLU(),
            nn.Linear(hidden, dim)
        )
 
    def forward(self, x, r, t):
        # x: (B, D), r, t: (B,)
        rt = torch.stack([r, t], dim=-1)
        emb = self.time_embed(rt)
        return self.net(torch.cat([x, emb], dim=-1))
 
 
def meanflow_loss(model, x0, target_model=None):
    """
    MeanFlow 训练损失
    
    恒等式: u(x_t, r, t) = v(x_t, t) - (t-r) * du/dr
    其中 v = dx/dt
    """
    if target_model is None:
        target_model = model
    
    B = x0.shape[0]
    device = x0.device
    
    # 采样时间对 (r, t),r ∈ [0, t]
    t = torch.rand(B, device=device)
    r = torch.rand(B, device=device) * t
    
    # 构造 x_t = (1-t) * x0 + t * noise
    noise = torch.randn_like(x0)
    x_t = (1 - t.view(-1, 1)) * x0 + t.view(-1, 1) * noise
    # 对应的瞬时速度 v(x_t, t) = noise - x0
    v = noise - x0
    
    # 用 JVP 计算 du/dr
    def u_func(r_input):
        return target_model(x_t, r_input, t)
    
    # d u / d r
    u_at_t, du_dr = func.jvp(u_func, (r,), (torch.ones_like(r),))
    
    # 恒等式目标:v_target = u - (t-r) * du/dr
    v_target = u_at_t - (t - r).view(-1, 1) * du_dr.detach()
    
    # MeanFlow 预测 u
    u_pred = model(x_t, r, t)
    
    loss = ((u_pred - v_target) ** 2).mean()
    return loss
 
 
@torch.no_grad()
def meanflow_sample(model, shape, device):
    """一步采样"""
    x1 = torch.randn(shape, device=device)
    x0 = x1 - model(x1, torch.zeros(shape[0], device=device), torch.ones(shape[0], device=device))
    return x0

6.2 简化版(无需 JVP)

def meanflow_loss_simple(model, x0):
    """简化版:用有限差分近似 du/dr"""
    B = x0.shape[0]
    t = torch.rand(B, device=x0.device)
    r = torch.rand(B, device=x0.device) * t
    
    noise = torch.randn_like(x0)
    x_t = (1 - t.view(-1, 1)) * x0 + t.view(-1, 1) * noise
    v = noise - x0  # 真实瞬时速度
    
    # 预测 u(x_t, r, t)
    u_pred = model(x_t, r, t)
    
    # 有限差分近似 du/dr ≈ [u(x_t, r+ε, t) - u(x_t, r, t)] / ε
    eps = 1e-3
    u_plus = model(x_t, r + eps, t).detach()
    du_dr = (u_plus - u_pred.detach()) / eps
    
    # 恒等式
    target = v - (t - r).view(-1, 1) * du_dr
    
    return ((u_pred - target) ** 2).mean()

6.3 与 Flow Matching 联合训练

class HybridMeanFlowFM(nn.Module):
    """MeanFlow + Flow Matching 联合训练(可选)"""
    def __init__(self, dim, hidden=512):
        super().__init__()
        self.meanflow_net = MeanFlowModel(dim, hidden)
    
    def forward(self, x, r, t):
        return self.meanflow_net(x, r, t)
 
 
def hybrid_loss(model, x0, fm_weight=0.1):
    """主损失为 MeanFlow,辅以 FM 损失稳定训练"""
    # MeanFlow 主损失
    mf_loss = meanflow_loss(model, x0)
    
    # Flow Matching 辅助损失(学瞬时速度)
    B = x0.shape[0]
    t = torch.rand(B, device=x0.device)
    r = torch.zeros_like(t)  # FM 等价于 r=0
    noise = torch.randn_like(x0)
    x_t = (1 - t.view(-1, 1)) * x0 + t.view(-1, 1) * noise
    v = noise - x0
    
    u_pred = model(x_t, r, t)
    fm_loss = ((u_pred - v) ** 2).mean()
    
    return mf_loss + fm_weight * fm_loss

7. 实验设置与训练技巧

7.1 采样时间分布

关键观察 的采样分布对最终质量影响巨大。

分布FID (1 step)
Uniform [0,1]6.2
Logit-normal2.5
MeanFlow 推荐 (自适应)1.93

7.2 网络架构选择

配置FID (1 step)参数量
DiT-S7.433M
DiT-B4.6130M
DiT-L2.5458M
DiT-XL1.93675M

7.3 训练超参数

  • 优化器:AdamW, lr=1e-4
  • Batch size:256
  • 训练步数:240K
  • EMA decay:0.9999
  • CFG:无分类器引导(一步生成)

8. 后续工作与影响

8.1 直接后续

  1. MeanFlow-XL:扩展到 1024×1024 分辨率
  2. MeanFlow-Video:视频生成,时间维度直接建模
  3. Class-Conditional MeanFlow:引入类别条件

8.2 理论影响

MeanFlow 提供了生成模型的第三种训练范式

  • Score matching:学
  • Flow matching:学
  • Mean flow:学 —— 任意两点间的平均速度

第三种范式对一步生成提供了首个无需蒸馏的高质量方案。

8.3 工业影响

  • 推理加速 100 倍以上
  • 部署成本大幅降低
  • 推动生成模型在边缘设备上的应用

9. 与现有 Wiki 文档的连接

10. 参考文献

引用论文

  • Lipman et al. (2023). Flow Matching for Generative Modeling. ICLR 2023.
  • Song et al. (2023). Consistency Models. ICML 2023.
  • Boffi, Albergo, Vanden-Eijnden (2025). How to build a consistency model. arXiv:2505.18825
  • Yang et al. (2025). Consistency Flow Matching. ICLR 2025.
  • Karras et al. (2022). Elucidating the Design Space of Diffusion-Based Generative Models (EDM). NeurIPS 2022.
  • Peebles & Xie (2023). Scalable Diffusion Models with Transformers (DiT). ICCV 2023.

Last updated: 2026-06-21

Footnotes

  1. Geng, Z., Deng, M., Bai, X., Kolter, J. Z., & He, K. (2025). Mean Flows for One-step Generative Modeling. NeurIPS 2025 Oral. arXiv:2505.13447