1. 概述

能量基模型(Energy-Based Model, EBM)是一类通过能量函数定义概率分布的生成模型。与直接参数化概率分布不同,EBM通过设计能量函数来描述数据的内在结构,具有表达能力强、条件化自然等优点。1

内容框架

┌─────────────────────────────────────────────────────────────────────┐
│                    能量基模型(EBM)全景图                             │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│   ┌─────────────────────────────────────────────────────────────┐   │
│   │                    理论基础                                  │   │
│   │   能量函数 E(x) → Gibbs分布 → 概率建模                       │   │
│   └────────────────────────────┬────────────────────────────────┘   │
│                                │                                     │
│   ┌────────────────────────────┼────────────────────────────────┐   │
│   │                       历史发展                                │   │
│   │   Hopfield网络 ──→ 现代EBM ──→ 与扩散模型融合                  │   │
│   └────────────────────────────┬────────────────────────────────┘   │
│                                │                                     │
│   ┌────────────────────────────┼────────────────────────────────┐   │
│   │                       采样方法                                │   │
│   │   郎之万动力学 ──→ 对比散度 ──→ 其他MCMC方法                   │   │
│   └────────────────────────────┬────────────────────────────────┘   │
│                                │                                     │
│   ┌────────────────────────────┼────────────────────────────────┐   │
│   │                       应用场景                                │   │
│   │   图像生成 ──→ 异常检测 ──→ 对抗鲁棒 ──→ 逆问题求解            │   │
│   └─────────────────────────────────────────────────────────────┘   │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

2. 能量基模型基础定义

2.1 能量函数的定义

能量函数 将每个输入样本映射到一个实数标量,表示该样本的”能量”水平。低能量对应高概率,高能量对应低概率。

典型形式

其中 是某个非负函数(如神经网络的输出)。

2.2 从能量到概率:Gibbs分布

通过Gibbs分布(也称Boltzmann分布)将能量函数转换为概率分布:

其中:

  • 配分函数
  • 逆温度参数(通常设为1)

物理直觉

┌─────────────────────────────────────────────────────────────────────┐
│                      能量景观与概率分布                                 │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  能量 E(x)                    概率 p(x)                              │
│                                                                     │
│     高能量                         低概率                            │
│        ↑                            ↓                               │
│     ┌──┴──┐                    ┌────┴────┐                         │
│     │     │                    │         │                          │
│     │  峰  │                    │  谷底   │                          │
│     │     │                    │ (高p)  │                          │
│     └─────┘                    └─────────┘                          │
│                                                                     │
│     低能量                         高概率                            │
│        ↓                            ↑                               │
│     ┌──┬──┐                    ┌────┬────┐                         │
│     │  │  │                    │    │    │                          │
│     │  全局  │                    │  峰顶  │                          │
│     │  最小值 │                    │ (最高p)│                          │
│     │  │  │                    │    │    │                          │
│     └─────┘                    └─────────┘                          │
│                                                                     │
│  Boltzmann因子: exp(-E(x)) 越大,能量越低,概率越高                    │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

2.3 配分函数的角色

配分函数 是概率分布的归一化常数:

核心挑战:对于高维数据, 通常无法解析计算,也难以数值估计。这是EBM训练的主要困难。

为什么配分函数难计算

维度样本空间大小配分函数计算难度
32×32图像不可能
784维MNIST不可能

2.4 能量模型的优势

优势说明
表达能力强任何非归一化密度都可以写成
自然条件化只需在能量函数中添加条件项即可
逆问题求解通过最小化能量可以直接求解逆问题
模式覆盖能量函数天然支持多模态分布
与物理的联系Boltzmann分布源于统计力学

3. EBM与概率分布的关系

3.1 各种分布的能量表示

Bernoulli分布

高斯分布

混合高斯分布

3.2 对数似然梯度

尽管配分函数无法直接计算,其梯度可以通过对数似然梯度间接获得:

对配分函数求导:

因此:

直观理解


4. Hopfield网络的历史回顾

4.1 Hopfield网络的起源

Hopfield网络(1982年)是最早的能量基模型之一,由John Hopfield提出,用于模拟生物神经网络的联想记忆功能。2

网络结构

其中 是神经元状态, 是连接权重, 是偏置。

4.2 能量函数与吸引子

Hopfield网络的吸引子对应能量的局部最小值,每个吸引子可以存储一个记忆模式:

