Sinkhorn算法与熵正则化最优传输

Sinkhorn算法是计算熵正则化最优传输的高效迭代方法,由 Sinkhorn 在1964年提出。该算法在近年来成为机器学习中大规模OT计算的核心工具。1

问题回顾

标准OT的计算复杂度

标准的Kantorovich问题的最优传输距离需要求解线性规划:

其中

方法时间复杂度空间复杂度
线性规划(单纯形)
内点法
网络单纯形

对于 的分布, 运算量是不可接受的!

熵正则化的引入

Cuturi (2013) 提出在目标函数中加入熵正则化项

其中:

  • 香农熵
  • 正则化参数

正则化的效果

  • 使问题变为强凸,解唯一
  • 解具有可分离形式,可以用Sinkhorn算法高效计算
  • 时,趋向原始OT解

Sinkhorn 算法

核心洞察

定理(Sinkhorn):熵正则化OT问题的最优解可以表示为:

其中:

  • 称为Gibbs核
  • 是两个正向量

物理意义:传输计划被分解为”源侧缩放”和”目标侧缩放”的乘积。

迭代公式

给定 ,向量 满足以下交替归一化:

其中除法为逐元素运算。

展开后的迭代:

def sinkhorn_iteration(a, b, K, num_iters=100, epsilon=1e-9):
    """
    Sinkhorn 算法迭代
    
    Args:
        a: 源分布, shape [n]
        b: 目标分布, shape [m]
        K: Gibbs核, shape [n, m]
        num_iters: 迭代次数
        epsilon: 收敛阈值
    
    Returns:
        u, v: 缩放向量
        gamma: 最优传输计划
    """
    n, m = len(a), len(b)
    
    # 初始化
    u = torch.ones(n)
    v = torch.ones(m)
    
    # 迭代
    for t in range(num_iters):
        u_prev = u.clone()
        
        # 更新 u
        u = a / (K @ v + epsilon)
        
        # 更新 v
        v = b / (K.T @ u + epsilon)
        
        # 检查收敛
        diff = torch.max(torch.abs(u - u_prev)).item()
        if diff < epsilon:
            print(f"Sinkhorn converged at iteration {t}")
            break
    
    # 计算最优传输计划
    gamma = u.view(-1, 1) * K * v.view(1, -1)
    
    return u, v, gamma

矩阵视角

Sinkhorn算法可以优雅地表示为矩阵运算:

def sinkhorn_matrix_form(a, b, C, epsilon):
    """
    Sinkhorn 算法的矩阵形式
    """
    # Gibbs核
    K = torch.exp(-C / epsilon)
    
    # Sinkhorn 迭代的矩阵形式
    # diag(u) @ K @ diag(v) 的Sinkhorn不动点
    
    for _ in range(100):
        # 行归一化:diag(a/(K@v)) @ K
        K = torch.diag(a / (K @ torch.ones(len(b)))) @ K
        
        # 列归一化:K @ diag(b/(K.T@u))
        K = K @ torch.diag(b / (K.T @ torch.ones(len(a))))
    
    # K 最终是一个近似传输计划
    return K

数值稳定性

问题:下溢(Underflow)

很小或 很大时, 会下溢到0。

# 问题演示
epsilon = 0.01
C_large = 100.0
K = np.exp(-C_large / epsilon)
print(K)  # 输出: 0.0 (下溢!)

解决方案:对数空间计算

对数空间进行所有计算,避免数值下溢:

def sinkhorn_log_stabilized(a, b, C, epsilon=0.1, num_iters=100):
    """
    对数空间稳定的 Sinkhorn 算法
    
    关键洞察:
    - 不直接计算 exp(-C/ε)
    - 在 log 空间进行所有运算
    - 使用 log-sum-exp 技巧
    """
    n, m = len(a), len(b)
    
    # log 空间的核(数值稳定)
    log_K = -C / epsilon
    
    # log 空间的初始化
    log_u = torch.zeros(n)
    log_v = torch.zeros(m)
    
    # log_a 和 log_b
    log_a = torch.log(a + 1e-50)
    log_b = torch.log(b + 1e-50)
    
    for _ in range(num_iters):
        # log(u) = log(a) - logsumexp(log_K + log_v)
        log_u = log_a - torch.logsumexp(log_K + log_v.unsqueeze(0), dim=1)
        
        # log(v) = log(b) - logsumexp(log_K.T + log_u)
        log_v = log_b - torch.logsumexp(log_K.T + log_u.unsqueeze(0), dim=1)
    
    # 计算传输计划(如果有需要)
    # gamma_ij = exp(log_u_i + log_K_ij + log_v_j)
    
    return log_u, log_v
 
 
