[论文解读] Optimal transport mapping via input convex neural networks
该论文提出一种最小极大框架,使用输入凸神经网络(ICNNs)在2-Wasserstein距离下学习分布之间的最优传输映射,将传输映射表示为凸函数梯度。
In this paper, we present a novel and principled approach to learn the optimal transport between two distributions, from samples. Guided by the optimal transport theory, we learn the optimal Kantorovich potential which induces the optimal transport map. This involves learning two convex functions, by solving a novel minimax optimization. Building upon recent advances in the field of input convex neural networks, we propose a new framework where the gradient of one convex function represents the optimal transport mapping. Numerical experiments confirm that we learn the optimal transport mapping. This approach ensures that the transport mapping we find is optimal independent of how we initialize the neural networks. Further, target distributions from a discontinuous support can be easily captured, as gradient of a convex function naturally models a {\em discontinuous} transport mapping.
研究动机与目标
- 动机:从样本中学习分布之间的最优传输映射,且不带正则化偏差。
- 提出一种最小极大表述,它对偶问题凸化并避免受约束的投影。
- 利用 ICNNs 参数化凸函数及其对偶凸共轭,以将传输映射复原为梯度。
- 在所提框架下建立学习映射的一致性与稳定性。
- 展示其在高维与真实世界数据集上的适用性,作为深度生成模型工具。
提出的方法
- 给出2-Wasserstein对偶问题的公式,并对约束进行凸化以得到最小极大目标(方程5)。
- 使用凸函数f及其对偶凸共轭进行重新参数化,通过一个单一的凸函数来表达W2^2(P,Q)(定理3.3)。
- 用输入凸神经网络(ICNNs)参数化凸函数f,以确保输入上的凸性。
- 将传输映射表示为凸函数梯度(在最小极大设置中的g的梯度)。
- 使用随机优化(Adam)解决得到的最小极大问题,强制f的ICNN非负权重,且为g提供可选的正则化以实现稳定性(方程9)。
- 给出理论保证:一致性(定理3.3)和稳定性界(定理3.6)。
实验结果
研究问题
- RQ1我们是否可以在不带偏置 primal 问题的正则化下,使用样本从Q到P学习最优传输映射T*?
- RQ2当用ICNNs参数化时,是否可通过对凸函数的最小极大表述得到精确的2-Wasserstein传输映射?
- RQ3将传输映射表示为凸函数的梯度,是否能实现不连续的传输映射与支撑之间更尖锐的边界?
- RQ4与标准的OT基方法或基于GAN的方法相比,该方法在高维和真实世界数据集上的表现如何?
- RQ5学得的传输映射对初始化和训练过程的鲁棒性如何,是否可以建立稳定性界?
主要发现
| 指标 | alpha=1 | alpha=5 | alpha=10 |
|---|---|---|---|
| ||μ_{T(Q)}−μ||^2 | 0.19±0.015 | 13.95±1.45 | 29.05±5.16 |
| 100·(||μ_{T(Q)}−μ||/||μ||)^2 | 0.02±0.001 | 0.07±0.005 | 0.04±0.006 |
- 提出的基于ICNN的最小极大框架产生的传输映射在视觉和定性上与简单分布与复杂分布之间的最优传输对齐。
- 方法对初始化鲁棒,与W1-LP和W2GAN基线相比,后者对初始化敏感。
- 凸函数梯度表示使传输映射能够不连续,产生断开支撑之间的尖锐边界。
- 在高维数据上的实验(包括高斯到高斯、高斯到混合以及基于MNIST的任务)表明该方法可扩展到复杂分布,精度在合理范围内。
- 关于728维高斯传输的定量结果(表1)报告了均值传输误差随目标位移大小(alpha)增加而增大,给出均值传输误差和相对误差的具体数值。
- 理论贡献包括一致性结果(定理3.3)和学习到的传输映射的稳定性界(定理3.6)。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。