Flow-Anchored Consistency Models (FACM)

流锚定一致性模型(Flow-Anchored Consistency Models,FACM)是一种解决连续时间一致性模型训练不稳定性问题的创新方法1。该方法通过将Flow Matching任务作为动态锚定,同时优化一致性目标,实现了稳定高效的少步生成。

1. 问题背景:连续时间一致性模型的训练不稳定性

1.1 连续时间一致性模型回顾

一致性模型(Consistency Models)学习一个映射 ,满足自一致性约束:

其中 是扩散过程。

1.2 训练不稳定性的根源

核心冲突:连续时间一致性模型面临一个根本性的矛盾——

目标要求冲突
Shortcut学习学习从 的捷径只需捕捉端点关系
Velocity Field保持保持速度场的准确性需要理解轨迹细节

灾难性遗忘现象:当网络专注于学习shortcut时,它会逐渐遗忘定义Flow的速度场信息。

1.3 不稳定性的表现

训练过程中的典型问题:

  1. 损失振荡:训练损失周期性剧烈波动
  2. 轨迹漂移:生成的样本质量随采样步数增加而下降
  3. 边界条件违反
  4. 模式崩溃:生成样本多样性降低

2. Flow-Anchoring解决方案

2.1 核心思想

解决思路:将Flow Matching任务作为动态锚定,确保网络在优化一致性目标的同时,不遗忘底层Flow的结构。

传统一致性模型:
┌─────────────────────────────────────────────┐
│  输入x_t → Shortcut目标 → 输出x̂_0           │
│           ↓                                  │
│     遗忘速度场信息                           │
└─────────────────────────────────────────────┘

FACM:
┌─────────────────────────────────────────────┐
│  输入x_t → Shortcut目标 → 输出x̂_0           │
│           ↓              ↑                   │
│     ┌──────────────────────────┐             │
│     │   Flow Matching锚定      │             │
│     │   保持速度场准确性      │             │
│     └──────────────────────────┘             │
└─────────────────────────────────────────────┘

2.2 双重优化目标

FACM的损失函数同时包含两个目标:

其中:

2.3 理论保证

定理:Flow-Anchoring确保以下不变性:

  1. 边界条件保持
  2. 轨迹一致性
  3. 速度场准确性:速度场不因shortcut学习而退化

3. 关键技术:Expanded Time Interval

3.1 设计动机

传统方法在固定时间间隔上训练,这导致了两个任务的耦合:

  • 接近时:一致性约束强,速度场约束弱
  • 远离时:一致性约束弱,速度场约束强

3.2 扩展时间间隔策略

核心创新:使用非对称扩展时间间隔:

def facm_loss(model, x0, x1, ema_model):
    """
    FACM损失函数(扩展时间间隔)
    
    参数:
        model: 主网络
        x0: 原始数据
        x1: 噪声样本
        ema_model: EMA目标网络
    """
    batch_size = x0.shape[0]
    
    # 采样两个时间点(扩展间隔)
    # t1: [ε, 1), 通常较大
    # t2: [0, t1 - δ), 确保t2 < t1
    t1 = torch.rand(batch_size) * (1 - eps) + eps
    delta = torch.rand(batch_size) * t1  # 随机间隔
    t2 = t1 - delta
    
    # 生成轨迹点
    x_t1 = (1 - t1) * x0 + t1 * x1  # 线性插值
    x_t2 = (1 - t2) * x0 + t2 * x1
    
    # 一致性损失
    f_t1 = model(x_t1, t1)
    with torch.no_grad():
        f_t2_ema = ema_model(x_t2, t2)
    L_cm = torch.mean((f_t1 - f_t2_ema) ** 2)
    
    # Flow锚定损失
    v_t1 = velocity_head(model, x_t1, t1)  # 速度预测头
    v_target = x1 - x0  # 目标速度(直线轨迹)
    L_fm = torch.mean((v_t1 - v_target) ** 2)
    
    return L_cm + lambda_fm * L_fm

3.3 优势分析

方面固定间隔扩展间隔
任务解耦强耦合任务分离
训练稳定性不稳定稳定
表达能力受限增强
收敛速度

3.4 Auxiliary Time Condition

另一种实现Flow-Anchoring的方法是使用辅助时间条件

其中 是一个额外的辅助时间输入,用于同时编码一致性信息和Flow信息。


4. 规模化训练:Chain-JVP

4.1 内存瓶颈

FACM的双重目标在规模化时面临内存挑战:

  • 主网络前向传播:
  • EMA网络存储:
  • Jacobian计算:

