1. 概述
能量基模型(Energy-Based Model, EBM)是一类通过能量函数定义概率分布的生成模型。与直接参数化概率分布不同,EBM通过设计能量函数来描述数据的内在结构,具有表达能力强、条件化自然等优点。1
内容框架:
┌─────────────────────────────────────────────────────────────────────┐
│ 能量基模型(EBM)全景图 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ 理论基础 │ │
│ │ 能量函数 E(x) → Gibbs分布 → 概率建模 │ │
│ └────────────────────────────┬────────────────────────────────┘ │
│ │ │
│ ┌────────────────────────────┼────────────────────────────────┐ │
│ │ 历史发展 │ │
│ │ Hopfield网络 ──→ 现代EBM ──→ 与扩散模型融合 │ │
│ └────────────────────────────┬────────────────────────────────┘ │
│ │ │
│ ┌────────────────────────────┼────────────────────────────────┐ │
│ │ 采样方法 │ │
│ │ 郎之万动力学 ──→ 对比散度 ──→ 其他MCMC方法 │ │
│ └────────────────────────────┬────────────────────────────────┘ │
│ │ │
│ ┌────────────────────────────┼────────────────────────────────┐ │
│ │ 应用场景 │ │
│ │ 图像生成 ──→ 异常检测 ──→ 对抗鲁棒 ──→ 逆问题求解 │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘
2. 能量基模型基础定义
2.1 能量函数的定义
能量函数 将每个输入样本映射到一个实数标量,表示该样本的”能量”水平。低能量对应高概率,高能量对应低概率。
典型形式:
其中 是某个非负函数(如神经网络的输出)。
2.2 从能量到概率:Gibbs分布
通过Gibbs分布(也称Boltzmann分布)将能量函数转换为概率分布:
其中:
- 是配分函数
- 是逆温度参数(通常设为1)
物理直觉:
┌─────────────────────────────────────────────────────────────────────┐
│ 能量景观与概率分布 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ 能量 E(x) 概率 p(x) │
│ │
│ 高能量 低概率 │
│ ↑ ↓ │
│ ┌──┴──┐ ┌────┴────┐ │
│ │ │ │ │ │
│ │ 峰 │ │ 谷底 │ │
│ │ │ │ (高p) │ │
│ └─────┘ └─────────┘ │
│ │
│ 低能量 高概率 │
│ ↓ ↑ │
│ ┌──┬──┐ ┌────┬────┐ │
│ │ │ │ │ │ │ │
│ │ 全局 │ │ 峰顶 │ │
│ │ 最小值 │ │ (最高p)│ │
│ │ │ │ │ │ │ │
│ └─────┘ └─────────┘ │
│ │
│ Boltzmann因子: exp(-E(x)) 越大,能量越低,概率越高 │
│ │
└─────────────────────────────────────────────────────────────────────┘
2.3 配分函数的角色
配分函数 是概率分布的归一化常数:
核心挑战:对于高维数据, 通常无法解析计算,也难以数值估计。这是EBM训练的主要困难。
为什么配分函数难计算:
| 维度 | 样本空间大小 | 配分函数计算难度 |
|---|---|---|
| 32×32图像 | 不可能 | |
| 784维MNIST | 不可能 |
2.4 能量模型的优势
| 优势 | 说明 |
|---|---|
| 表达能力强 | 任何非归一化密度都可以写成 |
| 自然条件化 | 只需在能量函数中添加条件项即可 |
| 逆问题求解 | 通过最小化能量可以直接求解逆问题 |
| 模式覆盖 | 能量函数天然支持多模态分布 |
| 与物理的联系 | Boltzmann分布源于统计力学 |
3. EBM与概率分布的关系
3.1 各种分布的能量表示
Bernoulli分布:
高斯分布:
混合高斯分布:
3.2 对数似然梯度
尽管配分函数无法直接计算,其梯度可以通过对数似然梯度间接获得:
对配分函数求导:
因此:
直观理解:
4. Hopfield网络的历史回顾
4.1 Hopfield网络的起源
Hopfield网络(1982年)是最早的能量基模型之一,由John Hopfield提出,用于模拟生物神经网络的联想记忆功能。2
网络结构:
其中 是神经元状态, 是连接权重, 是偏置。
4.2 能量函数与吸引子
Hopfield网络的吸引子对应能量的局部最小值,每个吸引子可以存储一个记忆模式:
┌─────────────────────────────────────────────────────────────────────┐
│ Hopfield网络能量景观 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ E(s) │
│ │ │
│ ┌──────────┐ │ ┌──────────┐ │
│ ╱│ │╲ │ ╱│ │╲ │
│ ╱ │ 吸引子1 │ ╲│╱ │ 吸引子2 │ ╲ │
│ ╱ │ (记忆m₁) │ ╲ │ (记忆m₂) │ ╲ │
│ ──╱────│ │──╲─│─│ │──╲───── │
│ │ │ ╲│╱│ │ ╲ │
│ ╲──────────╱ │ ╲──────────╱ ╲ │
│ │ │
│ 状态s沿能量下山梯度演化,趋向最近的吸引子 │
│ │
└─────────────────────────────────────────────────────────────────────┘
4.3 Hopfield更新规则
异步更新:
能量单调下降:
4.4 现代EBM与Hopfield网络的关系
现代能量基模型可以视为Hopfield网络的扩展:
| 特征 | Hopfield网络 | 现代EBM |
|---|---|---|
| 能量函数 | 二次型 | 神经网络参数化 |
| 状态 | 连续 | |
| 更新 | 确定性 | 随机(郎之万) |
| 容量 | 无限制 | |
| 应用 | 联想记忆 | 生成建模 |
现代扩展:Modern Hopfield Networks 和 注意力机制有深刻联系,详见相关研究。3
5. 现代EBM架构
5.1 基于能量的自编码器
Joint Embedding Models 和 能量基模型形成对比。能量基自编码器(EBAE)直接参数化能量函数:
class EnergyBasedAutoEncoder(nn.Module):
"""
基于能量的自编码器
编码器:学习隐表示
能量函数:测量 (x, z) 的兼容性
解码器:可选,用于重建
"""
def __init__(self, input_dim, hidden_dim, latent_dim):
super().__init__()
# 编码器
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim)
)
# 能量函数网络
self.energy_net = nn.Sequential(
nn.Linear(input_dim + latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1) # 输出标量能量
)
# 解码器(可选)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim)
)
def energy(self, x, z=None):
"""计算能量 E(x, z)"""
if z is None:
z = self.encoder(x)
return self.energy_net(torch.cat([x, z], dim=-1))
def encode(self, x):
"""编码"""
return self.encoder(x)
def decode(self, z):
"""解码"""
return self.decoder(z)
def reconstruct(self, x):
"""重建"""
z = self.encode(x)
return self.decode(z)5.2 联合能量基模型(JEM)
Joint Energy-based Models (JEM) 将分类器与EBM结合,同时实现分类和生成。4
class JEM(nn.Module):
"""
Joint Energy-based Model
统一分类和生成的能量基模型
"""
def __init__(self, classifier, energy_fn):
super().__init__()
self.classifier = classifier
self.energy_fn = energy_fn
def energy(self, x, y):
"""
计算条件能量 E(x | y)
结合分类logit和能量函数
"""
# 分类器输出的logit作为能量项
class_energy = -self.classifier(x) # 负logit
generative_energy = self.energy_fn(x)
# 组合
return class_energy[:, y] + generative_energy
def joint_energy(self, x, y):
"""
计算联合能量 E(x, y) = E(x | y) + log p(y)
"""
return self.energy(x, y) + torch.log_softmax(self.classifier(x), dim=1)[:, y]
def forward(self, x, y=None):
"""
前向传播
如果提供y,计算 E(x, y)
否则返回所有类别的能量
"""
if y is not None:
return self.joint_energy(x, y)
return self.energy(x, torch.arange(x.size(0), device=x.device))
def classify_with_energy(self, x):
"""
基于能量的分类:选择能量最低的类别
"""
return (-self.classifier(x)).argmax(dim=1)
def log_prob(self, x, y, n_steps=20, step_size=0.1):
"""
近似计算 log p(y|x) = -E(x,y) - log Z + const
"""
energy = self.joint_energy(x, y)
# 简化估计:使用分类器logit近似
# 注意:这是近似,不涉及MCMC采样
return -energy5.3 能量函数设计原则
良好能量函数的设计要点:
| 原则 | 说明 |
|---|---|
| 灵活的表达能力 | 使用深度网络捕捉复杂数据分布 |
| 平滑性 | 相似的输入应有相似的能量 |
| 模式覆盖 | 确保覆盖所有数据模式 |
| 可优化性 | 梯度容易计算,便于训练 |
| 采样友好 | 能量景观便于MCMC采样 |
6. 郎之万动力学采样
6.1 郎之万采样原理
**郎之万动力学(Langevin Dynamics)**是连续空间的MCMC采样方法,通过随机梯度下降从Boltzmann分布采样。
物理背景:布朗运动中粒子的运动方程:
其中 是高斯白噪声。
过阻尼极限(,):
6.2 离散时间郎之万采样
迭代公式:
其中 是步长。
收敛性:当 ,链收敛到目标分布 。
┌─────────────────────────────────────────────────────────────────────┐
│ 郎之万采样轨迹示意 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ E(x) │
│ │ │
│ 高 │ ╱╲ │
│ │ ╱ ╲ ╱╲ │
│ │ ╱ ╲ ╱ ╲ 能量景观 │
│ 能 │ ╱ ╲ ╱╲ ╱ ╲ 郎之万采样 │
│ │ ╱ ╲╱ ╲╱ ╲ 从高能区域 │
│ 量 │╱ 向谷底移动 ╲ 逐步下降到 │
│ │ ╲ 低能区域 │
│ 低 │ ╲ │
│ │ ● ←─采样点 │
│ │ ╱ │
│ │ ╱ │
│ ─────┼──────────────────────────────────────────── x │
│ │ │
│ │ │
│ ════════════════════════════════════════════════════ │
│ 时间演化 │
│ │
└─────────────────────────────────────────────────────────────────────┘
6.3 郎之万采样实现
class LangevinSampler:
"""
郎之万采样器
用于从能量基模型定义的分布中采样
"""
def __init__(self, energy_fn, device='cuda'):
self.energy_fn = energy_fn
self.device = device
def sample(self, x_init, n_steps=100, step_size=0.01, noise_scale=None, return_trajectory=False):
"""
郎之万采样
Args:
x_init: 初始点
n_steps: 采样步数
step_size: 步长 η
noise_scale: 噪声尺度,默认 sqrt(2*step_size)
return_trajectory: 是否返回采样轨迹
Returns:
采样结果 (可能包含轨迹)
"""
x = x_init.detach().clone().to(self.device).requires_grad_(True)
if noise_scale is None:
noise_scale = np.sqrt(2 * step_size)
trajectory = [x.detach().cpu()]
for step in range(n_steps):
# 计算能量梯度
energy = self.energy_fn(x)
grad = torch.autograd.grad(
energy.sum(), x, create_graph=False
)[0]
# 郎之万更新
noise = torch.randn_like(x)
x = x - step_size * grad + noise_scale * noise
x = x.detach().requires_grad_(True)
if return_trajectory:
trajectory.append(x.detach().cpu())
if return_trajectory:
return x.detach(), torch.stack(trajectory)
return x.detach()
def sample_multiple(self, n_samples, x_init_fn, n_steps=100, step_size=0.01):
"""
并行采样多个样本
"""
samples = []
for _ in range(n_samples):
x_init = x_init_fn()
x_sample = self.sample(x_init, n_steps, step_size)
samples.append(x_sample)
return torch.stack(samples)
class EBMWithLangevinSampling(nn.Module):
"""
带郎之万采样的能量基模型
"""
def __init__(self, energy_net, sampler=None):
super().__init__()
self.energy_net = energy_net
self.sampler = sampler or LangevinSampler(self.energy)
def energy(self, x):
"""能量函数"""
return self.energy_net(x).squeeze(-1)
def score(self, x):
"""分数函数 ∇_x log p(x) = -∇_x E(x)"""
return -torch.autograd.grad(
self.energy(x).sum(), x, create_graph=True
)[0]
def sample(self, x_init, n_steps=100, step_size=0.01):
"""从模型分布中采样"""
return self.sampler.sample(x_init, n_steps, step_size)
def generate(self, batch_size, shape, n_steps=100, step_size=0.01):
"""生成样本"""
x_init = torch.randn(batch_size, *shape, device=next(self.parameters()).device)
return self.sample(x_init, n_steps, step_size)6.4 退火郎之万采样
退火技术可以加速高维复杂分布的采样:
def annealed_langevin_sampling(energy_fn, x_init, n_steps=100,
step_size=0.01, noise_scale=0.01,
n_temps=10):
"""
退火郎之万采样
从高温逐步降低到低温,帮助跳出局部极小
"""
x = x_init.detach().clone().requires_grad_(True)
# 温度调度:从高到低
temps = torch.linspace(10.0, 0.1, n_temps)
for temp in temps:
steps_per_temp = n_steps // n_temps
for step in range(steps_per_temp):
# 缩放步长和噪声
scaled_step = step_size * temp
scaled_noise = noise_scale * np.sqrt(temp)
# 计算梯度
energy = energy_fn(x) / temp
grad = torch.autograd.grad(energy.sum(), x, create_graph=False)[0]
# 郎之万更新
noise = torch.randn_like(x)
x = x - scaled_step * grad + scaled_noise * noise
x = x.detach().requires_grad_(True)
return x.detach()7. 对比散度(Contrastive Divergence)算法
7.1 CD算法的动机
直接计算配分函数的梯度需要从模型分布中采样,这本身就是困难的问题。**对比散度(CD)**通过近似解决这个问题。5
核心思想:使用少量吉布斯采样步骤从数据初始化,而不是从随机状态开始。
7.2 CD-k算法
CD-k算法步骤:
┌─────────────────────────────────────────────────────────────────────┐
│ 对比散度 (CD-k) 算法 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ 输入: 数据样本 x₀ ~ p_data, E_θ(x) │
│ 输出: 梯度估计 │
│ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ │ │
│ │ 1. 计算正相位梯度(数据驱动) │ │
│ │ ▽_θ⁺ = ∂E_θ(x₀)/∂θ │ │
│ │ │ │
│ │ 2. k步吉布斯采样(从数据初始化) │ │
│ │ for t = 1 to k: │ │
│ │ x_t ~ p_θ(x | x_{t-1}) │ │
│ │ # 交替更新隐变量和可见变量 │ │
│ │ │ │
│ │ 3. 计算负相位梯度(模型驱动) │ │
│ │ ▽_θ⁻ = ∂E_θ(x_k)/∂θ │ │
│ │ │ │
│ │ 4. 梯度估计 │ │
│ │ ▽_θ ≈ ▽_θ⁺ - ▽_θ⁻ │ │
│ │ │ │
│ └─────────────────────────────────────────────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────────┘
数学推导:
其中 是经过 步吉布斯采样后的分布。
7.3 CD算法变体
| 变体 | 描述 | 特点 |
|---|---|---|
| CD-k | 标准k步CD | 简单,常用k=1 |
| PCD | 持续CD | 维护负样本池 |
| CDn | n个k步采样 | 更准确但慢 |
| Fast CD | 调整学习率 | 加速收敛 |
7.4 PyTorch实现CD
class ContrastiveDivergenceTrainer:
"""
对比散度训练器
用于训练能量基模型
"""
def __init__(self, energy_net, k=1, lr=0.01, momentum=0.5):
self.energy_net = energy_net
self.k = k # 吉布斯采样步数
self.lr = lr
self.momentum = momentum
# 动量
self.m = {name: torch.zeros_like(param)
for name, param in energy_net.named_parameters()}
def gibbs_step(self, x):
"""
吉布斯采样一步
对于EBM,简化使用朗之万一步
"""
x = x.detach().requires_grad_(True)
# 计算能量
energy = self.energy_net(x).sum()
# 梯度
grad = torch.autograd.grad(energy, x, create_graph=False)[0]
# 朗之万更新(简化)
step_size = 0.01
noise_scale = 0.01
x_new = x - step_size * grad + noise_scale * torch.randn_like(x)
return x_new.detach()
def negative_sampling(self, x_data, k=None):
"""
从数据样本出发,进行k步采样生成负样本
"""
k = k or self.k
x_neg = x_data.clone()
for _ in range(k):
x_neg = self.gibbs_step(x_neg)
return x_neg
def compute_gradients(self, x_pos, x_neg):
"""
计算CD梯度
"""
x_pos = x_pos.detach().requires_grad_(True)
x_neg = x_neg.detach().requires_grad_(True)
# 正相位梯度
energy_pos = self.energy_net(x_pos).sum()
grad_pos = torch.autograd.grad(energy_pos,
self.energy_net.parameters(),
retain_graph=True)
# 负相位梯度
energy_neg = self.energy_net(x_neg).sum()
grad_neg = torch.autograd.grad(energy_neg,
self.energy_net.parameters())
# CD梯度 = 正相位 - 负相位
grads = []
for g_pos, g_neg in zip(grad_pos, grad_neg):
grads.append(g_pos - g_neg)
return grads
def step(self, x_data, x_neg_init=None):
"""
一步训练更新
"""
# 生成负样本
if x_neg_init is not None:
x_neg = self.negative_sampling(x_neg_init)
else:
x_neg = self.negative_sampling(x_data)
# 计算梯度
grads = self.compute_gradients(x_data, x_neg)
# 更新参数(带动量)
for (name, param), grad in zip(self.energy_net.named_parameters(), grads):
self.m[name] = self.momentum * self.m[name] + (1 - self.momentum) * grad
param.data = param.data - self.lr * self.m[name]
# 计算损失(能量差)
with torch.no_grad():
loss = self.energy_net(x_neg).mean() - self.energy_net(x_data).mean()
return loss.item(), x_neg
def train_ebm_cd(energy_net, data_loader, n_epochs=100, k=1):
"""
使用对比散度训练EBM
"""
trainer = ContrastiveDivergenceTrainer(energy_net, k=k)
for epoch in range(n_epochs):
epoch_loss = 0
neg_samples = None
for batch_idx, x_data in enumerate(data_loader):
x_data = x_data.view(x_data.size(0), -1).to(next(energy_net.parameters()).device)
# 训练
loss, neg = trainer.step(x_data, neg_samples)
epoch_loss += loss
# 持久化:使用上一轮的负样本作为下一轮的初始点
neg_samples = neg.detach()
if epoch % 10 == 0:
print(f"Epoch {epoch}: Loss = {epoch_loss/len(data_loader):.4f}")
return energy_net8. EBM与扩散模型的联系
8.1 统一视角
EBM和扩散模型(DDPM)本质上都定义了数据分布,但参数化方式和采样方法不同。6
对比表:
| 维度 | 能量基模型 (EBM) | 扩散模型 (DDPM) |
|---|---|---|
| 分布表示 | ||
| 显式归一化 | 否(需) | 是(自动归一化) |
| 采样机制 | MCMC(郎之万等) | 逆时间SDE/ODE |
| 条件化 | 自然(加能量项) | 需classifier guidance |
| 训练目标 | 对比散度/分数匹配 | 去噪重建 |
8.2 分数函数连接
EBM的分数函数与扩散模型的核心量有深刻联系:
分数匹配是连接两者的桥梁:
详见 分数匹配理论。
8.3 EBM作为扩散模型的一步
扩散模型的逆过程可以视为从噪声分布出发,逐步降低能量到达数据分布:
这正是郎之万采样的形式!
8.4 能量引导的扩散采样
可以利用EBM的能量函数引导扩散模型采样:
class EnergyGuidedDiffusion:
"""
能量引导的扩散模型
结合扩散模型的语义生成能力和EBM的细粒度控制
"""
def __init__(self, diffusion_model, energy_net):
self.diffusion = diffusion_model
self.energy = energy_net
self.beta = 0.1 # 引导强度
def guided_score(self, x_t, t, condition=None):
"""
能量引导的分数估计
"""
# 扩散模型分数
diffusion_score = self.diffusion.score(x_t, t)
# EBM分数(仅在低噪声时使用)
if t > self.diffusion.num_timesteps // 2:
energy_score = self.energy.score(x_t)
# 引导
diffusion_score = diffusion_score + self.beta * energy_score
return diffusion_score
def sample_guided(self, shape, condition=None, n_steps=1000):
"""
引导采样
"""
x_t = torch.randn(shape, device=next(self.diffusion.parameters()).device)
for t in reversed(range(n_steps)):
# 估计分数
score = self.guided_score(x_t, t, condition)
# 更新(简化的DDIM采样)
x_t = x_t - self.diffusion.step_size * score
return x_t9. PyTorch完整实现示例
9.1 完整EBM训练流程
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, MNIST
import torchvision.transforms as transforms
class SimpleEBM(nn.Module):
"""
简单能量基模型
用于MNIST图像生成
"""
def __init__(self, hidden_dim=512):
super().__init__()
# 能量网络:多层感知机
self.net = nn.Sequential(
nn.Linear(784, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1) # 输出能量
)
def energy(self, x):
"""计算能量"""
x = x.view(x.size(0), -1)
return self.net(x).squeeze(-1)
def score(self, x):
"""分数函数"""
x = x.view(x.size(0), -1).requires_grad_(True)
e = self.energy(x)
return -torch.autograd.grad(e.sum(), x, create_graph=True)[0]
def langevin_sample(self, x_init, n_steps=60, step_size=0.01, noise_scale=0.01):
"""郎之万采样"""
x = x_init.view(x_init.size(0), -1).clone().detach().requires_grad_(True)
for _ in range(n_steps):
# 梯度
grad = torch.autograd.grad(self.energy(x).sum(), x)[0]
# 更新
noise = torch.randn_like(x)
x = x - step_size * grad + noise_scale * noise
x = x.detach().requires_grad_(True)
return x
def forward(self, x):
return self.energy(x)
class EBMTrainer:
"""EBM训练器"""
def __init__(self, model, lr=1e-4, lazy_grad=True):
self.model = model
self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
self.lazy_grad = lazy_grad
def train_step(self, x_pos, x_neg_init=None, n_cd=1):
"""
单步训练
使用对比散度
"""
self.model.train()
self.optimizer.zero_grad()
# 正样本能量
e_pos = self.model.energy(x_pos)
# 负样本生成
if x_neg_init is not None:
x_neg = x_neg_init
else:
x_neg = torch.randn_like(x_pos) * 0.1 # 从接近零的噪声初始化
# CD-k采样
for _ in range(n_cd):
x_neg = self.model.langevin_sample(x_neg, n_steps=20,
step_size=0.003, noise_scale=0.003)
# 负样本能量
e_neg = self.model.energy(x_neg)
# 对比散度损失
# 目标:正样本能量低,负样本能量高
loss = F.relu(e_pos.mean() - e_neg.mean() + 0.1) # margin=0.1
loss.backward()
self.optimizer.step()
return {
'loss': loss.item(),
'e_pos': e_pos.mean().item(),
'e_neg': e_neg.mean().item()
}
def sample(self, n_samples=64):
"""生成样本"""
self.model.eval()
with torch.no_grad():
x_init = torch.randn(n_samples, 1, 28, 28, device=next(self.model.parameters()).device)
x_samples = self.model.langevin_sample(x_init, n_steps=100,
step_size=0.01, noise_scale=0.01)
x_samples = x_samples.view(-1, 1, 28, 28)
return x_samples
def main():
# 加载数据
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: (x - 0.5) * 2) # 归一化到 [-1, 1]
])
dataset = MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)
# 创建模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleEBM(hidden_dim=1024).to(device)
# 创建训练器
trainer = EBMTrainer(model, lr=1e-4)
# 训练
neg_samples = None
for epoch in range(50):
epoch_stats = {'loss': 0, 'e_pos': 0, 'e_neg': 0}
n_batches = 0
for batch_idx, (x_data, _) in enumerate(dataloader):
x_data = x_data.to(device)
# 训练
stats = trainer.train_step(x_data, neg_samples)
# 持久化:使用最后一个负样本
if batch_idx == len(dataloader) - 1:
neg_samples = model.langevin_sample(x_data, n_steps=20).detach()
epoch_stats['loss'] += stats['loss']
epoch_stats['e_pos'] += stats['e_pos']
epoch_stats['e_neg'] += stats['e_neg']
n_batches += 1
# 打印统计
print(f"Epoch {epoch}: Loss={epoch_stats['loss']/n_batches:.4f}, "
f"E_pos={epoch_stats['e_pos']/n_batches:.4f}, "
f"E_neg={epoch_stats['e_neg']/n_batches:.4f}")
# 定期生成样本
if epoch % 10 == 0:
samples = trainer.sample(n_samples=64)
# 可以保存或可视化 samples
if __name__ == '__main__':
main()9.2 条件EBM实现
class ConditionalEBM(nn.Module):
"""
条件能量基模型
支持类别条件生成
"""
def __init__(self, input_dim, n_classes, hidden_dim=512):
super().__init__()
self.n_classes = n_classes
self.input_dim = input_dim
# 类别嵌入
self.class_embed = nn.Embedding(n_classes, hidden_dim)
# 能量网络
self.net = nn.Sequential(
nn.Linear(input_dim + hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def energy(self, x, y=None):
"""
计算条件能量 E(x | y)
"""
if y is not None:
y_emb = self.class_embed(y)
x_y = torch.cat([x.view(x.size(0), -1), y_emb], dim=-1)
else:
x_y = x.view(x.size(0), -1)
return self.net(x_y).squeeze(-1)
def conditional_sample(self, y, n_steps=100):
"""从条件分布 p(x|y) 采样"""
y_tensor = torch.full((1,), y, dtype=torch.long, device=next(self.parameters()).device)
x_init = torch.randn(1, *self.input_dim, device=next(self.parameters()).device)
x = x_init.view(1, -1).clone().detach().requires_grad_(True)
for _ in range(n_steps):
e = self.energy(x, y_tensor)
grad = torch.autograd.grad(e.sum(), x)[0]
x = x - 0.01 * grad + 0.01 * torch.randn_like(x)
x = x.detach().requires_grad_(True)
return x.view(*self.input_dim)
def classify_energy(self, x):
"""基于能量分类"""
energies = []
for y in range(self.n_classes):
y_tensor = torch.full((x.size(0),), y, dtype=torch.long, device=x.device)
e = self.energy(x, y_tensor)
energies.append(e)
energies = torch.stack(energies, dim=1)
return energies.argmin(dim=1)10. 总结与应用
10.1 EBM的应用场景
| 应用 | 说明 | 优势 |
|---|---|---|
| 图像生成 | 高质量样本生成 | 多模态覆盖好 |
| 异常检测 | 正常样本低能量,异常样本高能量 | 自然,无需异常样本 |
| 对抗鲁棒 | JEM等模型对抗样本抵抗力强 | 内置对抗训练 |
| 逆问题求解 | 通过最小化能量求解 | 无需专门求解器 |
| 半监督学习 | 利用未标注数据的能量景观 | 统一框架 |
10.2 与其他生成模型的关系
┌─────────────────────────────────────────────────────────────────────┐
│ 生成模型家族关系图 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ 生成模型 │
│ │ │
│ ┌───────────────┼───────────────┐ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────┐ ┌──────────┐ ┌──────────┐ │
│ │ GAN │ │ VAE │ │ EBM │ │
│ └────┬────┘ └─────┬────┘ └─────┬────┘ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌──────────┐ │ │
│ │ │ 归一化流 │ │ │
│ │ └─────┬────┘ │ │
│ │ │ │ │
│ └───────────────┼──────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────┐ │
│ │ 扩散模型 │ │
│ │ (郎之万采样) │ │
│ └────────────────┘ │
│ │
│ Energy Matching: EBM与Flow Matching的统一框架 │
│ (详见 [[energy-based-models-2025-unified|能量基模型与Flow Matching统一]]) │
│ │
└─────────────────────────────────────────────────────────────────────┘
相关文档
- Energy Matching — EBM与FM统一框架
- EBM vs Diffusion — 两种方法对比
- Diffusion Models — 扩散模型基础
- MCMC方法 — 马尔可夫链蒙特卡洛
- Score Matching — 分数匹配理论
- Modern Hopfield Networks — Hopfield网络进阶
参考文献
Footnotes
-
LeCun et al., “A Tutorial on Energy-Based Learning”, Predicting Structured Data, 2006. ↩
-
Hopfield, J. J., “Neural networks and physical systems with emergent collective computational abilities”, PNAS 1982. ↩
-
Ramsauer et al., “Hopfield Networks is All You Need”, ICLR 2021. ↩
-
Grathwohl et al., “JEM: Joint Energy-Based Models”, ICLR 2021. ↩
-
Hinton, G. E., “Training Products of Experts by Minimizing Contrastive Divergence”, Neural Computation 2002. ↩
-
Energy Matching: Unifying Flow Matching and Energy-Based Models, NeurIPS 2025. ↩