能量模型与Flow Matching统一

1. 能量模型基础

1.1 定义

能量基础模型(Energy-Based Model, EBM)通过能量函数 定义概率分布:

其中 配分函数,通常难以计算。

1.2 与其他生成模型的对比

特性EBMVAEFlowDiffusion
概率表示能量函数隐变量可逆变换噪声添加
采样方式Langevin/MCMC解码器逆变换去噪
训练目标对比散度ELBO对数似然去噪匹配
模式覆盖自然模糊精确渐进

1.3 能量函数设计

常见的能量函数架构:

class EnergyNetwork(nn.Module):
    """能量函数网络"""
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # 标量能量
        )
    
    def forward(self, x):
        return self.net(x).squeeze(-1)
    
    def free_energy(self, x):
        """计算自由能 F = -log ∫ exp(-E(x')) dx'"""
        E = self.forward(x)
        return -torch.logaddexp(E, torch.zeros_like(E))

2. Flow Matching基础

2.1 条件路径

Flow Matching(FM)定义从源分布 到目标分布 插值路径

对应的向量场 定义为:

2.2 最优传输Flow Matching

最优传输(OT)Flow Matching 选择使路径能量最小的向量场:

关键性质:OT路径保证没有模式坍缩,每个模式都被平等覆盖。

2.3 损失函数

FM的训练目标是匹配向量场

其中 是模型预测的速度场。


3. Energy Matching统一框架

3.1 核心思想

Energy Matching(ICLR 2026)提出EBM和FM的统一形式

其中 学习的能量函数,而FM的向量场可以通过能量梯度得到:

3.2 统一优化目标

Energy Matching损失

这与标准FM损失的等价性

3.3 偏观测学习

EM的关键优势是自然支持偏观测条件

class ConditionalEnergyMatching:
    """
    偏观测条件下的Energy Matching
    """
    def __init__(self, energy_net, mask_strategy='random'):
        self.energy = energy_net
        self.mask_strategy = mask_strategy
    
    def compute_conditional_energy(self, x, mask):
        """
        x: 完整数据
        mask: 观测掩码 (1=观测, 0=缺失)
        """
        # 只对观测部分计算能量
        x_masked = x * mask
        E = self.energy(x_masked)
        
        # 缺失部分的能量从先验采样
        return E, x_masked
    
    def training_step(self, batch, t):
        x, mask = batch
        # 采样时间点
        t = torch.rand(len(x))
        
        # 插值
        x_t = (1-t) * self.noise + t * x
        
        # 掩码插值
        mask_t = (1-t) * torch.ones_like(mask) + t * mask
        
        # 能量梯度
        E_grad = torch.autograd.grad(
            self.energy(x_t).sum(), x_t
        )[0]
        
        # 目标向量场(条件OT)
        u_t = self.target_vector_field(x_t, x, mask_t)
        
        loss = ((E_grad - u_t) * mask_t).pow(2).mean()
        return loss

4. Equilibrium Matching

4.1 动机

传统FM依赖于时间条件 ,而Equilibrium Matching(EqM)直接从隐式能量景观学习平衡:

这避免了时间建模的复杂性。

4.2 形式化

EqM定义平衡条件而非路径:

训练目标是最小化平衡误差

4.3 与FM的对比

特性Flow MatchingEquilibrium Matching
时间依赖
采样步数需多步理论上单步
模式覆盖依赖路径选择自然覆盖
实现复杂度中等较低

5. 能量基础语言模型(NRGPT)

5.1 核心思想

NRGPT(2025)将EBM引入语言建模:

  • 替代自回归分解
  • 能量景观建模:整个序列的能量函数

5.2 推理机制

生成推理转化为能量最小化

def generate_nrgpt(energy_net, prompt, num_steps=50):
    """
    通过Langevin动力学生成
    """
    # 初始化
    x = torch.randn(1, seq_len, vocab_size).to(device)
    x = torch.softmax(x, dim=-1)
    
    # 固定prompt
    x[:, :len(prompt)] = prompt
    
    # Langevin采样
    for step in range(num_steps):
        # 计算梯度
        x.requires_grad_(True)
        E = energy_net(x)
        grad = torch.autograd.grad(E.sum(), x)[0]
        
        # 更新
        x = x - 0.1 * grad + torch.randn_like(x) * 0.01
        x = torch.softmax(x, dim=-1)
        
        # 重新固定prompt
        x[:, :len(prompt)] = prompt
    
    return x

5.3 优势与挑战

优势

  • 自然的序列级建模
  • 可纳入复杂的约束
  • 模式覆盖有理论保证

挑战

  • 高维序列的Langevin采样效率
  • 配分函数估计

6. 能量基础Diffusion模型

6.1 从EBM到Diffusion

Energy-based Diffusion Language Model(EDLM)将扩散过程解释为能量演化

这统一了Score Matching和EBM:

6.2 训练目标

序列级EDLM损失

其中 是学习的分数函数。


7. 实践指南

7.1 方法选择

场景推荐方法
简单分布标准Flow Matching
偏观测条件Energy Matching
避免时间建模Equilibrium Matching
语言生成NRGPT/EDLM
高质量采样对比EBM + FM

7.2 实现注意事项

# 关键实现细节
class EnergyFlowMatching(nn.Module):
    def __init__(self, energy_net, flow_net, lambda_energy=0.1):
        self.energy = energy_net
        self.flow = flow_net
        self.lambda_energy = lambda_energy
    
    def training_step(self, x0, x1):
        # 采样时间
        t = torch.rand(len(x0))
        
        # 插值
        x_t = (1-t)[:, None] * x0 + t[:, None] * x1
        
        # Flow Matching目标
        v_pred = self.flow(x_t, t)
        v_target = x1 - x0
        loss_fm = (v_pred - v_target).pow(2).mean()
        
        # Energy Matching正则
        if self.lambda_energy > 0:
            E = self.energy(x_t)
            grad_E = torch.autograd.grad(E.sum(), x_t)[0]
            loss_e = ((v_pred + grad_E) * t[:, None]).pow(2).mean()
        else:
            loss_e = 0
        
        return loss_fm + self.lambda_energy * loss_e

8. 参考文献