最优传输与Wasserstein距离

最优传输(Optimal Transport, OT)理论是概率论与几何学的交叉领域,由 Monge 在 1781 年提出。经过两个多世纪的发展,它已成为现代机器学习中不可或缺的工具,尤其在生成模型、领域适应、聚类分析等领域发挥着关键作用。1

背景:从搬砖问题说起

Monge 的搬砖问题

想象你要将一堆砖从施工现场(分布 )运到建筑工地(分布 ):

        施工现场 P                建筑工地 Q
        
     ● ● ●                       ●
   ● ●                         ● ● ●
     ● ● ●                   ● ● ●
         ↘ ↘ ↘                   ↙ ↙ ↙
          搬 运                  运 输

问题:如何安排运输计划使得总搬运成本最小?

这就是最优传输的核心问题。


形式化定义

离散情形:Monge 问题

设:

  • :源分布( 位置, 质量)
  • :目标分布( 位置, 质量)
  • :从 的单位运输成本

Monge 传输问题

其中 表示 的质量恰好映射到

关键问题

Monge 问题存在两个困难:

  1. 映射可能不存在:当质量分布不匹配时
  2. 优化问题高度非线性 的组合性质

Kantorovich 松弛

Kantorovich 在 1942 年提出将问题松弛为线性规划:

Kantorovich 传输问题

其中 耦合矩阵的集合,满足:

物理意义 表示从 运到 的质量。

耦合矩阵 γ:

          y_1    y_2    y_3   ...  → Q 的概率
        ┌──────┬──────┬──────┐
    x_1 │ 0.2  │ 0.1  │ 0.0  │
        ├──────┼──────┼──────┤
    x_2 │ 0.0  │ 0.3  │ 0.1  │
        ├──────┼──────┼──────┤
    x_3 │ 0.0  │ 0.0  │ 0.2  │
        └──────┴──────┴──────┘
          ↑      ↑      ↑
        P 的概率

Monge vs Kantorovich

特性MongeKantorovich
传输映射确定性 概率耦合
解的存在性不一定存在总是存在
计算复杂度非凸线性规划(凸)
适用场景质量精确匹配一般分布

Wasserstein 距离

定义

当成本函数为 )时,Kantorovich 问题的最优值定义了 -Wasserstein 距离

对于离散分布:

常用 Wasserstein 距离

距离公式特点
(Earth Mover’s)线性成本
二次成本
最小化最大距离

Wasserstein-1 的特殊性质

距离也称为 Earth Mover’s Distance (EMD),具有以下重要性质:

这称为 Kantorovich-Rubinstein 对偶性

Wasserstein 距离的几何理解

                    P          Q
                    
                    ●          ●
                   ╱ ╲        ╱ ╲
                  ╱   ╲      ╱   ╲
                 ●─────●    ●─────●
                 
                 传输成本 = Σ 质量 × 距离
                          = 0.3×0.2 + 0.2×0.3 + ...

与 KL 散度的对比

