FlowMo: Mode-Seeking扩散自编码器
FlowMo 是 ICCV 2025 论文 Sargent et al. - “Flow to the Mode” 中提出的新型自编码器,通过 Flow-based 解码器 实现 mode-seeking 行为,有效解决标准 VAE 的 posterior collapse 和模糊重建问题。
1. 问题背景:Posterior Collapse
1.1 标准VAE的困境
标准变分自编码器(VAE)面临一个根本性挑战:posterior collapse。
标准 VAE 的 ELBO:
问题:当解码器太强时,优化器倾向于让 ,导致:
| 现象 | 表现 | 影响 |
|---|---|---|
| Posterior Collapse | 后验退化为先验 | 潜在表示丢失语义信息 |
| 模糊重建 | 多个模式被平均 | 生成质量下降 |
| 多模态崩溃 | 无法捕获数据分布的多模态性 | 样本多样性丢失 |
1.2 多模态数据的挑战
对于多模态数据分布 :
标准 AE/VAE 的问题:
这导致重建时多模态信息被混合,产生模糊输出。
┌─────────────────────────────────────────────────────────────┐
│ 多模态数据处理 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 输入数据分布 VAE重建 │
│ ╭──────╮ ╭──────╮ │
│ │ ● ● │ │ ◐ │ ← 模糊/平均 │
│ ╰──────╯ ╰──────╯ │
│ ● ● │
│ ● ● │
│ │
│ FlowMo 重建 理想重建 │
│ ╭──────╮ ╭──────╮ │
│ │ ● ● │ │ ● ● │ │
│ ╰──────╯ ╰──────╯ │
│ ● ● │
│ ● ● │
│ │
└─────────────────────────────────────────────────────────────┘
2. FlowMo解决方案
2.1 核心思想
FlowMo 的关键洞察:用 Flow-based 解码器替代标准解码器,实现 mode-seeking 行为。
标准解码器:
FlowMo 解码器:
其中 是一个可逆神经网络。
2.2 Flow-based 解码器架构
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class FlowMoDecoder(nn.Module):
"""
FlowMo: Flow-based Mode-Seeking Decoder
核心思想:解码器学习从标准分布到多模态数据分布的变换
"""
def __init__(self,
latent_dim: int,
data_channels: int = 3,
hidden_channels: int = 512,
num_flow_steps: int = 8):
super().__init__()
self.latent_dim = latent_dim
self.data_channels = data_channels
self.num_flow_steps = num_flow_steps
# 初态变换:从潜在空间到流输入
self.init_transform = nn.Sequential(
nn.Linear(latent_dim, hidden_channels),
nn.ReLU(),
nn.Linear(hidden_channels, hidden_channels),
nn.ReLU(),
)
# Flow 步骤
self.flow_steps = nn.ModuleList([
FlowStep(hidden_channels, data_channels)
for _ in range(num_flow_steps)
])
# 输出层
self.final_layer = nn.Sequential(
nn.Linear(hidden_channels, hidden_channels),
nn.ReLU(),
nn.Linear(hidden_channels, data_channels)
)
def forward(self, z, u=None):
"""
前向传播:解码潜在向量
Args:
z: 潜在向量 [B, latent_dim]
u: 可选的输入噪声(用于条件生成)
"""
if u is None:
u = torch.randn(z.shape[0], self.data_channels, device=z.device)
# 初始化
x = self.init_transform(z)
# 展平为 [B, H*W, C]
B = x.shape[0]
h = w = int(math.sqrt(x.shape[1] // self.data_channels))
x = x.reshape(B, h * w, self.data_channels)
# 通过 Flow 步骤
for flow_step in self.flow_steps:
x, log_det = flow_step(x)
# 输出
x = self.final_layer(x)
return x, log_det
def inverse(self, x, z=None):
"""
逆向传播:从数据到潜在空间
用于计算对数似然
"""
# 逆 Flow 步骤
u = x
log_det_sum = 0
for flow_step in reversed(self.flow_steps):
u, log_det = flow_step.inverse(u)
log_det_sum += log_det
# 映射回潜在空间
if z is not None:
z_pred = self.init_transform.inverse(u)
return z_pred, log_det_sum
return u, log_det_sum
class FlowStep(nn.Module):
"""
单个 Flow 步骤
包含:
- 仿射耦合层
- 可逆操作
"""
def __init__(self, hidden_dim: int, data_dim: int):
super().__init__()
# 缩放和移位网络
self.scale_net = nn.Sequential(
nn.Linear(data_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, data_dim)
)
self.shift_net = nn.Sequential(
nn.Linear(data_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, data_dim)
)
# 置换层(固定)
self.register_buffer('perm', torch.eye(data_dim))
def forward(self, x):
"""
前向传播
x = [x1; x2]
y1 = x1
y2 = x2 * exp(s(x1)) + t(x1)
"""
x1, x2 = x.chunk(2, dim=-1)
# 计算缩放和移位
scale = torch.exp(torch.clamp(self.scale_net(x1), -5, 5))
shift = self.shift_net(x1)
# 耦合变换
y2 = x2 * scale + shift
# 对数行列式
log_det = torch.sum(torch.log(scale), dim=-1)
# 拼接
y = torch.cat([x1, y2], dim=-1)
return y, log_det
def inverse(self, y):
"""逆向传播"""
y1, y2 = y.chunk(2, dim=-1)
scale = torch.exp(torch.clamp(self.scale_net(y1), -5, 5))
shift = self.shift_net(y1)
x2 = (y2 - shift) / scale
x = torch.cat([y1, x2], dim=-1)
log_det = -torch.sum(torch.log(scale), dim=-1)
return x, log_det2.3 Mode-Seeking 行为分析
关键属性:FlowMo 的解码器能够产生多模态输出,而不仅仅是最可能的模式。
数学分析:
设解码器分布为:
其中 是可逆变换。
Mode-Seeking 定理:
对于 FlowMo,以下等价关系成立:
这意味着:
- 从标准正态分布采样
- 通过可逆流变换得到
- 自动具有正确的多模态分布
3. 完整 FlowMo 模型
3.1 模型架构
class FlowMo(nn.Module):
"""
FlowMo: Mode-Seeking Autoencoder
结合:
- 标准编码器(提取语义潜在向量)
- Flow-based 解码器(mode-seeking 行为)
"""
def __init__(self,
encoder: nn.Module,
latent_dim: int = 128,
num_flow_steps: int = 8):
super().__init__()
self.encoder = encoder
self.decoder = FlowMoDecoder(
latent_dim=latent_dim,
num_flow_steps=num_flow_steps
)
def encode(self, x):
"""编码"""
return self.encoder(x)
def decode(self, z):
"""解码"""
return self.decoder(z)
def forward(self, x):
"""前向传播"""
z = self.encode(x)
x_recon, log_det = self.decode(z)
return x_recon, z, log_det
def training_loss(self, x):
"""
FlowMo 训练损失
结合重建损失和 Flow 正则化
"""
z = self.encode(x)
x_recon, _, log_det = self.decode(z)
# 重建损失
recon_loss = F.mse_loss(x_recon, x)
# Flow 正则化:鼓励对数行列式接近零
flow_reg = 0.01 * (log_det ** 2).mean()
# 总损失
loss = recon_loss + flow_reg
return loss, recon_loss, flow_reg
def sample(self, z, num_samples=1):
"""
从给定潜在向量采样
产生多样化的样本,而不是仅仅重建最可能的模式
"""
with torch.no_grad():
samples = []
for _ in range(num_samples):
u = torch.randn_like(z)
x, _ = self.decoder(z, u)
samples.append(x)
return torch.stack(samples, dim=0)3.2 与标准 AE 的对比
| 特性 | 标准 AE | FlowMo |
|---|---|---|
| 解码器 | 确定性 MLP/Conv | Flow-based 可逆 |
| 输出分布 | 点估计 | 完整分布 |
| Mode Seeking | ❌ | ✅ |
| 多模态捕获 | 弱 | 强 |
| 对数似然计算 | 不可用 | 可用 |
| 采样多样性 | 无 | 有 |
4. 关键数学
4.1 潜在空间多模态表示
FlowMo 学习的多模态潜在表示:
Flow 表示将其统一为单一流:
4.2 Flow Matching in Decoder
FlowMo 的解码器可以用 Flow Matching 训练:
def flow_matching_loss(model, z, x_target, t):
"""
Flow Matching 训练损失
目标:学习从噪声 u₀ 到目标 x 的向量场
"""
# 采样噪声和目标
u0 = torch.randn_like(x_target)
ut = (1 - t) * u0 + t * x_target # 线性插值
# 预测向量场
v_pred = model.decoder.flow_network(ut, z, t)
# 真实向量场
v_target = x_target - u0
return F.mse_loss(v_pred, v_target)4.3 训练目标
FlowMo 完整损失:
5. 实验结果
5.1 多模态数据集上的表现
| 数据集 | 模型 | FID ↓ | LPIPS ↓ | 多样性 ↑ |
|---|---|---|---|---|
| CelebA-HQ | Standard AE | 15.2 | 0.12 | 0.65 |
| VAE | 12.8 | 0.15 | 0.71 | |
| FlowMo | 8.4 | 0.09 | 0.89 | |
| FFHQ | Standard AE | 18.3 | 0.14 | 0.58 |
| VAE | 14.2 | 0.17 | 0.68 | |
| FlowMo | 9.1 | 0.10 | 0.85 |
5.2 消融研究
| 组件 | LPIPS ↓ | Mode Coverage ↑ |
|---|---|---|
| FlowMo (完整) | 0.09 | 0.89 |
| - Flow 解码器 | 0.15 | 0.65 |
| - Mode seeking 损失 | 0.11 | 0.72 |
| - 潜在正则化 | 0.10 | 0.81 |
5.3 Mode Coverage 分析
┌─────────────────────────────────────────────────────────────┐
│ Mode Coverage 对比 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1.0 ┤ ═══════ │
│ │ ═══════════════ │
│ 0.8 ┤ ════════════════════════ │
│ │ ═══════════════════════════════ │
│ 0.6 ┤ ═══════════════════════════════════════ │
│ │ ═══════════════════════════════════════════════ │
│ 0.4 ┤═══════════════════════════════════════════════════ │
│ └────────────────────────────────────────────────── │
│ 标准AE VAE FlowMo │
│ │
│ FlowMo 更好地覆盖所有模式 │
└─────────────────────────────────────────────────────────────┘
6. 应用场景
6.1 图像生成
# FlowMo 用于图像生成
model = FlowMo(encoder, latent_dim=128, num_flow_steps=8)
z = torch.randn(1, 128) # 潜在向量
# 多样化采样
samples = model.sample(z, num_samples=8)6.2 条件生成
class ConditionalFlowMo(nn.Module):
"""条件 FlowMo"""
def __init__(self, base_model, num_classes):
super().__init__()
self.base = base_model
self.class_embed = nn.Embedding(num_classes, 128)
def decode_with_class(self, z, class_id):
"""根据类别解码"""
class_z = self.class_embed(class_id)
z_combined = z + class_z
return self.base.decode(z_combined)