┌─────────────────────────────────────────────────────────────────────┐
│                      Hopfield网络能量景观                             │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│                         E(s)                                        │
│                          │                                          │
│            ┌──────────┐  │  ┌──────────┐                            │
│           ╱│          │╲ │ ╱│          │╲                           │
│          ╱ │  吸引子1 │ ╲│╱ │ 吸引子2  │ ╲                          │
│         ╱  │ (记忆m₁) │  ╲  │ (记忆m₂) │  ╲                         │
│     ──╱────│          │──╲─│─│          │──╲─────                    │
│           │          │   ╲│╱│          │   ╲                        │
│            ╲──────────╱    │  ╲──────────╱    ╲                       │
│                          │                                          │
│     状态s沿能量下山梯度演化,趋向最近的吸引子                           │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

4.3 Hopfield更新规则

异步更新

能量单调下降

4.4 现代EBM与Hopfield网络的关系

现代能量基模型可以视为Hopfield网络的扩展:

特征Hopfield网络现代EBM
能量函数二次型神经网络参数化
状态 连续
更新确定性随机(郎之万)
容量无限制
应用联想记忆生成建模

现代扩展:Modern Hopfield Networks 和 注意力机制有深刻联系,详见相关研究。3


5. 现代EBM架构

5.1 基于能量的自编码器

Joint Embedding Models能量基模型形成对比。能量基自编码器(EBAE)直接参数化能量函数:

class EnergyBasedAutoEncoder(nn.Module):
    """
    基于能量的自编码器
    编码器:学习隐表示
    能量函数:测量 (x, z) 的兼容性
    解码器:可选,用于重建
    """
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        
        # 编码器
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        
        # 能量函数网络
        self.energy_net = nn.Sequential(
            nn.Linear(input_dim + latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # 输出标量能量
        )
        
        # 解码器(可选)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
    
    def energy(self, x, z=None):
        """计算能量 E(x, z)"""
        if z is None:
            z = self.encoder(x)
        return self.energy_net(torch.cat([x, z], dim=-1))
    
    def encode(self, x):
        """编码"""
        return self.encoder(x)
    
    def decode(self, z):
        """解码"""
        return self.decoder(z)
    
    def reconstruct(self, x):
        """重建"""
        z = self.encode(x)
        return self.decode(z)

5.2 联合能量基模型(JEM)

Joint Energy-based Models (JEM) 将分类器与EBM结合,同时实现分类和生成。4

class JEM(nn.Module):
    """
    Joint Energy-based Model
    统一分类和生成的能量基模型
    """
    def __init__(self, classifier, energy_fn):
        super().__init__()
        self.classifier = classifier
        self.energy_fn = energy_fn
    
    def energy(self, x, y):
        """
        计算条件能量 E(x | y)
        结合分类logit和能量函数
        """
        # 分类器输出的logit作为能量项
        class_energy = -self.classifier(x)  # 负logit
        generative_energy = self.energy_fn(x)
        
        # 组合
        return class_energy[:, y] + generative_energy
    
    def joint_energy(self, x, y):
        """
        计算联合能量 E(x, y) = E(x | y) + log p(y)
        """
        return self.energy(x, y) + torch.log_softmax(self.classifier(x), dim=1)[:, y]
    
    def forward(self, x, y=None):
        """
        前向传播
        如果提供y,计算 E(x, y)
        否则返回所有类别的能量
        """
        if y is not None:
            return self.joint_energy(x, y)
        return self.energy(x, torch.arange(x.size(0), device=x.device))
    
    def classify_with_energy(self, x):
        """
        基于能量的分类:选择能量最低的类别
        """
        return (-self.classifier(x)).argmax(dim=1)
    
    def log_prob(self, x, y, n_steps=20, step_size=0.1):
        """
        近似计算 log p(y|x) = -E(x,y) - log Z + const
        """
        energy = self.joint_energy(x, y)
        
        # 简化估计:使用分类器logit近似
        # 注意:这是近似,不涉及MCMC采样
        return -energy

5.3 能量函数设计原则

良好能量函数的设计要点

原则说明
灵活的表达能力使用深度网络捕捉复杂数据分布
平滑性相似的输入应有相似的能量
模式覆盖确保覆盖所有数据模式
可优化性梯度容易计算,便于训练
采样友好能量景观便于MCMC采样

6. 郎之万动力学采样

6.1 郎之万采样原理

**郎之万动力学(Langevin Dynamics)**是连续空间的MCMC采样方法,通过随机梯度下降从Boltzmann分布采样。

物理背景:布朗运动中粒子的运动方程:

其中 是高斯白噪声。

过阻尼极限):