特性Wasserstein 距离 KL 散度
定义几何距离信息熵差
度量性质满足度量(对称、三角)不满足对称性
对零重叠的处理有定义(有意义)无定义(
计算复杂度(线性规划)
梯度特性平滑可能无界

为什么 Wasserstein 更好?

import numpy as np
 
def demo_wasserstein_vs_kl():
    """
    演示 Wasserstein 距离 vs KL 散度
    """
    # 两个分布:支撑集几乎不重叠
    P = np.array([0.99, 0.01])
    Q = np.array([0.01, 0.99])
    
    # KL 散度(不重叠时爆炸)
    kl_PQ = np.sum(P * np.log(P / (Q + 1e-10) + 1e-10))
    kl_QP = np.sum(Q * np.log(Q / (P + 1e-10) + 1e-10))
    
    # Wasserstein-1(始终有定义)
    w1 = 0.99 * 1 + 0.01 * 0  # 简化计算
    
    print(f"KL(P||Q): {kl_PQ:.2f}")  # ~4.56
    print(f"KL(Q||P): {kl_QP:.2f}")  # ~4.56
    print(f"W_1(P,Q): {w1:.2f}")    # ~0.98
    
    """
    结论:当分布支撑集不重叠时
    - KL → ∞(无法优化)
    - Wasserstein → 有意义的具体数值
    """
 
demo_wasserstein_vs_kl()

梯度对比可视化

                KL divergence梯度
                
                ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
                ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
                ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
                ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑  ← 梯度爆炸!
           ─────┼────────────────→ x
                 0


                Wasserstein梯度
                
                ↑                 ↓
                ↑                ↓
                ↑                ↓
           ─────┼────────────────→ x
                0                 ← 平滑过渡

这就是 Wasserstein GAN 优于标准 GAN 的原因!


Sinkhorn 算法

Entropic 正则化

直接在 Wasserstein 距离上优化计算量很大。Cuturi (2013) 提出了 entropic 正则化

正则化项使得最优耦合 具有可分离形式

其中 称为 Gibbs 核

Sinkhorn 迭代

给定 矩阵,Sinkhorn 算法通过交替归一化求解:

import torch
import numpy as np
 
def sinkhorn(a, b, C, epsilon=0.1, num_iters=100, device='cpu'):
    """
    Sinkhorn 算法计算 entropic OT
    
    Args:
        a: 源分布权重, shape [n]
        b: 目标分布权重, shape [m]
        C: 成本矩阵, shape [n, m]
        epsilon: 正则化参数(越小越接近真实 OT)
        num_iters: 迭代次数
    
    Returns:
        gamma: 最优耦合矩阵
    """
    n, m = len(a), len(b)
    
    # 确保是 PyTorch 张量
    a = torch.tensor(a, dtype=torch.float64, device=device)
    b = torch.tensor(b, dtype=torch.float64, device=device)
    C = torch.tensor(C, dtype=torch.float64, device=device)
    
    # Gibbs 核
    K = torch.exp(-C / epsilon)
    
    # 初始化
    u = torch.ones(n, dtype=torch.float64, device=device)
    v = torch.ones(m, dtype=torch.float64, device=device)
    
    # Sinkhorn 迭代
    for i in range(num_iters):
        # 更新 u
        u = a / (K @ v + 1e-10)
        # 更新 v
        v = b / (K.T @ u + 1e-10)
        
        # 数值稳定性检查
        if torch.isnan(u).any() or torch.isnan(v).any():
            print(f"Sinkhorn diverged at iteration {i}")
            break
    
    # 计算最优耦合
    gamma = u.view(-1, 1) * K * v.view(1, -1)
    
    return gamma
 
 
def wasserstein_1(P, Q, epsilon=1e-3):
    """
    近似计算 Wasserstein-1 距离
    
    使用 Sinkhorn 近似
    """
    n, m = P.shape[0], Q.shape[0]
    
    # 成本矩阵(假设 1D 分布)
    x = np.arange(n) / n
    y = np.arange(m) / m
    C = np.abs(x[:, None] - y[None, :])
    
    # 均匀分布
    a = np.ones(n) / n
    b = np.ones(m) / m
    
    # Sinkhorn 计算
    gamma = sinkhorn(a, b, C, epsilon)
    
    # Wasserstein-1 = <γ, C>
    return torch.sum(gamma * torch.tensor(C)).item()

计算复杂度分析

方法时间复杂度空间复杂度
原始线性规划
Sinkhorn
稀疏 Sinkhorn
低秩近似

其中 是迭代次数, 是低秩近似秩。

正则化参数 的影响

def plot_epsilon_effect():
    """
    展示正则化参数 ε 的影响
    """
    import matplotlib.pyplot as plt
    
    x = np.linspace(0, 1, 100)
    P = np.zeros(100); P[10] = 1.0
    Q = np.zeros(100); Q[90] = 1.0
    
    C = np.abs(x[:, None] - x[None, :])
    
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    
    epsilons = [0.01, 0.1, 1.0]
    for ax, eps in zip(axes, epsilons):
        gamma = sinkhorn(P, Q, C, epsilon=eps, num_iters=100)
        ax.imshow(gamma, cmap='Blues')
        ax.set_title(f'ε = {eps}, W = {torch.sum(gamma * C).item():.3f}')
    
    plt.tight_layout()
    plt.show()
    
    """
    观察:
    - ε → 0: γ 趋向稀疏(接近原始 OT)
    - ε → ∞: γ 趋向均匀分布
    """

Wasserstein GAN

标准 GAN 的问题

标准 GAN 使用 JS 散度,面临以下问题:

  • 当生成分布与真实分布不重叠时,梯度消失
  • 训练不稳定,难以平衡生成器和判别器

WGAN 的改进

Wasserstein GAN (Arjovsky et al., 2017) 使用 Wasserstein-1 距离作为损失:

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class DiscriminatorWGAN(nn.Module):
    """
    WGAN 判别器(称为 Critic)
    
    关键变化:
    1. 输出不是 sigmoid(去掉概率解释)
    2. 权重裁剪保证 Lipschitz 约束
    """
    def __init__(self, input_dim, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, x):
        return self.net(x)
 
 
class WGAN(nn.Module):
    """
    Wasserstein GAN
    
    损失函数:W_1(P_real, P_fake) ≈ E[critic(real)] - E[critic(fake)]
    """
    def __init__(self, latent_dim, data_dim, hidden_dim=64):
        super().__init__()
        
        # 生成器
        self.generator = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, data_dim),
            nn.Tanh()  # 输出归一化到 [-1, 1]
        )
        
        # 判别器 (Critic)
        self.critic = DiscriminatorWGAN(data_dim, hidden_dim)
        
        # 权重裁剪参数(WGAN 原始方法)
        self.clip_value = 0.01
    
    def wasserstein_loss(self, real, fake):
        """
        WGAN 损失函数
        
        目标:最大化 E[critic(real)] - E[critic(fake)]
              即最小化 -(上述值)
        """
        real_output = self.critic(real)
        fake_output = self.critic(fake)
        
        # 负的 Wasserstein 距离(要最小化)
        loss = -(real_output.mean() - fake_output.mean())
        return loss
    
    def train_step(self, real_data, optimizer_g, optimizer_c):
        """
        一步训练
        """
        batch_size = real_data.size(0)
        z = torch.randn(batch_size, self.generator[0].in_features)
        fake_data = self.generator(z)
        
        # 1. 训练 Critic(多步)
        optimizer_c.zero_grad()
        critic_loss = self.wasserstein_loss(real_data, fake_data.detach())
        critic_loss.backward()
        optimizer_c.step()
        
        # 权重裁剪
        for p in self.critic.parameters():
            p.data.clamp_(-self.clip_value, self.clip_value)
        
        # 2. 训练 Generator
        optimizer_g.zero_grad()
        gen_loss = self.wasserstein_loss(real_data, fake_data)
        gen_loss.backward()
        optimizer_g.step()
        
        return critic_loss.item(), gen_loss.item()
 
 
