1. 概述

能量基模型(Energy-Based Models, EBM)是一类通过未归一化能量函数定义概率分布的生成模型。2025年,Energy Matching框架成功将EBM与Flow Matching统一,实现了两个领域的深度融合。

核心论文

1.1 能量基模型基础

概率定义

其中 是能量函数, 是配分函数(难以计算)。

与归一化流的对比

维度能量基模型 (EBM)归一化流 (Flow)
概率密度未归一化精确归一化
配分函数难以计算不存在
似然需估计 精确可计算
采样MCMC (慢)逆变换 (快)
条件化自然(加能量项)需要设计

2. Flow Matching基础回顾

2.1 条件Flow Matching

Flow Matching通过学习向量场 从噪声 移动到数据

边缘轨迹由条件轨迹加权平均得到:

损失函数

其中 是预测的速度场。

2.2 最优传输Flow Matching

最优传输(OT)条件轨迹

对应的向量场:

优势:边际分布 可以通过闭式计算:


3. Energy Matching框架

3.1 统一视角

Energy Matching提出统一的框架,将Flow MatchingEBM作为同一框架的两种特例:

┌─────────────────────────────────────────────────────────────────────┐
│                    Energy Matching 统一框架                          │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│                         统一目标函数                                  │
│                               │                                      │
│                               ▼                                      │
│                    ┌──────────────────┐                             │
│                    │  Energy Matching │                             │
│                    │    目标函数      │                             │
│                    └────────┬─────────┘                             │
│                             │                                        │
│              ┌─────────────┴─────────────┐                          │
│              │                           │                          │
│              ▼                           ▼                          │
│     ┌────────────────┐         ┌────────────────┐                   │
│     │  Flow Matching │         │   EBM Training │                   │
│     │   (特例1)      │         │   (特例2)      │                   │
│     └────────────────┘         └────────────────┘                   │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

3.2 核心思想

Energy Matching损失

其中:

  • 分数函数
  • 是Flow Matching的速度场

物理解释

┌─────────────────────────────────────────────────────────────────────┐
│                      Energy Matching 物理解释                        │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  能量函数 E(x) 定义势能 landscape:                                   │
│                                                                     │
│         高能量                      低能量                          │
│           ↑                          ↓                              │
│      ┌────┴────┐                ┌────┬────┐                      │
│      │         │                │         │                       │
│      │  局部   │    ───→        │  全局   │                       │
│      │  极小值 │    梯度流       │  最优   │                       │
│      │         │                │  (OT)  │                       │
│      └─────────┘                └─────────┘                      │
│                                                                     │
│  分数 ∇log p = -∇E 指向数据分布的"吸引域"                            │
│  速度场 v 描述从噪声到数据的路径                                      │
│                                                                     │
│  Energy Matching: 让分数场逼近速度场                                  │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

3.3 两阶段训练策略

Energy Matching采用两阶段训练

┌─────────────────────────────────────────────────────────────────────┐
│                    Energy Matching 两阶段训练                         │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  第一阶段:Flow Matching预训练                                        │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                                                             │   │
│  │   数据 x₁ ~ p_data                                         │   │
│  │       │                                                     │   │
│  │       ▼                                                     │   │
│  │   学习速度场 v_θ(x_t, t) ──→ OT条件轨迹                    │   │
│  │       │                                                     │   │
│  │       ▼                                                     │   │
│  │   生成器 g_θ(x_0, t) = x_t                                 │   │
│  │                                                             │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                              │                                      │
│                              ▼                                      │
│  第二阶段:能量对齐微调                                               │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                                                             │   │
│  │   从生成器提取隐式能量函数:                                  │   │
│  │   E_θ(x) = -log p_θ(x) + const                             │   │
│  │                                                             │   │
│  │   训练能量函数匹配Flow Matching的分布:                       │   │
│  │   min_φ D( p_v_φ || p_data )                               │   │
│  │                                                             │   │
│  │   采样: Langevin动态 → Boltzmann平衡分布                    │   │
│  │                                                             │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