6.2 离散时间郎之万采样

迭代公式

其中 是步长。

收敛性:当 ,链收敛到目标分布

┌─────────────────────────────────────────────────────────────────────┐
│                    郎之万采样轨迹示意                                 │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│         E(x)                                                        │
│          │                                                          │
│     高   │     ╱╲                                                    │
│          │    ╱  ╲            ╱╲                                   │
│          │   ╱    ╲          ╱  ╲         能量景观                    │
│     能   │  ╱      ╲   ╱╲  ╱    ╲        郎之万采样                 │
│          │ ╱        ╲╱  ╲╱      ╲       从高能区域                  │
│     量   │╱          向谷底移动    ╲      逐步下降到                  │
│          │                            ╲     低能区域                  │
│     低   │                             ╲                             │
│          │                              ● ←─采样点                   │
│          │                             ╱                              │
│          │                            ╱                               │
│     ─────┼──────────────────────────────────────────── x             │
│          │                                                          │
│          │                                                          │
│     ════════════════════════════════════════════════════            │
│                      时间演化                                        │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

6.3 郎之万采样实现

class LangevinSampler:
    """
    郎之万采样器
    用于从能量基模型定义的分布中采样
    """
    def __init__(self, energy_fn, device='cuda'):
        self.energy_fn = energy_fn
        self.device = device
    
    def sample(self, x_init, n_steps=100, step_size=0.01, noise_scale=None, return_trajectory=False):
        """
        郎之万采样
        
        Args:
            x_init: 初始点
            n_steps: 采样步数
            step_size: 步长 η
            noise_scale: 噪声尺度,默认 sqrt(2*step_size)
            return_trajectory: 是否返回采样轨迹
        
        Returns:
            采样结果 (可能包含轨迹)
        """
        x = x_init.detach().clone().to(self.device).requires_grad_(True)
        
        if noise_scale is None:
            noise_scale = np.sqrt(2 * step_size)
        
        trajectory = [x.detach().cpu()]
        
        for step in range(n_steps):
            # 计算能量梯度
            energy = self.energy_fn(x)
            grad = torch.autograd.grad(
                energy.sum(), x, create_graph=False
            )[0]
            
            # 郎之万更新
            noise = torch.randn_like(x)
            x = x - step_size * grad + noise_scale * noise
            x = x.detach().requires_grad_(True)
            
            if return_trajectory:
                trajectory.append(x.detach().cpu())
        
        if return_trajectory:
            return x.detach(), torch.stack(trajectory)
        return x.detach()
    
    def sample_multiple(self, n_samples, x_init_fn, n_steps=100, step_size=0.01):
        """
        并行采样多个样本
        """
        samples = []
        for _ in range(n_samples):
            x_init = x_init_fn()
            x_sample = self.sample(x_init, n_steps, step_size)
            samples.append(x_sample)
        return torch.stack(samples)
 
 
class EBMWithLangevinSampling(nn.Module):
    """
    带郎之万采样的能量基模型
    """
    def __init__(self, energy_net, sampler=None):
        super().__init__()
        self.energy_net = energy_net
        self.sampler = sampler or LangevinSampler(self.energy)
    
    def energy(self, x):
        """能量函数"""
        return self.energy_net(x).squeeze(-1)
    
    def score(self, x):
        """分数函数 ∇_x log p(x) = -∇_x E(x)"""
        return -torch.autograd.grad(
            self.energy(x).sum(), x, create_graph=True
        )[0]
    
    def sample(self, x_init, n_steps=100, step_size=0.01):
        """从模型分布中采样"""
        return self.sampler.sample(x_init, n_steps, step_size)
    
    def generate(self, batch_size, shape, n_steps=100, step_size=0.01):
        """生成样本"""
        x_init = torch.randn(batch_size, *shape, device=next(self.parameters()).device)
        return self.sample(x_init, n_steps, step_size)

6.4 退火郎之万采样

退火技术可以加速高维复杂分布的采样:

