通用神经最优传输(UNOT)

UNOT(Universal Neural Optimal Transport)是ICML 2025提出的新型神经最优传输框架,能够在给定代价函数下,准确预测不同分辨率离散测度之间的(熵)最优传输距离和传输计划。1

问题背景

最优传输的挑战

最优传输(Optimal Transport, OT)问题是许多应用的核心,但精确求解OT问题的计算代价极高:

  • 标准的Kantorovich问题的计算复杂度为 为分布支撑点数量)
  • 即使使用Sinkhorn算法的熵正则化方法,对于大规模问题仍需要多次迭代

现有方法的局限

现有的神经OT求解器存在以下问题:

  1. 泛化能力不足:针对单一数据集训练的模型难以迁移到新数据
  2. 分辨率固定:大多数方法只能处理固定大小的输入
  3. 跨维度困难:难以处理不同维度的分布

UNOT框架核心

Fourier神经算子基础

UNOT基于Fourier神经算子(Fourier Neural Operators, FNOs)构建,这是一种在函数空间之间映射的神经网络架构。2

核心理念

FNO通过以下方式处理函数空间的映射:

其中 是函数空间。对于OT问题,输入是两个概率测度,输出是对应的对偶势函数。

Fourier层

FNO的核心是Fourier层,通过傅里叶变换在频域中操作:

网络在频域中学习线性变换:

其中 是可学习的频域参数。

离散化不变性

离散化不变性(Discretization Invariance)是UNOT的关键特性,使得网络能够处理任意分辨率的输入。

对于不同分辨率的输入 ,FNO将其视为连续测度的离散化,从而实现跨分辨率泛化。

不同分辨率的输入:
┌─────────────────────────────────┐
│  32×32 图像  ──┐                │
│               ├──→  UNOT ──→ 势函数预测
│  64×64 图像  ──┘                │
│                                 │
│  128×128 图像 ─┘                │
└─────────────────────────────────┘

自对抗训练与自举损失

UNOT采用自对抗训练框架,包含两个网络:

  1. 预测网络 :预测对偶势函数
  2. 生成网络 :生成训练用的合成测度

熵正则化OT的对偶形式

给定代价函数 ,熵正则化OT问题为:

其对偶形式为:

自举损失函数

UNOT的核心创新是自举损失(Bootstrapping Loss),定义为:

其中 是经过少量Sinkhorn迭代后的伪标签势函数。

关键洞察:最小化自举损失等价于最小化真实损失(Proposition 5):

生成器的通用性

论文证明了生成网络 通用性:它能够生成任意固定维度的离散分布。这意味着训练过程中,网络见过的分布类型不影响其在测试集上的泛化能力。


理论保证

生成器的万能逼近

定理(生成器通用性):对于任意固定维度 ,存在神经网络 能够参数化任意离散概率分布对

这确保了训练数据覆盖所有可能的分布类型,从而保证测试时的泛化能力。

损失函数的理论连接

命题(损失函数连接):最小化自举损失 渐近等价于最小化关于真实势函数 的损失。

这为训练目标提供了坚实的理论基础,解释了为什么简单的自举策略能够有效。


实验结果

OT距离预测

UNOT在多个数据集上实现了精确的OT距离预测:

数据集相对误差
MNIST~1-3%
CIFAR-10~2-4%
2D高斯混合~1%

OT计划捕获

除了距离预测,UNOT还能准确捕获传输计划的几何结构,包括:

  • Wasserstein空间中的测地线
  • Wasserstein重心(Wasserstein Barycenters)

作为Sinkhorn初始化

UNOT最重要的应用之一是作为Sinkhorn算法的初始化

实验表明,使用UNOT初始化可以实现:

  • 加速比:最高
  • 收敛稳定性:更好的收敛性质
Sinkhorn收敛对比:
迭代次数    默认初始化    Gaussian初始化    UNOT初始化
   1         45.2%         38.1%           8.3%
   5         22.1%         15.2%           2.1%
  10         12.3%          8.4%           0.8%
  20          5.2%          3.1%           0.2%

与现有方法的对比

方法跨数据集泛化变分辨率理论保证Sinkhorn加速
经典Sinkhorn-
单数据集NOT部分
UNOT

关键优势

  1. 通用性:首个能跨数据集泛化的神经OT求解器
  2. 灵活性:支持变分辨率输入
  3. 可微性:保持Sinkhorn算法的可微性
  4. 高效性:作为初始化时可达7.4倍加速

代码示例

import torch
from src.evaluation.import_models import load_fno
 
# 加载预训练模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_fno("unot_fno", device=device)
 
# 输入:两个展平的概率测度
mu = ...  # shape (batch_size, resolution**2)
nu = ...  # shape (batch_size, resolution**2)
 
# 预测对偶势函数
g = model(mu, nu)  # shape (batch_size, resolution**2)
 
# 用于Sinkhorn初始化
K = torch.exp(-C / epsilon)  # Gibbs核
u = torch.exp(g[:, :n] / epsilon)  # 源侧缩放向量
v = torch.exp(g[:, n:] / epsilon)  # 目标侧缩放向量

总结

UNOT提出了一种基于Fourier神经算子的通用神经最优传输框架,具有以下贡献:

  1. 首个元OT求解器:能够在任意离散测度上泛化
  2. 离散化不变性:处理变分辨率输入
  3. 理论支撑:严格的通用性和损失连接保证
  4. 实际应用:可作为Sinkhorn算法的高效初始化

该工作为神经最优传输领域开辟了新的研究方向,特别是在需要快速、大规模OT计算的场景中具有重要价值。


参考文献

Footnotes

  1. Geuter, J., Kornhardt, G., Tomasson, I., & Laschos, V. (2025). Universal Neural Optimal Transport. Proceedings of the 42nd International Conference on Machine Learning (ICML 2025), 19196-19232. https://openreview.net/forum?id=t10fde8tQ7

  2. Kovachki, N., et al. (2024). Fourier Neural Operators. Neural Operators library documentation. https://neuraloperator.github.io/dev/theory_guide/fno.html