3.4 理论保证

Energy Matching具有以下理论性质:

性质说明
一致性当样本量n→∞时,估计的能量函数收敛到真实能量
效率无需显式计算配分函数
统一性Flow Matching和EBM作为特例包含
采样保证收敛到Boltzmann平衡分布

4. Equilibrium Matching (EqM)

4.1 超越原始Energy Matching

EqM是Energy Matching的改进版本,核心创新是Boltzmann平衡分布的显式建模:

EqM目标函数

其中 是能量函数诱导的Boltzmann分布。

4.2 高分辨率实验结果

数据集方法FID ↓IS ↑
CIFAR-10Energy Matching3.39.8
EqM2.810.5
DDPM3.99.4
ImageNet 64×64Energy Matching8.228.5
EqM6.135.2
ADM7.232.1

4.3 与扩散模型对比

维度EqM扩散模型 (DDPM)
采样步数50-100100-1000
似然建模显式需要估计
条件化自然(加能量项)需要classifier guidance
采样质量可比或更好SOTA
训练稳定性较高中等

5. 算法实现

5.1 PyTorch实现框架

import torch
import torch.nn as nn
from torch.distributions import Normal
 
class EnergyMatchingModel(nn.Module):
    """
    Energy Matching: Unifying Flow Matching and EBM
    """
    def __init__(self, net, beta=1.0):
        super().__init__()
        self.net = net  # 能量函数/速度场网络
        self.beta = beta  # 温度参数
        
    def energy(self, x):
        """能量函数 E(x)"""
        return self.net(x)
    
    def score(self, x):
        """分数函数 ∇log p(x) = -∇E(x)"""
        return -torch.autograd.grad(
            self.energy(x).sum(), x, create_graph=True
        )[0]
    
    def velocity(self, x, t):
        """Flow Matching速度场 v(x, t)"""
        return self.net(x, t)
    
    def em_loss(self, x_data, x_noisy, t):
        """
        Energy Matching损失
        L_EM = E[ || ∇log p(x) - v(x, t)||^2 ]
        """
        # 分数
        score = self.score(x_noisy)
        # 速度场
        velocity = self.velocity(x_noisy, t)
        # 损失
        return ((score - velocity) ** 2).mean()
    
    def langevin_sampling(self, x_init, n_steps=100, lr=0.01):
        """
        Langevin采样: 从Boltzmann分布采样
        x_{t+1} = x_t - lr * ∇E(x_t) + sqrt(2*lr) * ε
        """
        x = x_init.detach().clone()
        x.requires_grad = True
        
        for _ in range(n_steps):
            energy = self.energy(x)
            grad = torch.autograd.grad(energy.sum(), x)[0]
            
            # Langevin动态
            noise = torch.randn_like(x)
            x = x - lr * grad + torch.sqrt(torch.tensor(2.0 * lr)) * noise
            x.requires_grad = True
            
        return x.detach()
 
 
class EquilibriumMatching(nn.Module):
    """
    Equilibrium Matching (EqM): 带有Boltzmann平衡的EM
    """
    def __init__(self, em_model, lambda_eq=0.1):
        super().__init__()
        self.em_model = em_model
        self.lambda_eq = lambda_eq
        
    def eqm_loss(self, x_data, x_neg, x_noisy, t):
        """
        EqM损失 = EM损失 + 平衡损失
        """
        # EM损失
        em_loss = self.em_model.em_loss(x_data, x_noisy, t)
        
        # 平衡损失: E(x_data) ≈ E(x_neg) for balanced distribution
        energy_data = self.em_model.energy(x_data)
        energy_neg = self.em_model.energy(x_neg)
        eq_loss = (energy_data - energy_neg).mean()
        
        return em_loss + self.lambda_eq * eq_loss

5.2 训练流程