def annealed_langevin_sampling(energy_fn, x_init, n_steps=100, 
                                step_size=0.01, noise_scale=0.01,
                                n_temps=10):
    """
    退火郎之万采样
    从高温逐步降低到低温,帮助跳出局部极小
    """
    x = x_init.detach().clone().requires_grad_(True)
    
    # 温度调度:从高到低
    temps = torch.linspace(10.0, 0.1, n_temps)
    
    for temp in temps:
        steps_per_temp = n_steps // n_temps
        
        for step in range(steps_per_temp):
            # 缩放步长和噪声
            scaled_step = step_size * temp
            scaled_noise = noise_scale * np.sqrt(temp)
            
            # 计算梯度
            energy = energy_fn(x) / temp
            grad = torch.autograd.grad(energy.sum(), x, create_graph=False)[0]
            
            # 郎之万更新
            noise = torch.randn_like(x)
            x = x - scaled_step * grad + scaled_noise * noise
            x = x.detach().requires_grad_(True)
    
    return x.detach()

7. 对比散度(Contrastive Divergence)算法

7.1 CD算法的动机

直接计算配分函数的梯度需要从模型分布中采样,这本身就是困难的问题。**对比散度(CD)**通过近似解决这个问题。5

核心思想:使用少量吉布斯采样步骤从数据初始化,而不是从随机状态开始。

7.2 CD-k算法

CD-k算法步骤

┌─────────────────────────────────────────────────────────────────────┐
│                    对比散度 (CD-k) 算法                               │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  输入: 数据样本 x₀ ~ p_data, E_θ(x)                                  │
│  输出: 梯度估计                                                      │
│                                                                     │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                                                             │   │
│  │  1. 计算正相位梯度(数据驱动)                               │   │
│  │     ▽_θ⁺ = ∂E_θ(x₀)/∂θ                                     │   │
│  │                                                             │   │
│  │  2. k步吉布斯采样(从数据初始化)                           │   │
│  │     for t = 1 to k:                                         │   │
│  │         x_t ~ p_θ(x | x_{t-1})                              │   │
│  │         # 交替更新隐变量和可见变量                          │   │
│  │                                                             │   │
│  │  3. 计算负相位梯度(模型驱动)                              │   │
│  │     ▽_θ⁻ = ∂E_θ(x_k)/∂θ                                    │   │
│  │                                                             │   │
│  │  4. 梯度估计                                                │   │
│  │     ▽_θ ≈ ▽_θ⁺ - ▽_θ⁻                                      │   │
│  │                                                             │   │
│  └─────────────────────────────────────────────────────────────┘   │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

数学推导

其中 是经过 步吉布斯采样后的分布。

7.3 CD算法变体

变体描述特点
CD-k标准k步CD简单,常用k=1
PCD持续CD维护负样本池
CDnn个k步采样更准确但慢
Fast CD调整学习率加速收敛

7.4 PyTorch实现CD

class ContrastiveDivergenceTrainer:
    """
    对比散度训练器
    用于训练能量基模型
    """
    def __init__(self, energy_net, k=1, lr=0.01, momentum=0.5):
        self.energy_net = energy_net
        self.k = k  # 吉布斯采样步数
        self.lr = lr
        self.momentum = momentum
        
        # 动量
        self.m = {name: torch.zeros_like(param) 
                  for name, param in energy_net.named_parameters()}
    
    def gibbs_step(self, x):
        """
        吉布斯采样一步
        对于EBM,简化使用朗之万一步
        """
        x = x.detach().requires_grad_(True)
        
        # 计算能量
        energy = self.energy_net(x).sum()
        
        # 梯度
        grad = torch.autograd.grad(energy, x, create_graph=False)[0]
        
        # 朗之万更新(简化)
        step_size = 0.01
        noise_scale = 0.01
        x_new = x - step_size * grad + noise_scale * torch.randn_like(x)
        
        return x_new.detach()
    
    def negative_sampling(self, x_data, k=None):
        """
        从数据样本出发,进行k步采样生成负样本
        """
        k = k or self.k
        x_neg = x_data.clone()
        
        for _ in range(k):
            x_neg = self.gibbs_step(x_neg)
        
        return x_neg
    
    def compute_gradients(self, x_pos, x_neg):
        """
        计算CD梯度
        """
        x_pos = x_pos.detach().requires_grad_(True)
        x_neg = x_neg.detach().requires_grad_(True)
        
        # 正相位梯度
        energy_pos = self.energy_net(x_pos).sum()
        grad_pos = torch.autograd.grad(energy_pos, 
                                        self.energy_net.parameters(),
                                        retain_graph=True)
        
        # 负相位梯度
        energy_neg = self.energy_net(x_neg).sum()
        grad_neg = torch.autograd.grad(energy_neg, 
                                        self.energy_net.parameters())
        
        # CD梯度 = 正相位 - 负相位
        grads = []
        for g_pos, g_neg in zip(grad_pos, grad_neg):
            grads.append(g_pos - g_neg)
        
        return grads
    
    def step(self, x_data, x_neg_init=None):
        """
        一步训练更新
        """
        # 生成负样本
        if x_neg_init is not None:
            x_neg = self.negative_sampling(x_neg_init)
        else:
            x_neg = self.negative_sampling(x_data)
        
        # 计算梯度
        grads = self.compute_gradients(x_data, x_neg)
        
        # 更新参数(带动量)
        for (name, param), grad in zip(self.energy_net.named_parameters(), grads):
            self.m[name] = self.momentum * self.m[name] + (1 - self.momentum) * grad
            param.data = param.data - self.lr * self.m[name]
        
        # 计算损失(能量差)
        with torch.no_grad():
            loss = self.energy_net(x_neg).mean() - self.energy_net(x_data).mean()
        
        return loss.item(), x_neg
 
 
