Sinkhorn算法与熵正则化最优传输
Sinkhorn算法是计算熵正则化最优传输的高效迭代方法,由 Sinkhorn 在1964年提出。该算法在近年来成为机器学习中大规模OT计算的核心工具。1
问题回顾
标准OT的计算复杂度
标准的Kantorovich问题的最优传输距离需要求解线性规划:
其中 。
| 方法 | 时间复杂度 | 空间复杂度 |
|---|---|---|
| 线性规划(单纯形) | ||
| 内点法 | ||
| 网络单纯形 |
对于 的分布, 运算量是不可接受的!
熵正则化的引入
Cuturi (2013) 提出在目标函数中加入熵正则化项:
其中:
- 是香农熵
- 是正则化参数
正则化的效果:
- 使问题变为强凸,解唯一
- 解具有可分离形式,可以用Sinkhorn算法高效计算
- 当 时,趋向原始OT解
Sinkhorn 算法
核心洞察
定理(Sinkhorn):熵正则化OT问题的最优解可以表示为:
其中:
- 称为Gibbs核
- 是两个正向量
物理意义:传输计划被分解为”源侧缩放”和”目标侧缩放”的乘积。
迭代公式
给定 ,向量 和 满足以下交替归一化:
其中除法为逐元素运算。
展开后的迭代:
def sinkhorn_iteration(a, b, K, num_iters=100, epsilon=1e-9):
"""
Sinkhorn 算法迭代
Args:
a: 源分布, shape [n]
b: 目标分布, shape [m]
K: Gibbs核, shape [n, m]
num_iters: 迭代次数
epsilon: 收敛阈值
Returns:
u, v: 缩放向量
gamma: 最优传输计划
"""
n, m = len(a), len(b)
# 初始化
u = torch.ones(n)
v = torch.ones(m)
# 迭代
for t in range(num_iters):
u_prev = u.clone()
# 更新 u
u = a / (K @ v + epsilon)
# 更新 v
v = b / (K.T @ u + epsilon)
# 检查收敛
diff = torch.max(torch.abs(u - u_prev)).item()
if diff < epsilon:
print(f"Sinkhorn converged at iteration {t}")
break
# 计算最优传输计划
gamma = u.view(-1, 1) * K * v.view(1, -1)
return u, v, gamma矩阵视角
Sinkhorn算法可以优雅地表示为矩阵运算:
def sinkhorn_matrix_form(a, b, C, epsilon):
"""
Sinkhorn 算法的矩阵形式
"""
# Gibbs核
K = torch.exp(-C / epsilon)
# Sinkhorn 迭代的矩阵形式
# diag(u) @ K @ diag(v) 的Sinkhorn不动点
for _ in range(100):
# 行归一化:diag(a/(K@v)) @ K
K = torch.diag(a / (K @ torch.ones(len(b)))) @ K
# 列归一化:K @ diag(b/(K.T@u))
K = K @ torch.diag(b / (K.T @ torch.ones(len(a))))
# K 最终是一个近似传输计划
return K数值稳定性
问题:下溢(Underflow)
当 很小或 很大时, 会下溢到0。
# 问题演示
epsilon = 0.01
C_large = 100.0
K = np.exp(-C_large / epsilon)
print(K) # 输出: 0.0 (下溢!)解决方案:对数空间计算
在对数空间进行所有计算,避免数值下溢:
def sinkhorn_log_stabilized(a, b, C, epsilon=0.1, num_iters=100):
"""
对数空间稳定的 Sinkhorn 算法
关键洞察:
- 不直接计算 exp(-C/ε)
- 在 log 空间进行所有运算
- 使用 log-sum-exp 技巧
"""
n, m = len(a), len(b)
# log 空间的核(数值稳定)
log_K = -C / epsilon
# log 空间的初始化
log_u = torch.zeros(n)
log_v = torch.zeros(m)
# log_a 和 log_b
log_a = torch.log(a + 1e-50)
log_b = torch.log(b + 1e-50)
for _ in range(num_iters):
# log(u) = log(a) - logsumexp(log_K + log_v)
log_u = log_a - torch.logsumexp(log_K + log_v.unsqueeze(0), dim=1)
# log(v) = log(b) - logsumexp(log_K.T + log_u)
log_v = log_b - torch.logsumexp(log_K.T + log_u.unsqueeze(0), dim=1)
# 计算传输计划(如果有需要)
# gamma_ij = exp(log_u_i + log_K_ij + log_v_j)
return log_u, log_v
def logsumexp_rows(log_K, log_v):
"""
计算 logsumexp_j (K_ij * v_j) 的数值稳定版本
logsumexp(x_i) = max(x) + log(sum(exp(x_i - max(x))))
"""
max_log_v = torch.max(log_v)
shifted_v = log_v - max_log_v
return max_log_v + torch.log(torch.sum(torch.exp(shifted_v)))PythonOT 实现
# 使用 POT (Python Optimal Transport) 库
try:
import ot
# 标准 Sinkhorn
gamma = ot.sinkhorn(a, b, C, reg=0.1)
# 对数稳定版本
gamma = ot.sinkhorn_lpl1_mm(a, b, C, reg=0.1, log=True)
# 半松弛 Sinkhorn(Unbalanced OT)
gamma = ot.sinkhorn2(a, b, C, reg=0.1)
except ImportError:
print("请安装 POT: pip install POT")收敛性分析
Sinkhorn 的收敛速率
定理:Sinkhorn算法线性收敛到最优解,收敛速率由条件数决定:
其中 , 是问题的条件数。
收敛速度与 的关系
def plot_convergence():
"""
展示不同 ε 下的收敛速度
"""
import matplotlib.pyplot as plt
n = 100
a = torch.ones(n) / n
b = torch.ones(n) / n
x = torch.arange(n, dtype=torch.float) / n
C = torch.cdist(x.unsqueeze(1), x.unsqueeze(1)) ** 2
epsilons = [0.01, 0.05, 0.1, 0.5]
colors = ['r', 'g', 'b', 'orange']
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
for eps, color in zip(epsilons, colors):
errors = []
u = torch.ones(n)
v = torch.ones(n)
K = torch.exp(-C / eps)
for t in range(100):
u_prev = u.clone()
u = a / (K @ v + 1e-50)
v = b / (K.T @ u + 1e-50)
# 计算与真值的误差(近似)
gamma = u.view(-1, 1) * K * v.view(1, -1)
marginal_error = torch.max(
torch.abs(torch.sum(gamma, dim=1) - a),
torch.abs(torch.sum(gamma, dim=0) - b)
)
errors.append(marginal_error.item())
axes[0].semilogy(errors, color=color, label=f'ε={eps}')
axes[1].semilogy(errors[:20], color=color, label=f'ε={eps}')
axes[0].set_xlabel('Iteration')
axes[0].set_ylabel('Marginal Error')
axes[0].set_title('Full Convergence')
axes[0].legend()
axes[1].set_xlabel('Iteration')
axes[1].set_ylabel('Marginal Error')
axes[1].set_title('Early Convergence')
axes[1].legend()
plt.tight_layout()
plt.show()
"""
观察:
- ε 越大,收敛越快
- ε 越小,最终精度越高
- 存在 ε-依赖的收敛速率权衡
"""Early Stopping
class EarlyStoppingSinkhorn:
"""
带早停的 Sinkhorn 算法
"""
def __init__(self, tol=1e-6, max_iter=1000):
self.tol = tol
self.max_iter = max_iter
def fit(self, a, b, C, epsilon):
self.history = {'error': [], 'cost': []}
u = torch.ones_like(a)
v = torch.ones_like(b)
K = torch.exp(-C / epsilon)
for t in range(self.max_iter):
u_prev, v_prev = u.clone(), v.clone()
u = a / (K @ v + 1e-50)
v = b / (K.T @ u + 1e-50)
# 计算边际误差
gamma = u.view(-1, 1) * K * v.view(1, -1)
error = max(
torch.norm(torch.sum(gamma, dim=1) - a).item(),
torch.norm(torch.sum(gamma, dim=0) - b).item()
)
self.history['error'].append(error)
if error < self.tol:
print(f"Converged at iteration {t}")
break
return u, v, gammaSinkhorn 距离的性质
Sinkhorn 距离的定义
使用 Sinkhorn 计算的熵正则化距离:
其中 是 Sinkhorn 的收敛解。
与真实 Wasserstein 距离的关系
def compare_distances():
"""
比较 Sinkhorn 距离与真实 Wasserstein 距离
"""
# 两个 Dirac 分布
P = torch.tensor([1.0, 0.0])
Q = torch.tensor([0.0, 1.0])
x = torch.tensor([[0.0], [1.0]])
C = torch.cdist(x, x) # [[0,1],[1,0]]
epsilons = [0.01, 0.1, 0.5, 1.0]
print("ε Sinkhorn Wasserstein")
for eps in epsilons:
u, v, gamma = sinkhorn_iteration(P, Q, C, eps)
sinkhorn_dist = torch.sum(gamma * C).item()
# 真实 Wasserstein-1
w_dist = 1.0
print(f"{eps:.2f} {sinkhorn_dist:.4f} {w_dist:.4f}")
"""
输出示例:
ε Sinkhorn Wasserstein
0.01 0.9900 1.0000
0.10 0.9000 1.0000
0.50 0.7500 1.0000
1.00 0.6321 1.0000
观察:ε 越小,Sinkhorn 距离越接近真实 Wasserstein 距离
"""三角不等式
重要:Sinkhorn 距离不满足三角不等式(因为正则化破坏了度量性质)。
| 性质 | Wasserstein | Sinkhorn |
|---|---|---|
| 非负性 | ✅ | ✅ |
| 同一性 | ✅ | ✅ |
| 对称性 | ✅ | ✅ |
| 三角不等式 | ✅ | ❌ |
| 收敛到 | - | ✅ 当 |
Sinkhorn 在深度学习中的应用
1. Contrastive Learning
class SinkhornContrastiveLoss(nn.Module):
"""
基于 Sinkhorn 的对比损失
用于无监督/自监督学习中的分布对齐
"""
def __init__(self, temperature=0.1, epsilon=0.1):
super().__init__()
self.temperature = temperature
self.epsilon = epsilon
def forward(self, z1, z2):
"""
Args:
z1, z2: 两个视图的特征, shape [batch, dim]
Returns:
loss: Sinkhorn 对比损失
"""
batch_size = z1.size(0)
# 归一化特征
z1 = F.normalize(z1, dim=1)
z2 = F.normalize(z2, dim=1)
# 拼接特征
z = torch.cat([z1, z2], dim=0)
# 计算相似度矩阵
sim = torch.mm(z, z.T) / self.temperature
# Sinkhorn 距离作为损失
# 成本矩阵:负相似度
C = -sim
# 均匀分布
a = torch.ones(2 * batch_size) / (2 * batch_size)
# Sinkhorn 计算
K = torch.exp(-C / self.epsilon)
u = torch.ones(2 * batch_size)
v = torch.ones(2 * batch_size)
for _ in range(10):
u = a / (K @ v + 1e-50)
v = a / (K.T @ u + 1e-50)
gamma = u.unsqueeze(1) * K * v.unsqueeze(0)
# 损失 = Sinkhorn 距离
loss = torch.sum(gamma * C)
return loss2. Image Generation (SinkhornGAN)
class SinkhornGenerativeLoss(nn.Module):
"""
基于 Sinkhorn 距离的生成损失
比 WGAN 更稳定的替代方案
"""
def __init__(self, epsilon=0.1):
super().__init__()
self.epsilon = epsilon
def sinkhorn_divergence(self, real, fake):
"""
计算真实分布与生成分布之间的 Sinkhorn 散度
近似 Wasserstein-1 距离
"""
batch_size = real.size(0)
# 特征级别的距离
C = torch.cdist(real, fake, p=2) ** 2
# 均匀分布
a = torch.ones(batch_size) / batch_size
b = torch.ones(batch_size) / batch_size
# Sinkhorn 计算
K = torch.exp(-C / self.epsilon)
u = torch.ones(batch_size)
v = torch.ones(batch_size)
for _ in range(20):
u = a / (K @ v + 1e-50)
v = b / (K.T @ u + 1e-50)
gamma = u.unsqueeze(1) * K * v.unsqueeze(0)
# Sinkhorn 距离
return torch.sum(gamma * C)
def forward(self, real_features, fake_features):
"""
计算生成损失
最小化真实与生成特征的 Sinkhorn 距离
"""
loss = self.sinkhorn_divergence(real_features, fake_features)
return loss3. Multi-Domain Translation
class MultiDomainOT(nn.Module):
"""
多域传输:学习域不变表示
使用 Sinkhorn 对齐不同域的分布
"""
def __init__(self, feature_dim, num_domains, epsilon=0.1):
super().__init__()
self.feature_extractor = nn.Sequential(
nn.Linear(feature_dim, 128),
nn.ReLU(),
nn.Linear(128, 64)
)
self.num_domains = num_domains
self.epsilon = epsilon
def compute_domain_alignment_loss(self, features, domain_labels):
"""
计算域对齐损失
目标:最小化所有域之间的 Sinkhorn 距离之和
"""
unique_domains = torch.unique(domain_labels)
n_domains = len(unique_domains)
total_loss = 0.0
count = 0
for i in range(n_domains):
for j in range(i + 1, n_domains):
mask_i = (domain_labels == unique_domains[i])
mask_j = (domain_labels == unique_domains[j])
features_i = features[mask_i]
features_j = features[mask_j]
# 计算两个域之间的 Sinkhorn 距离
C = torch.cdist(features_i, features_j, p=2) ** 2
a = torch.ones(len(features_i)) / len(features_i)
b = torch.ones(len(features_j)) / len(features_j)
K = torch.exp(-C / self.epsilon)
u = torch.ones(len(features_i))
v = torch.ones(len(features_j))
for _ in range(20):
u = a / (K @ v + 1e-50)
v = b / (K.T @ u + 1e-50)
gamma = u.unsqueeze(1) * K * v.unsqueeze(0)
loss_ij = torch.sum(gamma * C)
total_loss += loss_ij
count += 1
return total_loss / max(count, 1)计算效率优化
并行化
def sinkhorn_batched(a, b, C, epsilon=0.1, num_iters=100):
"""
批量 Sinkhorn 算法
同时处理多个传输问题
"""
# a: [batch, n]
# b: [batch, m]
# C: [batch, n, m]
batch_size = a.size(0)
# 初始化
u = torch.ones(batch_size, a.size(1))
v = torch.ones(batch_size, b.size(1))
# Gibbs 核(批量计算)
K = torch.exp(-C / epsilon)
for _ in range(num_iters):
# 更新 u: [batch, n] / ([batch, n, m] @ [batch, m, 1]) -> [batch, n]
u = a / (torch.bmm(K, v.unsqueeze(2)).squeeze(2) + 1e-50)
# 更新 v: [batch, m] / ([batch, m, n] @ [batch, n, 1]) -> [batch, m]
v = b / (torch.bmm(K.transpose(1, 2), u.unsqueeze(2)).squeeze(2) + 1e-50)
# 批量计算传输计划
gamma = u.unsqueeze(2) * K * v.unsqueeze(1)
return gammaGPU 加速
def sinkhorn_gpu_demo():
"""
演示 GPU 加速
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 大规模问题
n = 10000
m = 10000
a = torch.ones(n, device=device) / n
b = torch.ones(m, device=device) / m
# 随机成本矩阵
x = torch.randn(n, 1, device=device)
y = torch.randn(m, 1, device=device)
C = (x - y.T) ** 2
# GPU 计算
gamma = sinkhorn_iteration(a, b, C, epsilon=0.1, num_iters=100)
print(f"使用设备: {device}")
print(f"传输计划形状: {gamma.shape}")
print(f"传输计划范围: [{gamma.min():.6f}, {gamma.max():.6f}]")核心公式速查
| 概念 | 公式 |
|---|---|
| 熵正则化OT | |
| Gibbs核 | |
| Sinkhorn迭代 | , |
| 最优传输计划 | |
| 收敛速率 |
参考
扩展阅读
- 最优传输与Wasserstein距离 — 理论基础
- 信息论基础 — 熵与互信息的联系
- PythonOT — OT 计算库
- POT GitHub
Footnotes
-
Sinkhorn, R. (1964). “Relationship between Positive Matrices and Successive Contructions of Diagonal Matrices”. Proceedings of the American Mathematical Society. ↩