# 使用示例
# WGAN-GP (Gradient Penalty) 更稳定
class WGAN_GP(nn.Module):
    """
    WGAN with Gradient Penalty
    
    用梯度惩罚替代权重裁剪,更稳定
    """
    def __init__(self, latent_dim, data_dim, hidden_dim=64):
        super().__init__()
        self.generator = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, data_dim),
            nn.Tanh()
        )
        self.critic = DiscriminatorWGAN(data_dim, hidden_dim)
        self.lambda_gp = 10  # 梯度惩罚系数
    
    def gradient_penalty(self, real, fake):
        """
        计算梯度惩罚
        
        目标:||∇_x critic(x)|| ≈ 1
        """
        batch_size = real.size(0)
        alpha = torch.rand(batch_size, 1)
        alpha = alpha.expand_as(real)
        
        interpolated = alpha * real + (1 - alpha) * fake
        interpolated.requires_grad_(True)
        
        critic_interpolated = self.critic(interpolated)
        
        gradients = torch.autograd.grad(
            outputs=critic_interpolated,
            inputs=interpolated,
            grad_outputs=torch.ones_like(critic_interpolated),
            create_graph=True
        )[0]
        
        gradients = gradients.view(batch_size, -1)
        gradient_norm = gradients.norm(2, dim=1)
        
        penalty = ((gradient_norm - 1) ** 2).mean()
        return penalty
    
    def train_step(self, real_data, optimizer_g, optimizer_c):
        batch_size = real_data.size(0)
        z = torch.randn(batch_size, self.generator[0].in_features)
        fake_data = self.generator(z)
        
        # Critic 损失
        optimizer_c.zero_grad()
        critic_loss = self.wasserstein_loss(real_data, fake_data)
        gp = self.gradient_penalty(real_data, fake_data)
        c_total = critic_loss + self.lambda_gp * gp
        c_total.backward()
        optimizer_c.step()
        
        # Generator 损失
        optimizer_g.zero_grad()
        gen_loss = self.wasserstein_loss(real_data, fake_data)
        gen_loss.backward()
        optimizer_g.step()
        
        return critic_loss.item(), gen_loss.item(), gp.item()

WGAN vs 标准 GAN

特性标准 GANWGAN
散度JS 散度Wasserstein-1
梯度消失/爆炸平滑
训练稳定性敏感稳定
评估指标Loss 与生成质量相关
收敛保证理论上更好

应用场景

