概述

最优传输(Optimal Transport, OT)理论为理解注意力机制提供了一个优雅的几何框架。近期研究表明,Self-Attention矩阵可以严格地解释为某种半松弛熵最优传输问题的解。这一发现不仅提供了注意力机制的几何直觉,还为设计新的注意力变体提供了理论基础。12


最优传输基础

熵最优传输问题

给定两个概率分布 (源)和 (目标),Kantorovich运输问题寻求最优耦合 最小化总成本:

其中 是从第 个源点到第 个目标点的成本, 是所有联合分布的可行集合。

熵正则化版本(entropic OT)引入Kullback-Leibler正则化项:

其中 是熵项, 是温度参数。

Sinkhorn算法

熵正则化问题可以通过著名的Sinkhorn算法高效求解:

import torch
import torch.nn.functional as F
 
def sinkhorn(a, b, C, epsilon, num_iters=10):
    """
    Sinkhorn algorithm for entropic optimal transport
    
    Args:
        a: Source distribution (batch,)
        b: Target distribution (batch,)
        C: Cost matrix (batch, batch)
        epsilon: Regularization parameter
        num_iters: Number of iterations
    
    Returns:
        Optimal transport plan gamma
    """
    K = torch.exp(-C / epsilon)  # Gibbs kernel
    
    # Initialize scalings
    u = torch.ones_like(a)
    v = torch.ones_like(b)
    
    for _ in range(num_iters):
        u = a / (K @ v + 1e-8)
        v = b / (K.T @ u + 1e-8)
    
    # Compute optimal transport plan
    gamma = u.view(-1, 1) * K * v.view(1, -1)
    
    return gamma
 
def sinkhorn_divergence(p, q, cost_matrix, epsilon=0.1):
    """
    Compute Sinkhorn divergence (regularized OT distance)
    
    S_epsilon(p, q) = OT_epsilon(p, q) - 0.5*OT_epsilon(p, p) - 0.5*OT_epsilon(q, q)
    """
    gamma_pq = sinkhorn(p, q, cost_matrix, epsilon)
    gamma_pp = sinkhorn(p, p, cost_matrix, epsilon)
    gamma_qq = sinkhorn(q, q, cost_matrix, epsilon)
    
    ot_pq = torch.sum(gamma_pq * cost_matrix)
    ot_pp = torch.sum(gamma_pp * cost_matrix)
    ot_qq = torch.sum(gamma_qq * cost_matrix)
    
    return ot_pq - 0.5 * ot_pp - 0.5 * ot_qq

Self-Attention作为半松弛OT

问题形式化

考虑标准的Scaled Dot-Product Attention:

分别为查询和键矩阵。令 为注意力分数矩阵。

核心定理:注意力权重矩阵 是以下半松弛熵最优传输问题的闭式解:

其中约束 要求每行和为1(行随机),但列和不受约束。

数学证明

引理1:对于任意矩阵 ,优化问题

的唯一解为

证明:使用拉格朗日函数

求偏导并令为零:

解得:

这正是行归一化的softmax形式。

物理解释

从物理角度看:

  • 源分布:查询向量 被视为”粒子”
  • 目标分布:统一分布
  • 成本矩阵(负注意力分数)
  • 熵项:鼓励将注意力分散到多个键,而非集中在某一个
class AttentionAsOT(torch.nn.Module):
    """
    Attention mechanism interpreted as optimal transport
    """
    def __init__(self, d_model, n_heads, epsilon=1.0):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.epsilon = epsilon
        
        # Projections
        self.W_q = torch.nn.Linear(d_model, d_model)
        self.W_k = torch.nn.Linear(d_model, d_model)
        self.W_v = torch.nn.Linear(d_model, d_model)
        self.W_o = torch.nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, N, D = x.shape
        
        # Compute Q, K, V
        Q = self.W_q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
        
        # OT interpretation: attention = soft sorting
        # S_ij = similarity between query i and key j
        S = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        
        # Row-wise softmax = solution to semi-relaxed entropic OT
        # Each query distributes its "mass" (1) across keys
        A = F.softmax(S / self.epsilon, dim=-1)  # Soft OT plan
        
        # Apply to values
        out = torch.matmul(A, V)
        out = out.transpose(1, 2).contiguous().view(B, N, D)
        
        return self.W_o(out)

Sinkhorn Attention与全迭代

标准Attention vs Sinkhorn Attention

