概述
最优传输(Optimal Transport, OT)理论为理解注意力机制提供了一个优雅的几何框架。近期研究表明,Self-Attention矩阵可以严格地解释为某种半松弛熵最优传输问题的解。这一发现不仅提供了注意力机制的几何直觉,还为设计新的注意力变体提供了理论基础。12
最优传输基础
熵最优传输问题
给定两个概率分布 (源)和 (目标),Kantorovich运输问题寻求最优耦合 最小化总成本:
其中 是从第 个源点到第 个目标点的成本, 是所有联合分布的可行集合。
熵正则化版本(entropic OT)引入Kullback-Leibler正则化项:
其中 是熵项, 是温度参数。
Sinkhorn算法
熵正则化问题可以通过著名的Sinkhorn算法高效求解:
import torch
import torch.nn.functional as F
def sinkhorn(a, b, C, epsilon, num_iters=10):
"""
Sinkhorn algorithm for entropic optimal transport
Args:
a: Source distribution (batch,)
b: Target distribution (batch,)
C: Cost matrix (batch, batch)
epsilon: Regularization parameter
num_iters: Number of iterations
Returns:
Optimal transport plan gamma
"""
K = torch.exp(-C / epsilon) # Gibbs kernel
# Initialize scalings
u = torch.ones_like(a)
v = torch.ones_like(b)
for _ in range(num_iters):
u = a / (K @ v + 1e-8)
v = b / (K.T @ u + 1e-8)
# Compute optimal transport plan
gamma = u.view(-1, 1) * K * v.view(1, -1)
return gamma
def sinkhorn_divergence(p, q, cost_matrix, epsilon=0.1):
"""
Compute Sinkhorn divergence (regularized OT distance)
S_epsilon(p, q) = OT_epsilon(p, q) - 0.5*OT_epsilon(p, p) - 0.5*OT_epsilon(q, q)
"""
gamma_pq = sinkhorn(p, q, cost_matrix, epsilon)
gamma_pp = sinkhorn(p, p, cost_matrix, epsilon)
gamma_qq = sinkhorn(q, q, cost_matrix, epsilon)
ot_pq = torch.sum(gamma_pq * cost_matrix)
ot_pp = torch.sum(gamma_pp * cost_matrix)
ot_qq = torch.sum(gamma_qq * cost_matrix)
return ot_pq - 0.5 * ot_pp - 0.5 * ot_qqSelf-Attention作为半松弛OT
问题形式化
考虑标准的Scaled Dot-Product Attention:
设 分别为查询和键矩阵。令 为注意力分数矩阵。
核心定理:注意力权重矩阵 是以下半松弛熵最优传输问题的闭式解:
其中约束 要求每行和为1(行随机),但列和不受约束。
数学证明
引理1:对于任意矩阵 ,优化问题
的唯一解为 。
证明:使用拉格朗日函数
对 求偏导并令为零:
解得:
这正是行归一化的softmax形式。
物理解释
从物理角度看:
- 源分布:查询向量 被视为”粒子”
- 目标分布:统一分布
- 成本矩阵:(负注意力分数)
- 熵项:鼓励将注意力分散到多个键,而非集中在某一个
class AttentionAsOT(torch.nn.Module):
"""
Attention mechanism interpreted as optimal transport
"""
def __init__(self, d_model, n_heads, epsilon=1.0):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.epsilon = epsilon
# Projections
self.W_q = torch.nn.Linear(d_model, d_model)
self.W_k = torch.nn.Linear(d_model, d_model)
self.W_v = torch.nn.Linear(d_model, d_model)
self.W_o = torch.nn.Linear(d_model, d_model)
def forward(self, x):
B, N, D = x.shape
# Compute Q, K, V
Q = self.W_q(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, N, self.n_heads, self.d_k).transpose(1, 2)
# OT interpretation: attention = soft sorting
# S_ij = similarity between query i and key j
S = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
# Row-wise softmax = solution to semi-relaxed entropic OT
# Each query distributes its "mass" (1) across keys
A = F.softmax(S / self.epsilon, dim=-1) # Soft OT plan
# Apply to values
out = torch.matmul(A, V)
out = out.transpose(1, 2).contiguous().view(B, N, D)
return self.W_o(out)Sinkhorn Attention与全迭代
标准Attention vs Sinkhorn Attention
| 变体 | 数学形式 | OT解释 |
|---|---|---|
| 标准Attention | 单步Sinkhorn(截断迭代) | |
| Sinkhorn Attention | 完整Sinkhorn迭代 | 真实验证OT最优性 |
| Sinkformers | 多次迭代归一化 | 更接近理论最优 |
def sinkhorn_attention(Q, K, V, epsilon=0.1, num_iters=10):
"""
Sinkhorn Attention: Full iterative OT solution
Standard attention uses one Sinkhorn iteration.
Sinkhorn attention runs multiple iterations for convergence.
"""
# Cost matrix (negative similarity = transport cost)
C = -torch.matmul(Q, K.transpose(-2, -1)) / (Q.shape[-1] ** 0.5)
# Gibbs kernel
K_mat = torch.exp(-C / epsilon)
# Initialize marginals (uniform distribution)
a = torch.ones(Q.shape[0], Q.shape[1], device=Q.device) / Q.shape[1]
b = torch.ones(Q.shape[0], Q.shape[1], device=Q.device) / Q.shape[1]
# Sinkhorn-Knopp iterations
for _ in range(num_iters):
# Update scalings
u = 1.0 / (torch.matmul(K_mat, b.unsqueeze(-1)).squeeze(-1) + 1e-8)
v = 1.0 / (torch.matmul(K_mat.transpose(-2, -1), u.unsqueeze(-1)).squeeze(-1) + 1e-8)
# Alternative update
a_new = a * u
b_new = b * v
a, b = a_new, b_new
# Compute optimal transport plan
gamma = u.unsqueeze(-1) * K_mat * v.unsqueeze(-2)
# Apply to values
output = torch.matmul(gamma, V)
return output
def compute_attention_ot_divergence(Q, K, epsilon=0.1):
"""
Compute attention's deviation from true optimal transport
Returns the Sinkhorn divergence as a measure of attention quality.
Lower values indicate more "concentrated" (deterministic) attention.
"""
# Cost matrix
C = -torch.matmul(Q, K.transpose(-2, -1)) / (Q.shape[-1] ** 0.5)
# Compute Sinkhorn divergence
sink_div = sinkhorn_divergence(
torch.ones(Q.shape[0], Q.shape[1], device=Q.device) / Q.shape[1],
torch.ones(K.shape[0], K.shape[1], device=Q.device) / K.shape[1],
C, epsilon
)
return sink_divAttention Variants的OT视角
Linear Attention的OT解释
Linear Attention通过核近似避免 复杂度:
其中 是正定核函数。从OT角度看,这对应于用随机核近似确定性传输映射。
class LinearAttentionAsOT(torch.nn.Module):
"""
Linear attention with OT-inspired kernel design
"""
def __init__(self, d_model, num_features=64):
super().__init__()
self.d_model = d_model
self.num_features = num_features
# Feature maps for Q and K
self.phi_q = torch.nn.Linear(d_model, num_features, bias=False)
self.phi_k = torch.nn.Linear(d_model, num_features, bias=False)
self.W_v = torch.nn.Linear(d_model, d_model)
self.W_o = torch.nn.Linear(d_model, d_model)
# OT-inspired: use ReLU kernel (positive, sparse)
self.activation = torch.nn.ReLU()
def forward(self, x):
# Feature extraction
q = self.phi_q(x) # (B, N, F)
k = self.phi_k(x) # (B, N, F)
v = self.W_v(x) # (B, N, D)
# OT kernel: sparse positive features
q = self.activation(q)
k = self.activation(k)
# Numerically stable linear attention
kv = torch.einsum('bnd,bnv->bdv', k, v)
qkv = torch.einsum('bmd,bdv->bmv', q, kv)
Z = torch.einsum('bmd,bd->bm', q, torch.sum(k, dim=1))
out = qkv / (Z.unsqueeze(-1) + 1e-8)
return self.W_o(out)FlashAttention的OT解释
FlashAttention通过IO-aware计算提高效率。从OT角度看:
- Tile-wise处理:局部OT计算 + 全局归一化
- 在线算法:流式OT近似
- 数值稳定性:Sinkhorn外迭代的物理意义
OT理论对Attention设计的启示
温度参数的影响
| Attention行为 | OT解释 | |
|---|---|---|
| 近似one-hot(hard attention) | 熵项消失,趋向最优匹配 | |
| 标准softmax | 适度正则化 | |
| 均匀分布 | 强正则化,完全随机 |
def analyze_attention_temperature(Q, K, V, epsilon_range=[0.01, 0.1, 0.5, 1.0, 2.0]):
"""
Analyze how temperature epsilon affects attention behavior
through OT lens
"""
results = []
for eps in epsilon_range:
# Attention weights at this temperature
S = torch.matmul(Q, K.transpose(-2, -1)) / (Q.shape[-1] ** 0.5)
A = F.softmax(S / eps, dim=-1)
# OT metrics
entropy = -torch.sum(A * torch.log(A + 1e-8), dim=-1).mean()
# Concentration: how peaked is the attention?
max_attention = A.max(dim=-1).values.mean()
# Effective number of heads attended to
effective_size = torch.exp(entropy).mean()
results.append({
'epsilon': eps,
'entropy': entropy.item(),
'max_attention': max_attention.item(),
'effective_size': effective_size.item()
})
return results设计原则
基于OT理论,优秀Attention设计应满足:
- 成本函数设计:使用语义有意义的距离度量
- 温度调度:根据任务自适应调整
- 行随机性:保持每行和为1的约束
- 稀疏性:通过正则化鼓励稀疏解
稳定性与泛化保证
OT理论提供的保证
- Wasserstein稳定性:输入的小扰动导致输出的有界Wasserstein距离变化
-
插值性质:Attention解在均匀分布和one-hot之间插值
-
几何直觉:提供注意力机制的物理/几何直观理解
代码实现:完整示例
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class OTAttentionExperiment(nn.Module):
"""
Experiment module for comparing standard vs OT-inspired attention
"""
def __init__(self, d_model, n_heads=8):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.standard_attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.ot_attn = OptimalTransportAttention(d_model, n_heads, epsilon=0.1)
def compare_attentions(self, x):
"""
Compare standard and OT-inspired attention patterns
"""
# Standard attention
std_out, std_weights = self.standard_attn(x, x, x, need_weights=True)
# OT-inspired attention
ot_out, ot_weights = self.ot_attn(x, x, x, return_weights=True)
# Compare
print(f"Standard attention entropy: {self._compute_entropy(std_weights):.4f}")
print(f"OT attention entropy: {self._compute_entropy(ot_weights):.4f}")
return std_out, ot_out
def _compute_entropy(self, weights):
return (-weights * torch.log(weights + 1e-8)).sum(-1).mean()
class OptimalTransportAttention(nn.Module):
"""
Attention mechanism with OT-inspired design choices
"""
def __init__(self, d_model, n_heads, epsilon=0.1, sinkhorn_iters=3):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.epsilon = epsilon
self.sinkhorn_iters = sinkhorn_iters
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
# Learnable temperature (per head)
self.log_epsilon = nn.Parameter(torch.zeros(n_heads))
def forward(self, query, key, value, mask=None, return_weights=False):
B, N_q, D = query.shape
N_k = key.shape[1]
# Project
Q = self.W_q(query).view(B, N_q, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(B, N_k, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(B, N_k, self.n_heads, self.d_k).transpose(1, 2)
# Cost matrix (negative similarity)
S = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
# Learnable temperature per head
eps = F.softplus(self.log_epsilon).view(1, self.n_heads, 1, 1)
# OT solution: row-wise softmax
A = F.softmax(S / eps, dim=-1)
# Optional: run Sinkhorn iterations for tighter OT solution
if self.sinkhorn_iters > 0:
A = self._sinkhorn_iteration(A, eps)
# Apply attention
out = torch.matmul(A, V)
out = out.transpose(1, 2).contiguous().view(B, N_q, D)
output = self.W_o(out)
if return_weights:
return output, A
return output
def _sinkhorn_iteration(self, A, eps, num_iters=3):
"""
Run additional Sinkhorn iterations for better OT approximation
"""
for _ in range(num_iters):
# Row normalization (should already be done by softmax)
A = A / (A.sum(-1, keepdim=True) + 1e-8)
# Column normalization for doubly stochastic tendency
A = A / (A.sum(-2, keepdim=True) + 1e-8)
# Re-normalize rows
A = F.softmax(A / eps, dim=-1)
return A
# Experiment: compare OT attention with standard attention
def run_ot_attention_experiment():
"""
Demonstrate OT-inspired attention on a simple task
"""
torch.manual_seed(42)
batch_size, seq_len, d_model = 4, 32, 128
x = torch.randn(batch_size, seq_len, d_model)
model = OTAttentionExperiment(d_model)
std_out, ot_out = model.compare_attentions(x)
print(f"\nOutput shape: {std_out.shape}")
print(f"Outputs close: {torch.allclose(std_out, ot_out, atol=1e-5)}")
if __name__ == "__main__":
run_ot_attention_experiment()总结
最优传输理论为理解注意力机制提供了:
| 方面 | OT视角贡献 |
|---|---|
| 几何直觉 | Attention = soft sorting/ranking |
| 数学基础 | 严格的优化问题形式化 |
| 设计指导 | 温度、稀疏性、稳定性原则 |
| 泛化保证 | Wasserstein稳定性 |
| 变体统一 | 标准Attention、Sinkhorn Attention、Linear Attention |