def logsumexp_rows(log_K, log_v):
    """
    计算 logsumexp_j (K_ij * v_j) 的数值稳定版本
    
    logsumexp(x_i) = max(x) + log(sum(exp(x_i - max(x))))
    """
    max_log_v = torch.max(log_v)
    shifted_v = log_v - max_log_v
    return max_log_v + torch.log(torch.sum(torch.exp(shifted_v)))

PythonOT 实现

# 使用 POT (Python Optimal Transport) 库
try:
    import ot
    
    # 标准 Sinkhorn
    gamma = ot.sinkhorn(a, b, C, reg=0.1)
    
    # 对数稳定版本
    gamma = ot.sinkhorn_lpl1_mm(a, b, C, reg=0.1, log=True)
    
    # 半松弛 Sinkhorn(Unbalanced OT)
    gamma = ot.sinkhorn2(a, b, C, reg=0.1)
    
except ImportError:
    print("请安装 POT: pip install POT")

收敛性分析

Sinkhorn 的收敛速率

定理:Sinkhorn算法线性收敛到最优解,收敛速率由条件数决定:

其中 是问题的条件数。

收敛速度与 的关系

def plot_convergence():
    """
    展示不同 ε 下的收敛速度
    """
    import matplotlib.pyplot as plt
    
    n = 100
    a = torch.ones(n) / n
    b = torch.ones(n) / n
    x = torch.arange(n, dtype=torch.float) / n
    C = torch.cdist(x.unsqueeze(1), x.unsqueeze(1)) ** 2
    
    epsilons = [0.01, 0.05, 0.1, 0.5]
    colors = ['r', 'g', 'b', 'orange']
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    for eps, color in zip(epsilons, colors):
        errors = []
        u = torch.ones(n)
        v = torch.ones(n)
        K = torch.exp(-C / eps)
        
        for t in range(100):
            u_prev = u.clone()
            u = a / (K @ v + 1e-50)
            v = b / (K.T @ u + 1e-50)
            
            # 计算与真值的误差(近似)
            gamma = u.view(-1, 1) * K * v.view(1, -1)
            marginal_error = torch.max(
                torch.abs(torch.sum(gamma, dim=1) - a),
                torch.abs(torch.sum(gamma, dim=0) - b)
            )
            errors.append(marginal_error.item())
        
        axes[0].semilogy(errors, color=color, label=f'ε={eps}')
        axes[1].semilogy(errors[:20], color=color, label=f'ε={eps}')
    
    axes[0].set_xlabel('Iteration')
    axes[0].set_ylabel('Marginal Error')
    axes[0].set_title('Full Convergence')
    axes[0].legend()
    
    axes[1].set_xlabel('Iteration')
    axes[1].set_ylabel('Marginal Error')
    axes[1].set_title('Early Convergence')
    axes[1].legend()
    
    plt.tight_layout()
    plt.show()
    
    """
    观察:
    - ε 越大,收敛越快
    - ε 越小,最终精度越高
    - 存在 ε-依赖的收敛速率权衡
    """

Early Stopping

class EarlyStoppingSinkhorn:
    """
    带早停的 Sinkhorn 算法
    """
    def __init__(self, tol=1e-6, max_iter=1000):
        self.tol = tol
        self.max_iter = max_iter
    
    def fit(self, a, b, C, epsilon):
        self.history = {'error': [], 'cost': []}
        
        u = torch.ones_like(a)
        v = torch.ones_like(b)
        K = torch.exp(-C / epsilon)
        
        for t in range(self.max_iter):
            u_prev, v_prev = u.clone(), v.clone()
            
            u = a / (K @ v + 1e-50)
            v = b / (K.T @ u + 1e-50)
            
            # 计算边际误差
            gamma = u.view(-1, 1) * K * v.view(1, -1)
            error = max(
                torch.norm(torch.sum(gamma, dim=1) - a).item(),
                torch.norm(torch.sum(gamma, dim=0) - b).item()
            )
            
            self.history['error'].append(error)
            
            if error < self.tol:
                print(f"Converged at iteration {t}")
                break
        
        return u, v, gamma

