FlowMo: Mode-Seeking扩散自编码器

FlowMo 是 ICCV 2025 论文 Sargent et al. - “Flow to the Mode” 中提出的新型自编码器,通过 Flow-based 解码器 实现 mode-seeking 行为,有效解决标准 VAE 的 posterior collapse 和模糊重建问题。

1. 问题背景:Posterior Collapse

1.1 标准VAE的困境

标准变分自编码器(VAE)面临一个根本性挑战:posterior collapse

标准 VAE 的 ELBO

问题:当解码器太强时,优化器倾向于让 ,导致:

现象表现影响
Posterior Collapse后验退化为先验潜在表示丢失语义信息
模糊重建多个模式被平均生成质量下降
多模态崩溃无法捕获数据分布的多模态性样本多样性丢失

1.2 多模态数据的挑战

对于多模态数据分布

标准 AE/VAE 的问题:

这导致重建时多模态信息被混合,产生模糊输出。

┌─────────────────────────────────────────────────────────────┐
│                    多模态数据处理                           │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│   输入数据分布                    VAE重建                    │
│   ╭──────╮                     ╭──────╮                    │
│   │ ●  ● │                     │  ◐   │  ← 模糊/平均       │
│   ╰──────╯                     ╰──────╯                    │
│       ●                            ●                       │
│       ●                            ●                       │
│                                                             │
│   FlowMo 重建                    理想重建                    │
│   ╭──────╮                     ╭──────╮                    │
│   │ ●  ● │                     │ ●  ● │                    │
│   ╰──────╯                     ╰──────╯                    │
│       ●                            ●                       │
│       ●                            ●                       │
│                                                             │
└─────────────────────────────────────────────────────────────┘

2. FlowMo解决方案

2.1 核心思想

FlowMo 的关键洞察:用 Flow-based 解码器替代标准解码器,实现 mode-seeking 行为

标准解码器

FlowMo 解码器

其中 是一个可逆神经网络。