1. 领域适应(Domain Adaptation)

class OptimalTransportDomainAdaptation(nn.Module):
    """
    基于 OT 的领域适应
    
    核心思想:找到源域到目标域的最优传输映射
    """
    def __init__(self, feature_dim, num_classes):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )
        self.classifier = nn.Linear(64, num_classes)
    
    def ot_loss(self, source_features, target_features, epsilon=0.1):
        """
        计算源域到目标域的 OT 损失
        
        用于对齐两个域的特征分布
        """
        batch_size = source_features.size(0)
        
        # 成本矩阵
        C = torch.cdist(source_features, target_features, p=2) ** 2
        
        # 均匀分布
        a = torch.ones(batch_size) / batch_size
        b = torch.ones(batch_size) / batch_size
        
        # Sinkhorn 计算
        gamma = sinkhorn(a, b, C, epsilon)
        
        # OT 损失
        return torch.sum(gamma * C)
    
    def forward(self, source_x, target_x):
        source_features = self.feature_extractor(source_x)
        target_features = self.feature_extractor(target_x)
        
        # 分类损失 + OT 损失
        source_logits = self.classifier(source_features)
        
        ot_loss = self.ot_loss(source_features, target_features)
        
        return source_logits, ot_loss

2. Wasserstein Barycenter

def wasserstein_barycenter(distributions, weights=None, epsilon=0.1):
    """
    计算 Wasserstein Barycenter(最优传输重心)
    
    给定多个分布,找到与所有分布的 Wasserstein 距离加权和最小的分布
    
    应用:多域学习的中心表示
    """
    if weights is None:
        weights = torch.ones(len(distributions)) / len(distributions)
    
    # 迭代更新
    barycenter = distributions[0].clone()
    
    for _ in range(100):
        new_barycenter = torch.zeros_like(barycenter)
        
        for dist, w in zip(distributions, weights):
            C = torch.cdist(barycenter, dist) ** 2
            gamma = sinkhorn(
                torch.ones_like(barycenter[:, 0]) / len(barycenter),
                torch.ones_like(dist[:, 0]) / len(dist),
                C, epsilon
            )
            new_barycenter += w * (gamma @ dist)
        
        barycenter = new_barycenter / sum(weights)
    
    return barycenter

3. 图匹配

def graph_optimal_transport(A1, A2, epsilon=0.1):
    """
    图的最优传输匹配
    
    用于比较图结构相似性
    """
    n1, n2 = A1.size(0), A2.size(0)
    
    # 基于节点度构建成本矩阵
    d1 = A1.sum(dim=1)
    d2 = A2.sum(dim=1)
    C = torch.abs(d1.unsqueeze(1) - d2.unsqueeze(0)) ** 2
    
    # OT 计算
    gamma = sinkhorn(
        torch.ones(n1) / n1,
        torch.ones(n2) / n2,
        C, epsilon
    )
    
    # 图传输距离
    ot_distance = torch.sum(gamma * C)
    
    return ot_distance, gamma

Unbalanced Optimal Transport

问题背景

标准 OT 要求传输计划严格满足边际约束 ,但在现实中:

  • 分布可能只是近似归一化
  • 质量可能”创建”或”销毁”

Unbalanced OT 形式化

引入 KL 正则化松弛

这允许边际不完全匹配。

def unbalanced_sinkhorn(a, b, C, epsilon=0.1, tau=0.5):
    """
    Unbalanced Sinkhorn 算法
    
    Args:
        a, b: 权重向量(不要求归一化)
        tau: 平滑参数(tau → 0 趋向标准 OT)
    """
    n, m = len(a), len(b)
    
    # 软化归一化
    a_tilde = a ** (tau / (tau + epsilon))
    b_tilde = b ** (tau / (tau + epsilon))
    
    K = torch.exp(-C / epsilon)
    
    u = torch.ones(n)
    v = torch.ones(m)
    
    for _ in range(100):
        u_prev = u.clone()
        u = a_tilde / (K @ v + 1e-10)
        v = b_tilde / (K.T @ u + 1e-10)
        
        if torch.max(torch.abs(u - u_prev)) < 1e-6:
            break
    
    gamma = u.view(-1, 1) * K * v.view(1, -1)
    return gamma

核心公式速查

概念公式
Monge 问题
Kantorovich 问题
-Wasserstein
对偶
Entropic OT
Sinkhorn 更新,

参考


扩展阅读

Footnotes

  1. Villani, C. (2008). Optimal Transport: Old and New. Springer.