概述

Score Matching提供了一种无需计算归一化常数即可学习能量模型的方法,与扩散模型有着深刻的数学联系。Song等人(2021)将这一理论扩展到随机微分方程(SDE)框架,统一了离散扩散模型的各种变体。1


Score Function

定义

对于概率密度 ,其Score函数定义为对数密度的梯度:

几何意义上,Score函数指向概率密度增长最快的方向:

物理直觉

考虑一维情况,Score函数有直观解释:

其中 是能量函数, 是配分函数。则:

即Score函数是能量梯度的负方向,指向低能量(高概率)区域。

与朗之万动力学的联系

已知Score函数后,可通过**朗之万动力学(Langevin Dynamics)**采样:

其中 是步长。当 时, 收敛到真实分布


Score Matching

问题设定

能量模型定义为:

其中 是能量函数, 是** intractable**的配分函数。

直接最大化似然 需要计算 ,这通常不可行。

显式Score Matching (ESM)

目标函数直接最小化预测Score与真实Score的差异:

展开后:

其中 是与 无关的常数。

问题:需要计算 (Hessian的迹),对于神经网络计算量巨大。

切片Score Matching (SSM)

核心思想:用随机投影降低计算复杂度。

对于随机向量

切片Score Matching目标:

只需计算Jacobian-vector products,避免显式计算Hessian。

去噪Score Matching (DSM)

关键洞察:对于特定形式的噪声分布,可以得到Score的闭式表达式。

,则:

去噪Score Matching目标:

这正是DDPM训练目标的数学基础。


SDE统一框架

从离散到连续

DDPM的离散前向过程:

当步数 ,步长 时,离散过程收敛为连续SDE:

其中 是维纳过程(Wiener Process)。

前向SDE

一般形式的前向SDE:

类型漂移项 扩散项 特点
VP (Variance Preserving) 方差保持
VE (Variance Exploding)方差随时间爆炸
subVPVP的改进

逆向SDE

从时间 反向求解,需要知道反向时间的SDE:

其中 是反向维纳过程, 正是我们需要学习的Score函数。

统一视角

离散DDPM ←─────────────→ 连续SDE
    ↓                        ↓
预测噪声 ε_θ          预测Score s_θ
    ↓                        ↓
等价变换 ───────────→ 完全等价

核心结论:学习预测噪声 等价于学习Score函数


SDE求解与采样

Euler-Maruyama方法

最简单的SDE数值求解:

def euler_maruyama(score_fn, xT, T, dt=1e-5, device='cuda'):
    """Euler-Maruyama求解逆向SDE"""
    x = xT
    for t in reversed(range(0, int(T), int(1/dt))):
        t_norm = t / T
        drift = ...  # 漂移项
        diffusion = ...  # 扩散项
        x = x + drift * dt + diffusion * np.sqrt(dt) * np.random.randn()
    return x

Predictor-Corrector方法

Predictor:使用数值SDE求解器预测下一步
Corrector:使用朗之万动力学校正

def predictor_corrector_sampling(score_fn, xT, T, n_steps=1000, corrector_steps=10, eps=1e-5):
    """Predictor-Corrector采样"""
    dt = T / n_steps
    x = xT
    
    for i in range(n_steps):
        t = T - i * dt
        
        # Predictor: Euler-Maruyama
        drift = -0.5 * beta(t) * x - beta(t) * score_fn(x, t)
        diffusion = np.sqrt(beta(t))
        x_pred = x + drift * dt + diffusion * np.sqrt(dt) * np.random.randn()
        
        # Corrector: 朗之万动力学
        for _ in range(corrector_steps):
            grad = score_fn(x_pred, t - dt)
            x_pred = x_pred + corrector_lr * grad + np.sqrt(2 * corrector_lr) * np.random.randn()
        
        x = x_pred
    
    return x

概率流ODE

SDE对应的ODE形式(确定性采样):

使用ODE求解器(如Runge-Kutta)可实现确定性采样,路径可逆。


理论分析

收敛性保证

对于VP-SDE,可以证明以下收敛性:

定理:设Score估计误差为 ,则使用Euler-Maruyama采样的误差满足:

且步长 时,误差趋于零。

与DDPM的关系

方面DDPMScore SDE
时间离散 连续
采样固定步数 可变步长
理论基础变分推断Score Matching + SDE
加速DDIMODE/SDE求解器

统一损失函数

从Score Matching视角,DDPM的简化损失可写为:

其中权重 与噪声调度相关。


实践技巧

条件Score估计

对于条件生成 ,Classifier-Free Guidance扩展:

混合时间步训练

为提高低噪声时间步的Score估计质量:

def mixed_time_loss(model, x0, t_strategy='uniform'):
    # 策略1: 均匀采样
    if t_strategy == 'uniform':
        t = torch.randint(0, T, (batch_size,))
    
    # 策略2: 重要性采样(低t权重更高)
    elif t_strategy == 'importance':
        snr = alphas_bar / (1 - alphas_bar)
        weights = snr / snr.sum()
        t = torch.multinomial(weights, batch_size, replacement=True)
    
    # 策略3: 最小化SNR损失
    elif t_strategy == 'snr':
        t = torch.randint(1, T, (batch_size,))  # 跳过t=0
        snr_t = alphas_bar[t] / (1 - alphas_bar[t])
        loss = snr_t * ((1 + alphas_bar[t]) * eps_pred - eps_true).square()