def train_energy_matching(model, dataloader, n_epochs=100):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    for epoch in range(n_epochs):
        for batch in dataloader:
            x_data = batch.to(device)
            batch_size = x_data.shape[0]
            
            # 采样噪声和时间
            t = torch.rand(batch_size, device=device)
            x_0 = torch.randn_like(x_data)
            
            # 插值: x_t = (1-t)*x_0 + t*x_1 (OT轨迹)
            x_noisy = (1 - t.view(-1, 1, 1, 1)) * x_0 + t.view(-1, 1, 1, 1) * x_data
            
            # 计算损失
            loss = model.em_loss(x_data, x_noisy, t)
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if epoch % 100 == 0:
                print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
                
                # 生成样本验证
                with torch.no_grad():
                    z = torch.randn(16, *x_data.shape[1:], device=device)
                    samples = model.langevin_sampling(z, n_steps=50)
                    # 或使用Flow Matching逆过程采样
                    # samples = model.generate(n_steps=50)

6. 采样技术

6.1 从EM到采样

Energy Matching的核心优势之一是灵活的采样策略

┌─────────────────────────────────────────────────────────────────────┐
│                    Energy Matching 采样策略                         │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  方式1: Flow Matching逆采样                                          │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                                                             │   │
│  │   x_T ~ N(0, I) ──→ [逆向ODE求解] ──→ x_0 ≈ p_data          │   │
│  │                                                             │   │
│  │   dx/dt = -v_θ(x, 1-t)  (逆时间)                            │   │
│  │                                                             │   │
│  │   优点: 确定性轨迹,快速收敛                                  │   │
│  │   缺点: 需要ODE求解器                                        │   │
│  │                                                             │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                              │                                      │
│                              ▼                                      │
│  方式2: Langevin采样                                                  │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                                                             │   │
│  │   x_{k+1} = x_k - η∇E(x_k) + √(2η)ε                        │   │
│  │                                                             │   │
│  │   优点: 简单实现,适用于任意能量函数                           │   │
│  │   缺点: 收敛慢,需要很多步数                                   │   │
│  │                                                             │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                              │                                      │
│                              ▼                                      │
│  方式3: 混合采样                                                     │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                                                             │   │
│  │   Flow Matching快速初始化 → Langevin精调                      │   │
│  │                                                             │   │
│  │   结合两种方法的优势: 快速 + 精确                              │   │
│  │                                                             │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

6.2 MCMC加速技术

技术说明效果
Langevin MCMC梯度辅助Proposal加速收敛
哈密顿MCMC引入动量项更高效的Proposal
随机梯度MCMCSGLD用于大数据可扩展
并行链多起点并行采样更好覆盖

7. 应用场景

7.1 生成建模

数据类型应用优势
图像高分辨率生成质量可比扩散模型
分子分子生成条件化自然
音频语音合成似然建模
图结构图生成灵活的能量设计

7.2 异常检测

EBM天然适合异常检测:

def anomaly_detection(model, x):
    """
    异常检测: 低密度 = 异常
    """
    energy = model.energy(x)
    # 能量越高,密度越低,越可能是异常
    return energy

7.3 条件生成

EBM的条件化非常自然:

def conditional_generation(model, condition, target_energy):
    """
    条件生成: 在给定条件下找到满足能量约束的样本
    """
    def constrained_energy(x):
        return model.energy(x) + lambda_c * (h_c(x) - condition)**2
    
    # 从约束能量函数采样
    return langevin_sample(constrained_energy)

8. 与相关工作的对比

8.1 vs 扩散模型

维度Energy Matching扩散模型
训练目标分数匹配/速度场去噪重建
采样ODE/LangevinSDE
似然估计精确/估计
条件化自然classifier guidance
逆问题原生支持需专门设计

8.2 vs 传统EBM

维度Energy Matching传统EBM (如NICE, RealNVP)
可逆性不要求要求
训练稳定性低 (配分函数问题)
采样效率可用FM加速仅MCMC
与FM关系统一框架独立发展

9. 未来发展方向

方向说明
高分辨率应用扩展到ImageNet 256+
多模态生成统一处理多种数据类型
理论深化收敛性、泛化性分析
高效采样减少采样步数
与LLM结合用于文本/代码生成

10. 相关专题


参考文献