Sinkhorn 距离的性质

Sinkhorn 距离的定义

使用 Sinkhorn 计算的熵正则化距离:

其中 是 Sinkhorn 的收敛解。

与真实 Wasserstein 距离的关系

def compare_distances():
    """
    比较 Sinkhorn 距离与真实 Wasserstein 距离
    """
    # 两个 Dirac 分布
    P = torch.tensor([1.0, 0.0])
    Q = torch.tensor([0.0, 1.0])
    x = torch.tensor([[0.0], [1.0]])
    
    C = torch.cdist(x, x)  # [[0,1],[1,0]]
    
    epsilons = [0.01, 0.1, 0.5, 1.0]
    
    print("ε      Sinkhorn    Wasserstein")
    for eps in epsilons:
        u, v, gamma = sinkhorn_iteration(P, Q, C, eps)
        sinkhorn_dist = torch.sum(gamma * C).item()
        
        # 真实 Wasserstein-1
        w_dist = 1.0
        
        print(f"{eps:.2f}    {sinkhorn_dist:.4f}      {w_dist:.4f}")
    
    """
    输出示例:
    ε      Sinkhorn    Wasserstein
    0.01    0.9900      1.0000
    0.10    0.9000      1.0000
    0.50    0.7500      1.0000
    1.00    0.6321      1.0000
    
    观察:ε 越小,Sinkhorn 距离越接近真实 Wasserstein 距离
    """

三角不等式

重要:Sinkhorn 距离不满足三角不等式(因为正则化破坏了度量性质)。

性质Wasserstein Sinkhorn
非负性
同一性
对称性
三角不等式
收敛到 -✅ 当

Sinkhorn 在深度学习中的应用

1. Contrastive Learning

class SinkhornContrastiveLoss(nn.Module):
    """
    基于 Sinkhorn 的对比损失
    
    用于无监督/自监督学习中的分布对齐
    """
    def __init__(self, temperature=0.1, epsilon=0.1):
        super().__init__()
        self.temperature = temperature
        self.epsilon = epsilon
    
    def forward(self, z1, z2):
        """
        Args:
            z1, z2: 两个视图的特征, shape [batch, dim]
        
        Returns:
            loss: Sinkhorn 对比损失
        """
        batch_size = z1.size(0)
        
        # 归一化特征
        z1 = F.normalize(z1, dim=1)
        z2 = F.normalize(z2, dim=1)
        
        # 拼接特征
        z = torch.cat([z1, z2], dim=0)
        
        # 计算相似度矩阵
        sim = torch.mm(z, z.T) / self.temperature
        
        # Sinkhorn 距离作为损失
        # 成本矩阵:负相似度
        C = -sim
        
        # 均匀分布
        a = torch.ones(2 * batch_size) / (2 * batch_size)
        
        # Sinkhorn 计算
        K = torch.exp(-C / self.epsilon)
        u = torch.ones(2 * batch_size)
        v = torch.ones(2 * batch_size)
        
        for _ in range(10):
            u = a / (K @ v + 1e-50)
            v = a / (K.T @ u + 1e-50)
        
        gamma = u.unsqueeze(1) * K * v.unsqueeze(0)
        
        # 损失 = Sinkhorn 距离
        loss = torch.sum(gamma * C)
        
        return loss

2. Image Generation (SinkhornGAN)

class SinkhornGenerativeLoss(nn.Module):
    """
    基于 Sinkhorn 距离的生成损失
    
    比 WGAN 更稳定的替代方案
    """
    def __init__(self, epsilon=0.1):
        super().__init__()
        self.epsilon = epsilon
    
    def sinkhorn_divergence(self, real, fake):
        """
        计算真实分布与生成分布之间的 Sinkhorn 散度
        
        近似 Wasserstein-1 距离
        """
        batch_size = real.size(0)
        
        # 特征级别的距离
        C = torch.cdist(real, fake, p=2) ** 2
        
        # 均匀分布
        a = torch.ones(batch_size) / batch_size
        b = torch.ones(batch_size) / batch_size
        
        # Sinkhorn 计算
        K = torch.exp(-C / self.epsilon)
        u = torch.ones(batch_size)
        v = torch.ones(batch_size)
        
        for _ in range(20):
            u = a / (K @ v + 1e-50)
            v = b / (K.T @ u + 1e-50)
        
        gamma = u.unsqueeze(1) * K * v.unsqueeze(0)
        
        # Sinkhorn 距离
        return torch.sum(gamma * C)
    
    def forward(self, real_features, fake_features):
        """
        计算生成损失
        
        最小化真实与生成特征的 Sinkhorn 距离
        """
        loss = self.sinkhorn_divergence(real_features, fake_features)
        return loss