4.2 Chain-JVP方法

核心思想:利用Jacobian-Vector Product(JVP)的链式法则,避免显式存储Jacobian:

def chain_jvp(model, x, t, v):
    """
    Chain-JVP计算
    
    避免显式存储完整的Jacobian矩阵
    """
    # 前向传播
    y = model(x, t)
    
    # 构造计算图
    grad_y = torch.autograd.grad(
        outputs=y,
        inputs=model.parameters(),
        grad_outputs=v,
        create_graph=True
    )
    
    # JVP通过链式法则
    # 不需要存储中间Jacobian
    return grad_y
 
 
class FACMWithChainJVP:
    def __init__(self, model, lambda_fm=0.1):
        self.model = model
        self.lambda_fm = lambda_fm
    
    def training_step(self, x0, x1):
        # 主前向
        t = torch.rand(len(x0))
        x_t = (1 - t) * x0 + t * x1
        
        f_t = self.model(x_t, t)
        
        # Chain-JVP计算Flow锚定梯度
        # 避免存储完整Jacobian
        v_target = x1 - x0
        jvp_grad = self.chain_jvp(
            lambda x, t: self.velocity_head(x, t),
            x_t, t, v_target
        )
        
        # 一致性损失
        with torch.no_grad():
            x_t2 = (1 - 2*t) * x0 + 2*t * x1
            f_t2_ema = self.ema_model(x_t2, 2*t)
        
        L_cm = torch.mean((f_t - f_t2_ema) ** 2)
        L_fm = torch.mean(jvp_grad ** 2)
        
        return L_cm + self.lambda_fm * L_fm

4.3 与FSDP的兼容性

Chain-JVP方法与Fully Sharded Data Parallel (FSDP) 兼容:

  • 分片存储:参数分片存储在不同GPU上
  • JVP计算:仅需要局部梯度信息
  • 通信优化:减少跨GPU通信开销

5. 实验结果

5.1 ImageNet 256×256基准

方法NFE=1 FIDNFE=2 FID训练稳定性
sCM2.452.12不稳定
Consistency Model2.892.34较稳定
FACM1.701.32稳定

5.2 训练稳定性对比

训练步数
  ↑
  │   ╱╲    ╱╲    ╱╲
  │  ╱  ╲  ╱  ╲  ╱  ╲    ← sCM(振荡剧烈)
Loss│ ╱    ╲╱    ╲╱    ╲
  │╱
  │
  │▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁   ← FACM(稳定下降)
  └──────────────────────────→

5.3 规模化结果

模型参数量NFEFID文本到图像质量
SDXL3.5B401.23优秀
FACM (Wan 2.2)14B2-81.05优秀

6. 与Consistency-FM的联系

FACM和Consistency-FM都致力于解决一致性模型的训练问题,但侧重点不同:

方面Consistency-FMFACM
核心思想速度一致性Flow锚定
优化目标单一(速度一致性)双重(一致性+Flow)
实现方式多段训练扩展时间间隔
稳定性良好优秀
表达能力很强

互补性:两者可以结合使用,进一步提升性能。


7. 实践指南

7.1 推荐配置

# FACM推荐配置
config = {
    'lambda_fm': 0.1,           # Flow锚定权重
    'eps': 0.001,               # 最小时间值
    'use_chain_jvp': True,      # 使用Chain-JVP
    'ema_decay': 0.9999,        # EMA衰减
    'lr': 1e-4,                 # 学习率
    'batch_size': 2048,         # 批量大小(可扩展)
}

7.2 与现有架构的兼容性

FACM可以应用于任何backbone网络:

  • DiT (Diffusion Transformer)
  • U-Net
  • UViT
  • MaskDiT

7.3 微调策略

对于从预训练Diffusion模型的迁移:

  1. 初始阶段:只启用Flow锚定目标( 较大)
  2. 中期阶段:逐渐增加一致性目标权重
  3. 后期阶段:两者权重平衡

8. 总结

FACM通过Flow-Anchoring机制解决了连续时间一致性模型的训练不稳定性问题:

  1. 问题根源:Shortcut学习与Velocity Field保持的冲突
  2. 解决方案:双重优化目标 + 扩展时间间隔
  3. 技术贡献:Chain-JVP实现可扩展训练
  4. 实验验证:SOTA FID@1/2步,训练稳定

参考文献

Footnotes

  1. “FACM: Flow-Anchored Consistency Models.” ICLR 2026. https://arxiv.org/abs/2507.03738