def train_ebm_cd(energy_net, data_loader, n_epochs=100, k=1):
    """
    使用对比散度训练EBM
    """
    trainer = ContrastiveDivergenceTrainer(energy_net, k=k)
    
    for epoch in range(n_epochs):
        epoch_loss = 0
        neg_samples = None
        
        for batch_idx, x_data in enumerate(data_loader):
            x_data = x_data.view(x_data.size(0), -1).to(next(energy_net.parameters()).device)
            
            # 训练
            loss, neg = trainer.step(x_data, neg_samples)
            epoch_loss += loss
            
            # 持久化:使用上一轮的负样本作为下一轮的初始点
            neg_samples = neg.detach()
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss = {epoch_loss/len(data_loader):.4f}")
    
    return energy_net

8. EBM与扩散模型的联系

8.1 统一视角

EBM和扩散模型(DDPM)本质上都定义了数据分布,但参数化方式和采样方法不同。6

对比表

维度能量基模型 (EBM)扩散模型 (DDPM)
分布表示
显式归一化否(需是(自动归一化)
采样机制MCMC(郎之万等)逆时间SDE/ODE
条件化自然(加能量项)需classifier guidance
训练目标对比散度/分数匹配去噪重建

8.2 分数函数连接

EBM的分数函数与扩散模型的核心量有深刻联系:

分数匹配是连接两者的桥梁:

详见 分数匹配理论

8.3 EBM作为扩散模型的一步

扩散模型的逆过程可以视为从噪声分布出发,逐步降低能量到达数据分布:

这正是郎之万采样的形式!

8.4 能量引导的扩散采样

可以利用EBM的能量函数引导扩散模型采样:

class EnergyGuidedDiffusion:
    """
    能量引导的扩散模型
    结合扩散模型的语义生成能力和EBM的细粒度控制
    """
    def __init__(self, diffusion_model, energy_net):
        self.diffusion = diffusion_model
        self.energy = energy_net
        self.beta = 0.1  # 引导强度
    
    def guided_score(self, x_t, t, condition=None):
        """
        能量引导的分数估计
        """
        # 扩散模型分数
        diffusion_score = self.diffusion.score(x_t, t)
        
        # EBM分数(仅在低噪声时使用)
        if t > self.diffusion.num_timesteps // 2:
            energy_score = self.energy.score(x_t)
            # 引导
            diffusion_score = diffusion_score + self.beta * energy_score
        
        return diffusion_score
    
    def sample_guided(self, shape, condition=None, n_steps=1000):
        """
        引导采样
        """
        x_t = torch.randn(shape, device=next(self.diffusion.parameters()).device)
        
        for t in reversed(range(n_steps)):
            # 估计分数
            score = self.guided_score(x_t, t, condition)
            
            # 更新(简化的DDIM采样)
            x_t = x_t - self.diffusion.step_size * score
        
        return x_t

9. PyTorch完整实现示例

9.1 完整EBM训练流程

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, MNIST
import torchvision.transforms as transforms
 
class SimpleEBM(nn.Module):
    """
    简单能量基模型
    用于MNIST图像生成
    """
    def __init__(self, hidden_dim=512):
        super().__init__()
        
        # 能量网络:多层感知机
        self.net = nn.Sequential(
            nn.Linear(784, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # 输出能量
        )
    
    def energy(self, x):
        """计算能量"""
        x = x.view(x.size(0), -1)
        return self.net(x).squeeze(-1)
    
    def score(self, x):
        """分数函数"""
        x = x.view(x.size(0), -1).requires_grad_(True)
        e = self.energy(x)
        return -torch.autograd.grad(e.sum(), x, create_graph=True)[0]
    
    def langevin_sample(self, x_init, n_steps=60, step_size=0.01, noise_scale=0.01):
        """郎之万采样"""
        x = x_init.view(x_init.size(0), -1).clone().detach().requires_grad_(True)
        
        for _ in range(n_steps):
            # 梯度
            grad = torch.autograd.grad(self.energy(x).sum(), x)[0]
            
            # 更新
            noise = torch.randn_like(x)
            x = x - step_size * grad + noise_scale * noise
            x = x.detach().requires_grad_(True)
        
        return x
    
    def forward(self, x):
        return self.energy(x)
 
 
class EBMTrainer:
    """EBM训练器"""
    def __init__(self, model, lr=1e-4, lazy_grad=True):
        self.model = model
        self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        self.lazy_grad = lazy_grad
    
    def train_step(self, x_pos, x_neg_init=None, n_cd=1):
        """
        单步训练
        使用对比散度
        """
        self.model.train()
        self.optimizer.zero_grad()
        
        # 正样本能量
        e_pos = self.model.energy(x_pos)
        
        # 负样本生成
        if x_neg_init is not None:
            x_neg = x_neg_init
        else:
            x_neg = torch.randn_like(x_pos) * 0.1  # 从接近零的噪声初始化
        
        # CD-k采样
        for _ in range(n_cd):
            x_neg = self.model.langevin_sample(x_neg, n_steps=20, 
                                               step_size=0.003, noise_scale=0.003)
        
        # 负样本能量
        e_neg = self.model.energy(x_neg)
        
        # 对比散度损失
        # 目标:正样本能量低,负样本能量高
        loss = F.relu(e_pos.mean() - e_neg.mean() + 0.1)  # margin=0.1
        
        loss.backward()
        self.optimizer.step()
        
        return {
            'loss': loss.item(),
            'e_pos': e_pos.mean().item(),
            'e_neg': e_neg.mean().item()
        }
    
    def sample(self, n_samples=64):
        """生成样本"""
        self.model.eval()
        with torch.no_grad():
            x_init = torch.randn(n_samples, 1, 28, 28, device=next(self.model.parameters()).device)
            x_samples = self.model.langevin_sample(x_init, n_steps=100, 
                                                   step_size=0.01, noise_scale=0.01)
            x_samples = x_samples.view(-1, 1, 28, 28)
        return x_samples
 
 
def main():
    # 加载数据
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: (x - 0.5) * 2)  # 归一化到 [-1, 1]
    ])
    
    dataset = MNIST(root='./data', train=True, transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)
    
    # 创建模型
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = SimpleEBM(hidden_dim=1024).to(device)
    
    # 创建训练器
    trainer = EBMTrainer(model, lr=1e-4)
    
    # 训练
    neg_samples = None
    for epoch in range(50):
        epoch_stats = {'loss': 0, 'e_pos': 0, 'e_neg': 0}
        n_batches = 0
        
        for batch_idx, (x_data, _) in enumerate(dataloader):
            x_data = x_data.to(device)
            
            # 训练
            stats = trainer.train_step(x_data, neg_samples)
            
            # 持久化:使用最后一个负样本
            if batch_idx == len(dataloader) - 1:
                neg_samples = model.langevin_sample(x_data, n_steps=20).detach()
            
            epoch_stats['loss'] += stats['loss']
            epoch_stats['e_pos'] += stats['e_pos']
            epoch_stats['e_neg'] += stats['e_neg']
            n_batches += 1
        
        # 打印统计
        print(f"Epoch {epoch}: Loss={epoch_stats['loss']/n_batches:.4f}, "
              f"E_pos={epoch_stats['e_pos']/n_batches:.4f}, "
              f"E_neg={epoch_stats['e_neg']/n_batches:.4f}")
        
        # 定期生成样本
        if epoch % 10 == 0:
            samples = trainer.sample(n_samples=64)
            # 可以保存或可视化 samples
 
 