变体数学形式OT解释
标准Attention单步Sinkhorn(截断迭代)
Sinkhorn Attention完整Sinkhorn迭代真实验证OT最优性
Sinkformers多次迭代归一化更接近理论最优
def sinkhorn_attention(Q, K, V, epsilon=0.1, num_iters=10):
    """
    Sinkhorn Attention: Full iterative OT solution
    
    Standard attention uses one Sinkhorn iteration.
    Sinkhorn attention runs multiple iterations for convergence.
    """
    # Cost matrix (negative similarity = transport cost)
    C = -torch.matmul(Q, K.transpose(-2, -1)) / (Q.shape[-1] ** 0.5)
    
    # Gibbs kernel
    K_mat = torch.exp(-C / epsilon)
    
    # Initialize marginals (uniform distribution)
    a = torch.ones(Q.shape[0], Q.shape[1], device=Q.device) / Q.shape[1]
    b = torch.ones(Q.shape[0], Q.shape[1], device=Q.device) / Q.shape[1]
    
    # Sinkhorn-Knopp iterations
    for _ in range(num_iters):
        # Update scalings
        u = 1.0 / (torch.matmul(K_mat, b.unsqueeze(-1)).squeeze(-1) + 1e-8)
        v = 1.0 / (torch.matmul(K_mat.transpose(-2, -1), u.unsqueeze(-1)).squeeze(-1) + 1e-8)
        
        # Alternative update
        a_new = a * u
        b_new = b * v
        
        a, b = a_new, b_new
    
    # Compute optimal transport plan
    gamma = u.unsqueeze(-1) * K_mat * v.unsqueeze(-2)
    
    # Apply to values
    output = torch.matmul(gamma, V)
    
    return output
 
 
def compute_attention_ot_divergence(Q, K, epsilon=0.1):
    """
    Compute attention's deviation from true optimal transport
    
    Returns the Sinkhorn divergence as a measure of attention quality.
    Lower values indicate more "concentrated" (deterministic) attention.
    """
    # Cost matrix
    C = -torch.matmul(Q, K.transpose(-2, -1)) / (Q.shape[-1] ** 0.5)
    
    # Compute Sinkhorn divergence
    sink_div = sinkhorn_divergence(
        torch.ones(Q.shape[0], Q.shape[1], device=Q.device) / Q.shape[1],
        torch.ones(K.shape[0], K.shape[1], device=Q.device) / K.shape[1],
        C, epsilon
    )
    
    return sink_div

Attention Variants的OT视角

Linear Attention的OT解释

Linear Attention通过核近似避免 复杂度:

其中 是正定核函数。从OT角度看,这对应于用随机核近似确定性传输映射

class LinearAttentionAsOT(torch.nn.Module):
    """
    Linear attention with OT-inspired kernel design
    """
    def __init__(self, d_model, num_features=64):
        super().__init__()
        self.d_model = d_model
        self.num_features = num_features
        
        # Feature maps for Q and K
        self.phi_q = torch.nn.Linear(d_model, num_features, bias=False)
        self.phi_k = torch.nn.Linear(d_model, num_features, bias=False)
        self.W_v = torch.nn.Linear(d_model, d_model)
        self.W_o = torch.nn.Linear(d_model, d_model)
        
        # OT-inspired: use ReLU kernel (positive, sparse)
        self.activation = torch.nn.ReLU()
    
    def forward(self, x):
        # Feature extraction
        q = self.phi_q(x)  # (B, N, F)
        k = self.phi_k(x)  # (B, N, F)
        v = self.W_v(x)    # (B, N, D)
        
        # OT kernel: sparse positive features
        q = self.activation(q)
        k = self.activation(k)
        
        # Numerically stable linear attention
        kv = torch.einsum('bnd,bnv->bdv', k, v)
        qkv = torch.einsum('bmd,bdv->bmv', q, kv)
        Z = torch.einsum('bmd,bd->bm', q, torch.sum(k, dim=1))
        
        out = qkv / (Z.unsqueeze(-1) + 1e-8)
        
        return self.W_o(out)

FlashAttention的OT解释

FlashAttention通过IO-aware计算提高效率。从OT角度看:

  • Tile-wise处理:局部OT计算 + 全局归一化
  • 在线算法:流式OT近似
  • 数值稳定性:Sinkhorn外迭代的物理意义

OT理论对Attention设计的启示

温度参数的影响

Attention行为OT解释
近似one-hot(hard attention)熵项消失,趋向最优匹配
标准softmax适度正则化
均匀分布强正则化,完全随机
def analyze_attention_temperature(Q, K, V, epsilon_range=[0.01, 0.1, 0.5, 1.0, 2.0]):
    """
    Analyze how temperature epsilon affects attention behavior
    through OT lens
    """
    results = []
    
    for eps in epsilon_range:
        # Attention weights at this temperature
        S = torch.matmul(Q, K.transpose(-2, -1)) / (Q.shape[-1] ** 0.5)
        A = F.softmax(S / eps, dim=-1)
        
        # OT metrics
        entropy = -torch.sum(A * torch.log(A + 1e-8), dim=-1).mean()
        
        # Concentration: how peaked is the attention?
        max_attention = A.max(dim=-1).values.mean()
        
        # Effective number of heads attended to
        effective_size = torch.exp(entropy).mean()
        
        results.append({
            'epsilon': eps,
            'entropy': entropy.item(),
            'max_attention': max_attention.item(),
            'effective_size': effective_size.item()
        })
    
    return results