架构选择

  • U-Net + Attention:标准选择,适合图像
  • Transformer:DiT架构,更好的可扩展性
  • EDM框架:Karras等人的统一实现框架

代码实现

完整Score SDE框架

import torch
import torch.nn as nn
import numpy as np
 
class ScoreSDE(nn.Module):
    """基于SDE的统一Score模型"""
    
    def __init__(self, sde_type='vp', beta_min=0.1, beta_max=20.0):
        super().__init__()
        self.sde_type = sde_type
        self.beta_min = beta_min
        self.beta_max = beta_max
        
        # 神经网络:预测Score
        self.score_net = ScoreNetwork()
    
    def beta(self, t):
        """噪声调度"""
        return self.beta_min + t * (self.beta_max - self.beta_min)
    
    def drift(self, x, t):
        """漂移项 f(x, t)"""
        if self.sde_type == 'vp':
            return -0.5 * self.beta(t) * x
        elif self.sde_type == 've':
            return torch.zeros_like(x)
    
    def diffusion(self, t):
        """扩散项 g(t)"""
        if self.sde_type == 'vp':
            return np.sqrt(self.beta(t))
        elif self.sde_type == 've':
            return self.beta(t)
    
    def forward(self, x, t):
        """预测Score函数"""
        return self.score_net(x, t)
    
    @torch.no_grad()
    def euler_maruyama_sampling(self, shape, n_steps=1000, device='cuda'):
        """Euler-Maruyama采样"""
        x = torch.randn(shape, device=device)
        dt = 1.0 / n_steps
        
        for i in range(n_steps):
            t = 1.0 - i * dt
            t_tensor = torch.full((shape[0],), t, device=device)
            
            score = self.score_net(x, t_tensor)
            drift = self.drift(x, t)
            diff = self.diffusion(t)
            
            # 逆向SDE
            x = x - (drift - diff**2 * score) * dt + diff * np.sqrt(dt) * torch.randn_like(x)
        
        return x
    
    @torch.no_grad()
    def ode_sampling(self, shape, n_steps=1000, device='cuda'):
        """概率流ODE采样(确定性)"""
        x = torch.randn(shape, device=device)
        dt = 1.0 / n_steps
        
        for i in range(n_steps):
            t = 1.0 - i * dt
            t_tensor = torch.full((shape[0],), t, device=device)
            
            score = self.score_net(x, t_tensor)
            drift = self.drift(x, t)
            diff = self.diffusion(t)
            
            # ODE
            x = x - (drift - 0.5 * diff**2 * score) * dt
        
        return x
 
class ScoreNetwork(nn.Module):
    """Score估计网络"""
    
    def __init__(self, dim=64, time_dim=128):
        super().__init__()
        self.time_mlp = SinusoidalPosEmb(time_dim)
        
        # U-Net风格编码器
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, dim, 3, padding=1),
            nn.GroupNorm(8, dim),
            nn.SiLU()
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(dim, dim*2, 3, stride=2, padding=1),
            nn.GroupNorm(8, dim*2),
            nn.SiLU()
        )
        
        # 中间层
        self.mid = nn.Sequential(
            nn.Conv2d(dim*2, dim*2, 3, padding=1),
            nn.GroupNorm(8, dim*2),
            nn.SiLU()
        )
        
        # 解码器
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(dim*2, dim, 4, stride=2, padding=1),
            nn.GroupNorm(8, dim),
            nn.SiLU()
        )
        self.dec1 = nn.Sequential(
            nn.Conv2d(dim, dim, 3, padding=1),
            nn.GroupNorm(8, dim),
            nn.SiLU(),
            nn.Conv2d(dim, 3, 3, padding=1)
        )
    
    def forward(self, x, t):
        t_emb = self.time_mlp(t)
        
        h1 = self.enc1(x)
        h2 = self.enc2(h1)
        h_mid = self.mid(h2)
        h2_out = self.dec2(h_mid)
        out = self.dec1(h2_out + h1)
        
        # 返回Score(而非噪声)
        return out
 
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
    
    def forward(self, t):
        half = self.dim // 2
        emb = torch.log(torch.tensor(10000.0)) / (half - 1)
        emb = torch.exp(torch.arange(half, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
        return emb

延伸阅读

重要论文

论文年份贡献
Score Matching2005Score Matching基础理论
NCSN2019噪声条件Score网络
Score SDE2021SDE统一框架
EDM2022改进的训练与采样框架

学习资源


参考资料

Footnotes

  1. Song, Y., Sohl-Dickstein, J., Kingma, D. P., Kumar, A., Ermon, S., & Poole, B. (2021). Score-Based Generative Modeling through Stochastic Differential Equations. ICLR 2021. https://arxiv.org/abs/2011.13456