Flow-Anchored Consistency Models (FACM)
流锚定一致性模型(Flow-Anchored Consistency Models,FACM)是一种解决连续时间一致性模型训练不稳定性问题的创新方法1。该方法通过将Flow Matching任务作为动态锚定,同时优化一致性目标,实现了稳定高效的少步生成。
1. 问题背景:连续时间一致性模型的训练不稳定性
1.1 连续时间一致性模型回顾
一致性模型(Consistency Models)学习一个映射 ,满足自一致性约束:
其中 是扩散过程。
1.2 训练不稳定性的根源
核心冲突:连续时间一致性模型面临一个根本性的矛盾——
| 目标 | 要求 | 冲突 |
|---|---|---|
| Shortcut学习 | 学习从 到 的捷径 | 只需捕捉端点关系 |
| Velocity Field保持 | 保持速度场的准确性 | 需要理解轨迹细节 |
灾难性遗忘现象:当网络专注于学习shortcut时,它会逐渐遗忘定义Flow的速度场信息。
1.3 不稳定性的表现
训练过程中的典型问题:
- 损失振荡:训练损失周期性剧烈波动
- 轨迹漂移:生成的样本质量随采样步数增加而下降
- 边界条件违反:
- 模式崩溃:生成样本多样性降低
2. Flow-Anchoring解决方案
2.1 核心思想
解决思路:将Flow Matching任务作为动态锚定,确保网络在优化一致性目标的同时,不遗忘底层Flow的结构。
传统一致性模型:
┌─────────────────────────────────────────────┐
│ 输入x_t → Shortcut目标 → 输出x̂_0 │
│ ↓ │
│ 遗忘速度场信息 │
└─────────────────────────────────────────────┘
FACM:
┌─────────────────────────────────────────────┐
│ 输入x_t → Shortcut目标 → 输出x̂_0 │
│ ↓ ↑ │
│ ┌──────────────────────────┐ │
│ │ Flow Matching锚定 │ │
│ │ 保持速度场准确性 │ │
│ └──────────────────────────┘ │
└─────────────────────────────────────────────┘
2.2 双重优化目标
FACM的损失函数同时包含两个目标:
其中:
2.3 理论保证
定理:Flow-Anchoring确保以下不变性:
- 边界条件保持:
- 轨迹一致性:
- 速度场准确性:速度场不因shortcut学习而退化
3. 关键技术:Expanded Time Interval
3.1 设计动机
传统方法在固定时间间隔上训练,这导致了两个任务的耦合:
- 和 接近时:一致性约束强,速度场约束弱
- 和 远离时:一致性约束弱,速度场约束强
3.2 扩展时间间隔策略
核心创新:使用非对称扩展时间间隔:
def facm_loss(model, x0, x1, ema_model):
"""
FACM损失函数(扩展时间间隔)
参数:
model: 主网络
x0: 原始数据
x1: 噪声样本
ema_model: EMA目标网络
"""
batch_size = x0.shape[0]
# 采样两个时间点(扩展间隔)
# t1: [ε, 1), 通常较大
# t2: [0, t1 - δ), 确保t2 < t1
t1 = torch.rand(batch_size) * (1 - eps) + eps
delta = torch.rand(batch_size) * t1 # 随机间隔
t2 = t1 - delta
# 生成轨迹点
x_t1 = (1 - t1) * x0 + t1 * x1 # 线性插值
x_t2 = (1 - t2) * x0 + t2 * x1
# 一致性损失
f_t1 = model(x_t1, t1)
with torch.no_grad():
f_t2_ema = ema_model(x_t2, t2)
L_cm = torch.mean((f_t1 - f_t2_ema) ** 2)
# Flow锚定损失
v_t1 = velocity_head(model, x_t1, t1) # 速度预测头
v_target = x1 - x0 # 目标速度(直线轨迹)
L_fm = torch.mean((v_t1 - v_target) ** 2)
return L_cm + lambda_fm * L_fm3.3 优势分析
| 方面 | 固定间隔 | 扩展间隔 |
|---|---|---|
| 任务解耦 | 强耦合 | 任务分离 |
| 训练稳定性 | 不稳定 | 稳定 |
| 表达能力 | 受限 | 增强 |
| 收敛速度 | 慢 | 快 |
3.4 Auxiliary Time Condition
另一种实现Flow-Anchoring的方法是使用辅助时间条件:
其中 是一个额外的辅助时间输入,用于同时编码一致性信息和Flow信息。
4. 规模化训练:Chain-JVP
4.1 内存瓶颈
FACM的双重目标在规模化时面临内存挑战:
- 主网络前向传播:
- EMA网络存储:
- Jacobian计算:
4.2 Chain-JVP方法
核心思想:利用Jacobian-Vector Product(JVP)的链式法则,避免显式存储Jacobian:
def chain_jvp(model, x, t, v):
"""
Chain-JVP计算
避免显式存储完整的Jacobian矩阵
"""
# 前向传播
y = model(x, t)
# 构造计算图
grad_y = torch.autograd.grad(
outputs=y,
inputs=model.parameters(),
grad_outputs=v,
create_graph=True
)
# JVP通过链式法则
# 不需要存储中间Jacobian
return grad_y
class FACMWithChainJVP:
def __init__(self, model, lambda_fm=0.1):
self.model = model
self.lambda_fm = lambda_fm
def training_step(self, x0, x1):
# 主前向
t = torch.rand(len(x0))
x_t = (1 - t) * x0 + t * x1
f_t = self.model(x_t, t)
# Chain-JVP计算Flow锚定梯度
# 避免存储完整Jacobian
v_target = x1 - x0
jvp_grad = self.chain_jvp(
lambda x, t: self.velocity_head(x, t),
x_t, t, v_target
)
# 一致性损失
with torch.no_grad():
x_t2 = (1 - 2*t) * x0 + 2*t * x1
f_t2_ema = self.ema_model(x_t2, 2*t)
L_cm = torch.mean((f_t - f_t2_ema) ** 2)
L_fm = torch.mean(jvp_grad ** 2)
return L_cm + self.lambda_fm * L_fm4.3 与FSDP的兼容性
Chain-JVP方法与Fully Sharded Data Parallel (FSDP) 兼容:
- 分片存储:参数分片存储在不同GPU上
- JVP计算:仅需要局部梯度信息
- 通信优化:减少跨GPU通信开销
5. 实验结果
5.1 ImageNet 256×256基准
| 方法 | NFE=1 FID | NFE=2 FID | 训练稳定性 |
|---|---|---|---|
| sCM | 2.45 | 2.12 | 不稳定 |
| Consistency Model | 2.89 | 2.34 | 较稳定 |
| FACM | 1.70 | 1.32 | 稳定 |
5.2 训练稳定性对比
训练步数
↑
│ ╱╲ ╱╲ ╱╲
│ ╱ ╲ ╱ ╲ ╱ ╲ ← sCM(振荡剧烈)
Loss│ ╱ ╲╱ ╲╱ ╲
│╱
│
│▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ← FACM(稳定下降)
└──────────────────────────→
5.3 规模化结果
| 模型 | 参数量 | NFE | FID | 文本到图像质量 |
|---|---|---|---|---|
| SDXL | 3.5B | 40 | 1.23 | 优秀 |
| FACM (Wan 2.2) | 14B | 2-8 | 1.05 | 优秀 |
6. 与Consistency-FM的联系
FACM和Consistency-FM都致力于解决一致性模型的训练问题,但侧重点不同:
| 方面 | Consistency-FM | FACM |
|---|---|---|
| 核心思想 | 速度一致性 | Flow锚定 |
| 优化目标 | 单一(速度一致性) | 双重(一致性+Flow) |
| 实现方式 | 多段训练 | 扩展时间间隔 |
| 稳定性 | 良好 | 优秀 |
| 表达能力 | 强 | 很强 |
互补性:两者可以结合使用,进一步提升性能。
7. 实践指南
7.1 推荐配置
# FACM推荐配置
config = {
'lambda_fm': 0.1, # Flow锚定权重
'eps': 0.001, # 最小时间值
'use_chain_jvp': True, # 使用Chain-JVP
'ema_decay': 0.9999, # EMA衰减
'lr': 1e-4, # 学习率
'batch_size': 2048, # 批量大小(可扩展)
}7.2 与现有架构的兼容性
FACM可以应用于任何backbone网络:
- DiT (Diffusion Transformer)
- U-Net
- UViT
- MaskDiT
7.3 微调策略
对于从预训练Diffusion模型的迁移:
- 初始阶段:只启用Flow锚定目标( 较大)
- 中期阶段:逐渐增加一致性目标权重
- 后期阶段:两者权重平衡
8. 总结
FACM通过Flow-Anchoring机制解决了连续时间一致性模型的训练不稳定性问题:
- 问题根源:Shortcut学习与Velocity Field保持的冲突
- 解决方案:双重优化目标 + 扩展时间间隔
- 技术贡献:Chain-JVP实现可扩展训练
- 实验验证:SOTA FID@1/2步,训练稳定
参考文献
Footnotes
-
“FACM: Flow-Anchored Consistency Models.” ICLR 2026. https://arxiv.org/abs/2507.03738 ↩