设计原则

基于OT理论,优秀Attention设计应满足:

  1. 成本函数设计:使用语义有意义的距离度量
  2. 温度调度:根据任务自适应调整
  3. 行随机性:保持每行和为1的约束
  4. 稀疏性:通过正则化鼓励稀疏解

稳定性与泛化保证

OT理论提供的保证

  1. Wasserstein稳定性:输入的小扰动导致输出的有界Wasserstein距离变化
  1. 插值性质:Attention解在均匀分布和one-hot之间插值

  2. 几何直觉:提供注意力机制的物理/几何直观理解


代码实现:完整示例

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
 
class OTAttentionExperiment(nn.Module):
    """
    Experiment module for comparing standard vs OT-inspired attention
    """
    def __init__(self, d_model, n_heads=8):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.standard_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.ot_attn = OptimalTransportAttention(d_model, n_heads, epsilon=0.1)
    
    def compare_attentions(self, x):
        """
        Compare standard and OT-inspired attention patterns
        """
        # Standard attention
        std_out, std_weights = self.standard_attn(x, x, x, need_weights=True)
        
        # OT-inspired attention
        ot_out, ot_weights = self.ot_attn(x, x, x, return_weights=True)
        
        # Compare
        print(f"Standard attention entropy: {self._compute_entropy(std_weights):.4f}")
        print(f"OT attention entropy: {self._compute_entropy(ot_weights):.4f}")
        
        return std_out, ot_out
    
    def _compute_entropy(self, weights):
        return (-weights * torch.log(weights + 1e-8)).sum(-1).mean()
 
 
class OptimalTransportAttention(nn.Module):
    """
    Attention mechanism with OT-inspired design choices
    """
    def __init__(self, d_model, n_heads, epsilon=0.1, sinkhorn_iters=3):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.epsilon = epsilon
        self.sinkhorn_iters = sinkhorn_iters
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        # Learnable temperature (per head)
        self.log_epsilon = nn.Parameter(torch.zeros(n_heads))
    
    def forward(self, query, key, value, mask=None, return_weights=False):
        B, N_q, D = query.shape
        N_k = key.shape[1]
        
        # Project
        Q = self.W_q(query).view(B, N_q, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(B, N_k, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(B, N_k, self.n_heads, self.d_k).transpose(1, 2)
        
        # Cost matrix (negative similarity)
        S = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        
        # Learnable temperature per head
        eps = F.softplus(self.log_epsilon).view(1, self.n_heads, 1, 1)
        
        # OT solution: row-wise softmax
        A = F.softmax(S / eps, dim=-1)
        
        # Optional: run Sinkhorn iterations for tighter OT solution
        if self.sinkhorn_iters > 0:
            A = self._sinkhorn_iteration(A, eps)
        
        # Apply attention
        out = torch.matmul(A, V)
        out = out.transpose(1, 2).contiguous().view(B, N_q, D)
        
        output = self.W_o(out)
        
        if return_weights:
            return output, A
        return output
    
    def _sinkhorn_iteration(self, A, eps, num_iters=3):
        """
        Run additional Sinkhorn iterations for better OT approximation
        """
        for _ in range(num_iters):
            # Row normalization (should already be done by softmax)
            A = A / (A.sum(-1, keepdim=True) + 1e-8)
            # Column normalization for doubly stochastic tendency
            A = A / (A.sum(-2, keepdim=True) + 1e-8)
            # Re-normalize rows
            A = F.softmax(A / eps, dim=-1)
        return A
 
 
# Experiment: compare OT attention with standard attention
def run_ot_attention_experiment():
    """
    Demonstrate OT-inspired attention on a simple task
    """
    torch.manual_seed(42)
    
    batch_size, seq_len, d_model = 4, 32, 128
    x = torch.randn(batch_size, seq_len, d_model)
    
    model = OTAttentionExperiment(d_model)
    std_out, ot_out = model.compare_attentions(x)
    
    print(f"\nOutput shape: {std_out.shape}")
    print(f"Outputs close: {torch.allclose(std_out, ot_out, atol=1e-5)}")
 
 
if __name__ == "__main__":
    run_ot_attention_experiment()

总结

最优传输理论为理解注意力机制提供了:

方面OT视角贡献
几何直觉Attention = soft sorting/ranking
数学基础严格的优化问题形式化
设计指导温度、稀疏性、稳定性原则
泛化保证Wasserstein稳定性
变体统一标准Attention、Sinkhorn Attention、Linear Attention

参考资料

Footnotes

  1. Mensch, A., & Blondel, M. (2018). Differentiable dynamic programming for sequence modeling. OpenReview.

  2. Geshkovski, B., et al. (2025). A mathematical perspective on transformers. AMS Bulletin.