if __name__ == '__main__':
    main()

9.2 条件EBM实现

class ConditionalEBM(nn.Module):
    """
    条件能量基模型
    支持类别条件生成
    """
    def __init__(self, input_dim, n_classes, hidden_dim=512):
        super().__init__()
        self.n_classes = n_classes
        self.input_dim = input_dim
        
        # 类别嵌入
        self.class_embed = nn.Embedding(n_classes, hidden_dim)
        
        # 能量网络
        self.net = nn.Sequential(
            nn.Linear(input_dim + hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def energy(self, x, y=None):
        """
        计算条件能量 E(x | y)
        """
        if y is not None:
            y_emb = self.class_embed(y)
            x_y = torch.cat([x.view(x.size(0), -1), y_emb], dim=-1)
        else:
            x_y = x.view(x.size(0), -1)
        
        return self.net(x_y).squeeze(-1)
    
    def conditional_sample(self, y, n_steps=100):
        """从条件分布 p(x|y) 采样"""
        y_tensor = torch.full((1,), y, dtype=torch.long, device=next(self.parameters()).device)
        x_init = torch.randn(1, *self.input_dim, device=next(self.parameters()).device)
        
        x = x_init.view(1, -1).clone().detach().requires_grad_(True)
        
        for _ in range(n_steps):
            e = self.energy(x, y_tensor)
            grad = torch.autograd.grad(e.sum(), x)[0]
            
            x = x - 0.01 * grad + 0.01 * torch.randn_like(x)
            x = x.detach().requires_grad_(True)
        
        return x.view(*self.input_dim)
    
    def classify_energy(self, x):
        """基于能量分类"""
        energies = []
        for y in range(self.n_classes):
            y_tensor = torch.full((x.size(0),), y, dtype=torch.long, device=x.device)
            e = self.energy(x, y_tensor)
            energies.append(e)
        
        energies = torch.stack(energies, dim=1)
        return energies.argmin(dim=1)

10. 总结与应用

10.1 EBM的应用场景

应用说明优势
图像生成高质量样本生成多模态覆盖好
异常检测正常样本低能量,异常样本高能量自然,无需异常样本
对抗鲁棒JEM等模型对抗样本抵抗力强内置对抗训练
逆问题求解通过最小化能量求解无需专门求解器
半监督学习利用未标注数据的能量景观统一框架

10.2 与其他生成模型的关系

┌─────────────────────────────────────────────────────────────────────┐
│                    生成模型家族关系图                                   │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│                      生成模型                                        │
│                         │                                           │
│         ┌───────────────┼───────────────┐                          │
│         │               │               │                          │
│         ▼               ▼               ▼                           │
│   ┌─────────┐    ┌──────────┐   ┌──────────┐                      │
│   │   GAN   │    │   VAE    │   │   EBM    │                      │
│   └────┬────┘    └─────┬────┘   └─────┬────┘                      │
│        │               │              │                            │
│        │               ▼              │                            │
│        │         ┌──────────┐         │                            │
│        │         │ 归一化流  │         │                            │
│        │         └─────┬────┘         │                            │
│        │               │              │                            │
│        └───────────────┼──────────────┘                            │
│                        │                                            │
│                        ▼                                            │
│               ┌────────────────┐                                    │
│               │    扩散模型     │                                    │
│               │  (郎之万采样)   │                                    │
│               └────────────────┘                                    │
│                                                                     │
│  Energy Matching: EBM与Flow Matching的统一框架                       │
│  (详见 [[energy-based-models-2025-unified|能量基模型与Flow Matching统一]])  │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

相关文档


参考文献

Footnotes

  1. LeCun et al., “A Tutorial on Energy-Based Learning”, Predicting Structured Data, 2006.

  2. Hopfield, J. J., “Neural networks and physical systems with emergent collective computational abilities”, PNAS 1982.

  3. Ramsauer et al., “Hopfield Networks is All You Need”, ICLR 2021.

  4. Grathwohl et al., “JEM: Joint Energy-Based Models”, ICLR 2021.

  5. Hinton, G. E., “Training Products of Experts by Minimizing Contrastive Divergence”, Neural Computation 2002.

  6. Energy Matching: Unifying Flow Matching and Energy-Based Models, NeurIPS 2025.