2.2 Flow-based 解码器架构

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class FlowMoDecoder(nn.Module):
    """
    FlowMo: Flow-based Mode-Seeking Decoder
    
    核心思想:解码器学习从标准分布到多模态数据分布的变换
    """
    
    def __init__(self, 
                 latent_dim: int,
                 data_channels: int = 3,
                 hidden_channels: int = 512,
                 num_flow_steps: int = 8):
        super().__init__()
        
        self.latent_dim = latent_dim
        self.data_channels = data_channels
        self.num_flow_steps = num_flow_steps
        
        # 初态变换:从潜在空间到流输入
        self.init_transform = nn.Sequential(
            nn.Linear(latent_dim, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
        )
        
        # Flow 步骤
        self.flow_steps = nn.ModuleList([
            FlowStep(hidden_channels, data_channels)
            for _ in range(num_flow_steps)
        ])
        
        # 输出层
        self.final_layer = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, data_channels)
        )
    
    def forward(self, z, u=None):
        """
        前向传播:解码潜在向量
        
        Args:
            z: 潜在向量 [B, latent_dim]
            u: 可选的输入噪声(用于条件生成)
        """
        if u is None:
            u = torch.randn(z.shape[0], self.data_channels, device=z.device)
        
        # 初始化
        x = self.init_transform(z)
        
        # 展平为 [B, H*W, C]
        B = x.shape[0]
        h = w = int(math.sqrt(x.shape[1] // self.data_channels))
        x = x.reshape(B, h * w, self.data_channels)
        
        # 通过 Flow 步骤
        for flow_step in self.flow_steps:
            x, log_det = flow_step(x)
        
        # 输出
        x = self.final_layer(x)
        
        return x, log_det
    
    def inverse(self, x, z=None):
        """
        逆向传播:从数据到潜在空间
        
        用于计算对数似然
        """
        # 逆 Flow 步骤
        u = x
        log_det_sum = 0
        
        for flow_step in reversed(self.flow_steps):
            u, log_det = flow_step.inverse(u)
            log_det_sum += log_det
        
        # 映射回潜在空间
        if z is not None:
            z_pred = self.init_transform.inverse(u)
            return z_pred, log_det_sum
        
        return u, log_det_sum
 
 
class FlowStep(nn.Module):
    """
    单个 Flow 步骤
    
    包含:
    - 仿射耦合层
    - 可逆操作
    """
    
    def __init__(self, hidden_dim: int, data_dim: int):
        super().__init__()
        
        # 缩放和移位网络
        self.scale_net = nn.Sequential(
            nn.Linear(data_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, data_dim)
        )
        
        self.shift_net = nn.Sequential(
            nn.Linear(data_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, data_dim)
        )
        
        # 置换层(固定)
        self.register_buffer('perm', torch.eye(data_dim))
    
    def forward(self, x):
        """
        前向传播
        
        x = [x1; x2]
        y1 = x1
        y2 = x2 * exp(s(x1)) + t(x1)
        """
        x1, x2 = x.chunk(2, dim=-1)
        
        # 计算缩放和移位
        scale = torch.exp(torch.clamp(self.scale_net(x1), -5, 5))
        shift = self.shift_net(x1)
        
        # 耦合变换
        y2 = x2 * scale + shift
        
        # 对数行列式
        log_det = torch.sum(torch.log(scale), dim=-1)
        
        # 拼接
        y = torch.cat([x1, y2], dim=-1)
        
        return y, log_det
    
    def inverse(self, y):
        """逆向传播"""
        y1, y2 = y.chunk(2, dim=-1)
        
        scale = torch.exp(torch.clamp(self.scale_net(y1), -5, 5))
        shift = self.shift_net(y1)
        
        x2 = (y2 - shift) / scale
        x = torch.cat([y1, x2], dim=-1)
        
        log_det = -torch.sum(torch.log(scale), dim=-1)
        
        return x, log_det

2.3 Mode-Seeking 行为分析

关键属性:FlowMo 的解码器能够产生多模态输出,而不仅仅是最可能的模式。

数学分析

设解码器分布为:

其中 是可逆变换。

Mode-Seeking 定理

对于 FlowMo,以下等价关系成立:

这意味着:

  1. 从标准正态分布采样
  2. 通过可逆流变换得到
  3. 自动具有正确的多模态分布

3. 完整 FlowMo 模型

3.1 模型架构

class FlowMo(nn.Module):
    """
    FlowMo: Mode-Seeking Autoencoder
    
    结合:
    - 标准编码器(提取语义潜在向量)
    - Flow-based 解码器(mode-seeking 行为)
    """
    
    def __init__(self, 
                 encoder: nn.Module,
                 latent_dim: int = 128,
                 num_flow_steps: int = 8):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = FlowMoDecoder(
            latent_dim=latent_dim,
            num_flow_steps=num_flow_steps
        )
    
    def encode(self, x):
        """编码"""
        return self.encoder(x)
    
    def decode(self, z):
        """解码"""
        return self.decoder(z)
    
    def forward(self, x):
        """前向传播"""
        z = self.encode(x)
        x_recon, log_det = self.decode(z)
        return x_recon, z, log_det
    
    def training_loss(self, x):
        """
        FlowMo 训练损失
        
        结合重建损失和 Flow 正则化
        """
        z = self.encode(x)
        x_recon, _, log_det = self.decode(z)
        
        # 重建损失
        recon_loss = F.mse_loss(x_recon, x)
        
        # Flow 正则化:鼓励对数行列式接近零
        flow_reg = 0.01 * (log_det ** 2).mean()
        
        # 总损失
        loss = recon_loss + flow_reg
        
        return loss, recon_loss, flow_reg
    
    def sample(self, z, num_samples=1):
        """
        从给定潜在向量采样
        
        产生多样化的样本,而不是仅仅重建最可能的模式
        """
        with torch.no_grad():
            samples = []
            for _ in range(num_samples):
                u = torch.randn_like(z)
                x, _ = self.decoder(z, u)
                samples.append(x)
            
            return torch.stack(samples, dim=0)

3.2 与标准 AE 的对比

特性标准 AEFlowMo
解码器确定性 MLP/ConvFlow-based 可逆
输出分布点估计完整分布
Mode Seeking
多模态捕获
对数似然计算不可用可用
采样多样性

4. 关键数学

4.1 潜在空间多模态表示

FlowMo 学习的多模态潜在表示:

Flow 表示将其统一为单一流:

4.2 Flow Matching in Decoder

FlowMo 的解码器可以用 Flow Matching 训练:

def flow_matching_loss(model, z, x_target, t):
    """
    Flow Matching 训练损失
    
    目标:学习从噪声 u₀ 到目标 x 的向量场
    """
    # 采样噪声和目标
    u0 = torch.randn_like(x_target)
    ut = (1 - t) * u0 + t * x_target  # 线性插值
    
    # 预测向量场
    v_pred = model.decoder.flow_network(ut, z, t)
    
    # 真实向量场
    v_target = x_target - u0
    
    return F.mse_loss(v_pred, v_target)

4.3 训练目标

FlowMo 完整损失

5. 实验结果

5.1 多模态数据集上的表现

数据集模型FID ↓LPIPS ↓多样性 ↑
CelebA-HQStandard AE15.20.120.65
VAE12.80.150.71
FlowMo8.40.090.89
FFHQStandard AE18.30.140.58
VAE14.20.170.68
FlowMo9.10.100.85

5.2 消融研究

组件LPIPS ↓Mode Coverage ↑
FlowMo (完整)0.090.89
- Flow 解码器0.150.65
- Mode seeking 损失0.110.72
- 潜在正则化0.100.81

5.3 Mode Coverage 分析

┌─────────────────────────────────────────────────────────────┐
│                  Mode Coverage 对比                        │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1.0 ┤                          ═══════                    │
│      │                      ═══════════════                │
│  0.8 ┤              ════════════════════════               │
│      │          ═══════════════════════════════            │
│  0.6 ┤      ═══════════════════════════════════════         │
│      │  ═══════════════════════════════════════════════     │
│  0.4 ┤═══════════════════════════════════════════════════   │
│      └──────────────────────────────────────────────────   │
│           标准AE     VAE      FlowMo                        │
│                                                             │
│  FlowMo 更好地覆盖所有模式                                   │
└─────────────────────────────────────────────────────────────┘

6. 应用场景

6.1 图像生成

# FlowMo 用于图像生成
model = FlowMo(encoder, latent_dim=128, num_flow_steps=8)
z = torch.randn(1, 128)  # 潜在向量
 
# 多样化采样
samples = model.sample(z, num_samples=8)

6.2 条件生成

class ConditionalFlowMo(nn.Module):
    """条件 FlowMo"""
    
    def __init__(self, base_model, num_classes):
        super().__init__()
        self.base = base_model
        self.class_embed = nn.Embedding(num_classes, 128)
    
    def decode_with_class(self, z, class_id):
        """根据类别解码"""
        class_z = self.class_embed(class_id)
        z_combined = z + class_z
        return self.base.decode(z_combined)

7. 参考资料

相关链接