扩散模型与 Flow Matching 的等价性
概述
2024年,Google DeepMind 发表了里程碑式的工作「Diffusion Meets Flow Matching」,证明了扩散模型和 Flow Matching 在高斯源分布下是数学等价的。1
这一发现统一了两个看似不同的生成式建模框架,使得研究者可以在统一的视角下理解和应用这些技术。
1. 两种框架的类比
扩散模型(Diffusion)
| 组件 | 描述 |
|---|---|
| 前向过程 | |
| 信噪比 | |
| 噪声调度 | 通常 , |
Flow Matching
| 组件 | 描述 |
|---|---|
| 路径 | , |
| 向量场 | |
| 条件分布 |
关键类比
┌─────────────────────────────────────────────────────────────┐
│ 统一表示 │
├─────────────────────────────────────────────────────────────┤
│ 扩散模型: x_t = αₜx₀ + σₜε │
│ Flow Match: x_t = (1-t)x₀ + tx₁ │
│ │
│ 当设置 αₜ = 1-t, σₜ = t, x₁ = ε 时,两者完全等价! │
└─────────────────────────────────────────────────────────────┘
2. 数学统一
2.1 统一的前向过程
定义参数化:
其中 是噪声源分布。
扩散模型通过噪声调度 定义路径:
- VP (Variance Preserving):
- VE (Variance Exploding):
- 余弦调度、线性调度等
Flow Matching使用线性插值:
2.2 统一的向量场
从概率路径 出发,定义概率流 ODE:
其中 是向量场,满足:
对于高斯条件分布:
2.3 Score 函数与向量场的关系
Score 函数定义为:
对于高斯分布 :
其中 是噪声。
概率流 ODE 与 Score 的关系:
其中 , 。
3. 训练目标的等价性
3.1 扩散模型的训练目标
预测噪声 :
预测 :
3.2 Flow Matching 的训练目标
预测向量场 :
3.3 目标等价性证明
定理:对于任意噪声调度 ,以下三种训练目标在最优解下等价:
- 预测噪声
- 预测
- 预测 (当 时)
关键关系:
4. 网络输出的选择
4.1 四种输出类型
| 输出类型 | 公式 | 特点 |
|---|---|---|
| 预测 | 高噪声时不稳定() | |
| 预测 | 低噪声时不稳定() | |
| 预测 | 平衡方案 | |
| (Flow) | Flow Matching 专用 |
4.2 -Prediction 的优势
-prediction(又称 -space 或 EDM 参数化)由 Karras 等人在 EDM 论文中提出:
优势:
- 数值稳定性:在所有时间步都有合理的值
- 等效权重:
- 与 Flow Matching 兼容: 本质上是流向场的估计
4.3 权重等价性
Google DeepMind 的发现表明:Flow Matching 权重等价于 -MSE 损失 + 余弦调度。
def v_prediction_loss(model, x0, eps, t, sigma):
"""
v-prediction 损失
v = alpha * eps - sigma * x0
"""
alpha = torch.sqrt(1 - sigma**2) # VP调度
v = alpha * eps - sigma * x0
pred_v = model(x0 + sigma * eps, sigma)
return F.mse_loss(pred_v, v)5. 采样器的等价性
5.1 DDIM = Flow Matching 采样器
定理:DDIM 采样器()等价于 Flow Matching 的欧拉求解器。
DDIM 更新公式:
Flow Matching 欧拉更新:
5.2 统一采样框架
def unified_sampling(model, xT, num_steps=50, eta=0.0):
"""
统一采样:适用于 Diffusion 和 Flow Matching
Args:
model: 去噪模型(预测 v 或 ε)
xT: 初始噪声
num_steps: 采样步数
eta: 随机性 (0=确定, 1=随机)
"""
# 时间步调度
timesteps = torch.linspace(1, 0, num_steps + 1)
x = xT
for i in range(num_steps):
t = timesteps[i]
t_next = timesteps[i + 1]
# 预测 v 向量场
v_pred = model(x, t)
# 欧拉更新(确定性部分)
dt = t - t_next
x = x + dt * v_pred
# 添加随机性(可选)
if eta > 0:
noise = torch.randn_like(x)
x = x + eta * np.sqrt(dt) * noise
return x6. Rectified Flow(整流流)
6.1 核心思想
Rectified Flow 由 Liu 等人提出,核心思想是通过蒸馏获得更直的路径。2
问题:原始 Flow Matching 的路径可能是弯曲的,导致需要更多采样步数。
解决方案:通过自编码器将路径「拉直」。
6.2 Reflow 操作
def reflow(model, x0_samples, num_steps=100):
"""
Rectified Flow 的 Reflow 操作
通过蒸馏获得更直的路径
"""
# 起始点:数据分布
x_start = x0_samples
# 目标点:纯噪声
x_end = torch.randn_like(x_start)
# 生成新的(更直的)轨迹
# 方法1:使用当前模型生成中间点
with torch.no_grad():
trajectory = []
x = x_end
timesteps = torch.linspace(0, 1, num_steps)
for t in timesteps:
v = model(x, t)
x = x - v * (1 / num_steps) # 逆向生成
trajectory.append(x)
# 方法2:从轨迹中采样新路径
# 新路径:直接从 x_start 到 x_end
# 但中间点来自原始轨迹的蒸馏
return trajectory6.3 路径「直度」与采样效率
弯曲路径 ←── 需要多步采样
│
│ Reflow
▼
直线路径 ←── 可以少步采样
关键洞察:路径直度不是唯一目标。端点分布的特性(如对齐、支撑覆盖)比路径直度更重要。
7. 理论深度
7.1 为什么等价性成立?
核心条件:源分布 是标准高斯分布 。
当这一条件成立时:
- 条件分布
任意参数化 都定义了一条从 到 的路径,只是路径形状不同。
7.2 训练调度不变性
重要发现:训练损失对噪声调度 不变(只取决于端点的信噪比范围)。
这意味着:
- 可以自由选择任何调度(如余弦、线性、Sigmoid)
- 调度只影响采样效率,不影响模型容量
7.3 连续时间极限
在连续时间极限下,DDPM(离散)和 Flow Matching(连续)完全等价:
8. 实践指南
8.1 框架选择
| 场景 | 推荐框架 | 理由 |
|---|---|---|
| 理论研究 | Flow Matching | 数学简洁 |
| 图像生成 | Diffusion (DiT) | 生态丰富 |
| 少步采样 | Flow Matching + Rectified | 直路径便于蒸馏 |
| 高分辨率 | Latent Diffusion | 计算效率 |
8.2 实现建议
# 推荐:使用统一的 v-prediction 接口
class UnifiedDiffusionModel:
def __init__(self, network, scheduler='cosine'):
self.network = network
self.scheduler = scheduler
def training_loss(self, x0, noise, t):
"""
统一的训练接口
"""
# 计算 alpha, sigma
if self.scheduler == 'cosine':
alpha = torch.cos(t * np.pi / 2)
sigma = torch.sin(t * np.pi / 2)
else:
alpha = 1 - t
sigma = t
# 前向过程
xt = alpha * x0 + sigma * noise
# v-prediction 目标
v_target = alpha * noise - sigma * x0
# 预测
v_pred = self.network(xt, t)
return F.mse_loss(v_pred, v_target)
@torch.no_grad()
def sampling(self, xT, num_steps=50):
"""
统一采样接口
"""
x = xT
timesteps = torch.linspace(1, 0, num_steps)
for i in range(num_steps):
t = timesteps[i]
t_next = timesteps[i + 1] if i < num_steps - 1 else 0
v = self.network(x, t)
x = x - (t - t_next) * v
return x9. 参考文献
相关链接
Footnotes
-
Diffusion Meets Flow Matching - Google DeepMind, 2024. https://diffusionflow.github.io/ ↩
-
Liu et al., “Flow Matching: Simplifying and Generalizing Diffusion Models”, ICML 2024. ↩