3. Multi-Domain Translation

class MultiDomainOT(nn.Module):
    """
    多域传输:学习域不变表示
    
    使用 Sinkhorn 对齐不同域的分布
    """
    def __init__(self, feature_dim, num_domains, epsilon=0.1):
        super().__init__()
        self.feature_extractor = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )
        self.num_domains = num_domains
        self.epsilon = epsilon
    
    def compute_domain_alignment_loss(self, features, domain_labels):
        """
        计算域对齐损失
        
        目标:最小化所有域之间的 Sinkhorn 距离之和
        """
        unique_domains = torch.unique(domain_labels)
        n_domains = len(unique_domains)
        
        total_loss = 0.0
        count = 0
        
        for i in range(n_domains):
            for j in range(i + 1, n_domains):
                mask_i = (domain_labels == unique_domains[i])
                mask_j = (domain_labels == unique_domains[j])
                
                features_i = features[mask_i]
                features_j = features[mask_j]
                
                # 计算两个域之间的 Sinkhorn 距离
                C = torch.cdist(features_i, features_j, p=2) ** 2
                
                a = torch.ones(len(features_i)) / len(features_i)
                b = torch.ones(len(features_j)) / len(features_j)
                
                K = torch.exp(-C / self.epsilon)
                u = torch.ones(len(features_i))
                v = torch.ones(len(features_j))
                
                for _ in range(20):
                    u = a / (K @ v + 1e-50)
                    v = b / (K.T @ u + 1e-50)
                
                gamma = u.unsqueeze(1) * K * v.unsqueeze(0)
                loss_ij = torch.sum(gamma * C)
                
                total_loss += loss_ij
                count += 1
        
        return total_loss / max(count, 1)

计算效率优化

并行化

def sinkhorn_batched(a, b, C, epsilon=0.1, num_iters=100):
    """
    批量 Sinkhorn 算法
    
    同时处理多个传输问题
    """
    # a: [batch, n]
    # b: [batch, m]
    # C: [batch, n, m]
    
    batch_size = a.size(0)
    
    # 初始化
    u = torch.ones(batch_size, a.size(1))
    v = torch.ones(batch_size, b.size(1))
    
    # Gibbs 核(批量计算)
    K = torch.exp(-C / epsilon)
    
    for _ in range(num_iters):
        # 更新 u: [batch, n] / ([batch, n, m] @ [batch, m, 1]) -> [batch, n]
        u = a / (torch.bmm(K, v.unsqueeze(2)).squeeze(2) + 1e-50)
        
        # 更新 v: [batch, m] / ([batch, m, n] @ [batch, n, 1]) -> [batch, m]
        v = b / (torch.bmm(K.transpose(1, 2), u.unsqueeze(2)).squeeze(2) + 1e-50)
    
    # 批量计算传输计划
    gamma = u.unsqueeze(2) * K * v.unsqueeze(1)
    
    return gamma

GPU 加速

def sinkhorn_gpu_demo():
    """
    演示 GPU 加速
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # 大规模问题
    n = 10000
    m = 10000
    
    a = torch.ones(n, device=device) / n
    b = torch.ones(m, device=device) / m
    
    # 随机成本矩阵
    x = torch.randn(n, 1, device=device)
    y = torch.randn(m, 1, device=device)
    C = (x - y.T) ** 2
    
    # GPU 计算
    gamma = sinkhorn_iteration(a, b, C, epsilon=0.1, num_iters=100)
    
    print(f"使用设备: {device}")
    print(f"传输计划形状: {gamma.shape}")
    print(f"传输计划范围: [{gamma.min():.6f}, {gamma.max():.6f}]")

核心公式速查

概念公式
熵正则化OT
Gibbs核
Sinkhorn迭代,
最优传输计划
收敛速率

参考


扩展阅读

Footnotes

  1. Sinkhorn, R. (1964). “Relationship between Positive Matrices and Successive Contructions of Diagonal Matrices”. Proceedings of the American Mathematical Society.