最优传输与Wasserstein距离
最优传输(Optimal Transport, OT)理论是概率论与几何学的交叉领域,由 Monge 在 1781 年提出。经过两个多世纪的发展,它已成为现代机器学习中不可或缺的工具,尤其在生成模型、领域适应、聚类分析等领域发挥着关键作用。1
背景:从搬砖问题说起
Monge 的搬砖问题
想象你要将一堆砖从施工现场(分布 )运到建筑工地(分布 ):
施工现场 P 建筑工地 Q
● ● ● ●
● ● ● ● ●
● ● ● ● ● ●
↘ ↘ ↘ ↙ ↙ ↙
搬 运 运 输
问题:如何安排运输计划使得总搬运成本最小?
这就是最优传输的核心问题。
形式化定义
离散情形:Monge 问题
设:
- :源分布( 位置, 质量)
- :目标分布( 位置, 质量)
- :从 到 的单位运输成本
Monge 传输问题:
其中 表示 将 的质量恰好映射到 。
关键问题
Monge 问题存在两个困难:
- 映射可能不存在:当质量分布不匹配时
- 优化问题高度非线性: 的组合性质
Kantorovich 松弛
Kantorovich 在 1942 年提出将问题松弛为线性规划:
Kantorovich 传输问题:
其中 是耦合矩阵的集合,满足:
物理意义: 表示从 运到 的质量。
耦合矩阵 γ:
y_1 y_2 y_3 ... → Q 的概率
┌──────┬──────┬──────┐
x_1 │ 0.2 │ 0.1 │ 0.0 │
├──────┼──────┼──────┤
x_2 │ 0.0 │ 0.3 │ 0.1 │
├──────┼──────┼──────┤
x_3 │ 0.0 │ 0.0 │ 0.2 │
└──────┴──────┴──────┘
↑ ↑ ↑
P 的概率
Monge vs Kantorovich
| 特性 | Monge | Kantorovich |
|---|---|---|
| 传输映射 | 确定性 | 概率耦合 |
| 解的存在性 | 不一定存在 | 总是存在 |
| 计算复杂度 | 非凸 | 线性规划(凸) |
| 适用场景 | 质量精确匹配 | 一般分布 |
Wasserstein 距离
定义
当成本函数为 ()时,Kantorovich 问题的最优值定义了 -Wasserstein 距离:
对于离散分布:
常用 Wasserstein 距离
| 距离 | 公式 | 特点 |
|---|---|---|
| (Earth Mover’s) | 线性成本 | |
| 二次成本 | ||
| 最小化最大距离 |
Wasserstein-1 的特殊性质
距离也称为 Earth Mover’s Distance (EMD),具有以下重要性质:
这称为 Kantorovich-Rubinstein 对偶性。
Wasserstein 距离的几何理解
P Q
● ●
╱ ╲ ╱ ╲
╱ ╲ ╱ ╲
●─────● ●─────●
传输成本 = Σ 质量 × 距离
= 0.3×0.2 + 0.2×0.3 + ...
与 KL 散度的对比
| 特性 | Wasserstein 距离 | KL 散度 |
|---|---|---|
| 定义 | 几何距离 | 信息熵差 |
| 度量性质 | 满足度量(对称、三角) | 不满足对称性 |
| 对零重叠的处理 | 有定义(有意义) | 无定义() |
| 计算复杂度 | (线性规划) | |
| 梯度特性 | 平滑 | 可能无界 |
为什么 Wasserstein 更好?
import numpy as np
def demo_wasserstein_vs_kl():
"""
演示 Wasserstein 距离 vs KL 散度
"""
# 两个分布:支撑集几乎不重叠
P = np.array([0.99, 0.01])
Q = np.array([0.01, 0.99])
# KL 散度(不重叠时爆炸)
kl_PQ = np.sum(P * np.log(P / (Q + 1e-10) + 1e-10))
kl_QP = np.sum(Q * np.log(Q / (P + 1e-10) + 1e-10))
# Wasserstein-1(始终有定义)
w1 = 0.99 * 1 + 0.01 * 0 # 简化计算
print(f"KL(P||Q): {kl_PQ:.2f}") # ~4.56
print(f"KL(Q||P): {kl_QP:.2f}") # ~4.56
print(f"W_1(P,Q): {w1:.2f}") # ~0.98
"""
结论:当分布支撑集不重叠时
- KL → ∞(无法优化)
- Wasserstein → 有意义的具体数值
"""
demo_wasserstein_vs_kl()梯度对比可视化
KL divergence梯度
↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ← 梯度爆炸!
─────┼────────────────→ x
0
Wasserstein梯度
↑ ↓
↑ ↓
↑ ↓
─────┼────────────────→ x
0 ← 平滑过渡
这就是 Wasserstein GAN 优于标准 GAN 的原因!
Sinkhorn 算法
Entropic 正则化
直接在 Wasserstein 距离上优化计算量很大。Cuturi (2013) 提出了 entropic 正则化:
正则化项使得最优耦合 具有可分离形式:
其中 称为 Gibbs 核。
Sinkhorn 迭代
给定 矩阵,Sinkhorn 算法通过交替归一化求解:
import torch
import numpy as np
def sinkhorn(a, b, C, epsilon=0.1, num_iters=100, device='cpu'):
"""
Sinkhorn 算法计算 entropic OT
Args:
a: 源分布权重, shape [n]
b: 目标分布权重, shape [m]
C: 成本矩阵, shape [n, m]
epsilon: 正则化参数(越小越接近真实 OT)
num_iters: 迭代次数
Returns:
gamma: 最优耦合矩阵
"""
n, m = len(a), len(b)
# 确保是 PyTorch 张量
a = torch.tensor(a, dtype=torch.float64, device=device)
b = torch.tensor(b, dtype=torch.float64, device=device)
C = torch.tensor(C, dtype=torch.float64, device=device)
# Gibbs 核
K = torch.exp(-C / epsilon)
# 初始化
u = torch.ones(n, dtype=torch.float64, device=device)
v = torch.ones(m, dtype=torch.float64, device=device)
# Sinkhorn 迭代
for i in range(num_iters):
# 更新 u
u = a / (K @ v + 1e-10)
# 更新 v
v = b / (K.T @ u + 1e-10)
# 数值稳定性检查
if torch.isnan(u).any() or torch.isnan(v).any():
print(f"Sinkhorn diverged at iteration {i}")
break
# 计算最优耦合
gamma = u.view(-1, 1) * K * v.view(1, -1)
return gamma
def wasserstein_1(P, Q, epsilon=1e-3):
"""
近似计算 Wasserstein-1 距离
使用 Sinkhorn 近似
"""
n, m = P.shape[0], Q.shape[0]
# 成本矩阵(假设 1D 分布)
x = np.arange(n) / n
y = np.arange(m) / m
C = np.abs(x[:, None] - y[None, :])
# 均匀分布
a = np.ones(n) / n
b = np.ones(m) / m
# Sinkhorn 计算
gamma = sinkhorn(a, b, C, epsilon)
# Wasserstein-1 = <γ, C>
return torch.sum(gamma * torch.tensor(C)).item()计算复杂度分析
| 方法 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| 原始线性规划 | ||
| Sinkhorn | ||
| 稀疏 Sinkhorn | ||
| 低秩近似 |
其中 是迭代次数, 是低秩近似秩。
正则化参数 的影响
def plot_epsilon_effect():
"""
展示正则化参数 ε 的影响
"""
import matplotlib.pyplot as plt
x = np.linspace(0, 1, 100)
P = np.zeros(100); P[10] = 1.0
Q = np.zeros(100); Q[90] = 1.0
C = np.abs(x[:, None] - x[None, :])
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
epsilons = [0.01, 0.1, 1.0]
for ax, eps in zip(axes, epsilons):
gamma = sinkhorn(P, Q, C, epsilon=eps, num_iters=100)
ax.imshow(gamma, cmap='Blues')
ax.set_title(f'ε = {eps}, W = {torch.sum(gamma * C).item():.3f}')
plt.tight_layout()
plt.show()
"""
观察:
- ε → 0: γ 趋向稀疏(接近原始 OT)
- ε → ∞: γ 趋向均匀分布
"""Wasserstein GAN
标准 GAN 的问题
标准 GAN 使用 JS 散度,面临以下问题:
- 当生成分布与真实分布不重叠时,梯度消失
- 训练不稳定,难以平衡生成器和判别器
WGAN 的改进
Wasserstein GAN (Arjovsky et al., 2017) 使用 Wasserstein-1 距离作为损失:
import torch
import torch.nn as nn
import torch.nn.functional as F
class DiscriminatorWGAN(nn.Module):
"""
WGAN 判别器(称为 Critic)
关键变化:
1. 输出不是 sigmoid(去掉概率解释)
2. 权重裁剪保证 Lipschitz 约束
"""
def __init__(self, input_dim, hidden_dim=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, hidden_dim),
nn.LeakyReLU(0.2),
nn.Linear(hidden_dim, 1)
)
def forward(self, x):
return self.net(x)
class WGAN(nn.Module):
"""
Wasserstein GAN
损失函数:W_1(P_real, P_fake) ≈ E[critic(real)] - E[critic(fake)]
"""
def __init__(self, latent_dim, data_dim, hidden_dim=64):
super().__init__()
# 生成器
self.generator = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, data_dim),
nn.Tanh() # 输出归一化到 [-1, 1]
)
# 判别器 (Critic)
self.critic = DiscriminatorWGAN(data_dim, hidden_dim)
# 权重裁剪参数(WGAN 原始方法)
self.clip_value = 0.01
def wasserstein_loss(self, real, fake):
"""
WGAN 损失函数
目标:最大化 E[critic(real)] - E[critic(fake)]
即最小化 -(上述值)
"""
real_output = self.critic(real)
fake_output = self.critic(fake)
# 负的 Wasserstein 距离(要最小化)
loss = -(real_output.mean() - fake_output.mean())
return loss
def train_step(self, real_data, optimizer_g, optimizer_c):
"""
一步训练
"""
batch_size = real_data.size(0)
z = torch.randn(batch_size, self.generator[0].in_features)
fake_data = self.generator(z)
# 1. 训练 Critic(多步)
optimizer_c.zero_grad()
critic_loss = self.wasserstein_loss(real_data, fake_data.detach())
critic_loss.backward()
optimizer_c.step()
# 权重裁剪
for p in self.critic.parameters():
p.data.clamp_(-self.clip_value, self.clip_value)
# 2. 训练 Generator
optimizer_g.zero_grad()
gen_loss = self.wasserstein_loss(real_data, fake_data)
gen_loss.backward()
optimizer_g.step()
return critic_loss.item(), gen_loss.item()
# 使用示例
# WGAN-GP (Gradient Penalty) 更稳定
class WGAN_GP(nn.Module):
"""
WGAN with Gradient Penalty
用梯度惩罚替代权重裁剪,更稳定
"""
def __init__(self, latent_dim, data_dim, hidden_dim=64):
super().__init__()
self.generator = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, data_dim),
nn.Tanh()
)
self.critic = DiscriminatorWGAN(data_dim, hidden_dim)
self.lambda_gp = 10 # 梯度惩罚系数
def gradient_penalty(self, real, fake):
"""
计算梯度惩罚
目标:||∇_x critic(x)|| ≈ 1
"""
batch_size = real.size(0)
alpha = torch.rand(batch_size, 1)
alpha = alpha.expand_as(real)
interpolated = alpha * real + (1 - alpha) * fake
interpolated.requires_grad_(True)
critic_interpolated = self.critic(interpolated)
gradients = torch.autograd.grad(
outputs=critic_interpolated,
inputs=interpolated,
grad_outputs=torch.ones_like(critic_interpolated),
create_graph=True
)[0]
gradients = gradients.view(batch_size, -1)
gradient_norm = gradients.norm(2, dim=1)
penalty = ((gradient_norm - 1) ** 2).mean()
return penalty
def train_step(self, real_data, optimizer_g, optimizer_c):
batch_size = real_data.size(0)
z = torch.randn(batch_size, self.generator[0].in_features)
fake_data = self.generator(z)
# Critic 损失
optimizer_c.zero_grad()
critic_loss = self.wasserstein_loss(real_data, fake_data)
gp = self.gradient_penalty(real_data, fake_data)
c_total = critic_loss + self.lambda_gp * gp
c_total.backward()
optimizer_c.step()
# Generator 损失
optimizer_g.zero_grad()
gen_loss = self.wasserstein_loss(real_data, fake_data)
gen_loss.backward()
optimizer_g.step()
return critic_loss.item(), gen_loss.item(), gp.item()WGAN vs 标准 GAN
| 特性 | 标准 GAN | WGAN |
|---|---|---|
| 散度 | JS 散度 | Wasserstein-1 |
| 梯度 | 消失/爆炸 | 平滑 |
| 训练稳定性 | 敏感 | 稳定 |
| 评估指标 | 无 | Loss 与生成质量相关 |
| 收敛保证 | 无 | 理论上更好 |
应用场景
1. 领域适应(Domain Adaptation)
class OptimalTransportDomainAdaptation(nn.Module):
"""
基于 OT 的领域适应
核心思想:找到源域到目标域的最优传输映射
"""
def __init__(self, feature_dim, num_classes):
super().__init__()
self.feature_extractor = nn.Sequential(
nn.Linear(feature_dim, 128),
nn.ReLU(),
nn.Linear(128, 64)
)
self.classifier = nn.Linear(64, num_classes)
def ot_loss(self, source_features, target_features, epsilon=0.1):
"""
计算源域到目标域的 OT 损失
用于对齐两个域的特征分布
"""
batch_size = source_features.size(0)
# 成本矩阵
C = torch.cdist(source_features, target_features, p=2) ** 2
# 均匀分布
a = torch.ones(batch_size) / batch_size
b = torch.ones(batch_size) / batch_size
# Sinkhorn 计算
gamma = sinkhorn(a, b, C, epsilon)
# OT 损失
return torch.sum(gamma * C)
def forward(self, source_x, target_x):
source_features = self.feature_extractor(source_x)
target_features = self.feature_extractor(target_x)
# 分类损失 + OT 损失
source_logits = self.classifier(source_features)
ot_loss = self.ot_loss(source_features, target_features)
return source_logits, ot_loss2. Wasserstein Barycenter
def wasserstein_barycenter(distributions, weights=None, epsilon=0.1):
"""
计算 Wasserstein Barycenter(最优传输重心)
给定多个分布,找到与所有分布的 Wasserstein 距离加权和最小的分布
应用:多域学习的中心表示
"""
if weights is None:
weights = torch.ones(len(distributions)) / len(distributions)
# 迭代更新
barycenter = distributions[0].clone()
for _ in range(100):
new_barycenter = torch.zeros_like(barycenter)
for dist, w in zip(distributions, weights):
C = torch.cdist(barycenter, dist) ** 2
gamma = sinkhorn(
torch.ones_like(barycenter[:, 0]) / len(barycenter),
torch.ones_like(dist[:, 0]) / len(dist),
C, epsilon
)
new_barycenter += w * (gamma @ dist)
barycenter = new_barycenter / sum(weights)
return barycenter3. 图匹配
def graph_optimal_transport(A1, A2, epsilon=0.1):
"""
图的最优传输匹配
用于比较图结构相似性
"""
n1, n2 = A1.size(0), A2.size(0)
# 基于节点度构建成本矩阵
d1 = A1.sum(dim=1)
d2 = A2.sum(dim=1)
C = torch.abs(d1.unsqueeze(1) - d2.unsqueeze(0)) ** 2
# OT 计算
gamma = sinkhorn(
torch.ones(n1) / n1,
torch.ones(n2) / n2,
C, epsilon
)
# 图传输距离
ot_distance = torch.sum(gamma * C)
return ot_distance, gammaUnbalanced Optimal Transport
问题背景
标准 OT 要求传输计划严格满足边际约束 ,但在现实中:
- 分布可能只是近似归一化
- 质量可能”创建”或”销毁”
Unbalanced OT 形式化
引入 KL 正则化松弛:
这允许边际不完全匹配。
def unbalanced_sinkhorn(a, b, C, epsilon=0.1, tau=0.5):
"""
Unbalanced Sinkhorn 算法
Args:
a, b: 权重向量(不要求归一化)
tau: 平滑参数(tau → 0 趋向标准 OT)
"""
n, m = len(a), len(b)
# 软化归一化
a_tilde = a ** (tau / (tau + epsilon))
b_tilde = b ** (tau / (tau + epsilon))
K = torch.exp(-C / epsilon)
u = torch.ones(n)
v = torch.ones(m)
for _ in range(100):
u_prev = u.clone()
u = a_tilde / (K @ v + 1e-10)
v = b_tilde / (K.T @ u + 1e-10)
if torch.max(torch.abs(u - u_prev)) < 1e-6:
break
gamma = u.view(-1, 1) * K * v.view(1, -1)
return gamma核心公式速查
| 概念 | 公式 |
|---|---|
| Monge 问题 | |
| Kantorovich 问题 | |
| -Wasserstein | |
| 对偶 | |
| Entropic OT | |
| Sinkhorn 更新 | , |
参考
扩展阅读
- 扩散模型与变分推断 — 与 OT 的联系
- 信息瓶颈理论 — 另一个信息论视角
- 高斯过程 — 概率测度与 OT 的联系
- PythonOT — 最优传输 Python 库
- POT: Python Optimal Transport
Footnotes
-
Villani, C. (2008). Optimal Transport: Old and New. Springer. ↩