[论文解读] 2-Wasserstein Approximation via Restricted Convex Potentials with Application to Improved Training for GANs
本文提出了一种新颖的框架,通过使用受限凸势函数(特别是输入-凸神经网络)来近似2- Wasserstein距离,以提升生成对抗网络(GAN)训练的稳定性和性能。通过利用Brenier定理的几何结构,该方法确保最优传输映射为凸势函数的梯度,从而实现高效的优化、统计泛化性,并在特定条件下实现精确的矩匹配。
We provide a framework to approximate the 2-Wasserstein distance and the optimal transport map, amenable to efficient training as well as statistical and geometric analysis. With the quadratic cost and considering the Kantorovich dual form of the optimal transportation problem, the Brenier theorem states that the optimal potential function is convex and the optimal transport map is the gradient of the optimal potential function. Using this geometric structure, we restrict the optimization problem to different parametrized classes of convex functions and pay special attention to the class of input-convex neural networks. We analyze the statistical generalization and the discriminative power of the resulting approximate metric, and we prove a restricted moment-matching property for the approximate optimal map. Finally, we discuss a numerical algorithm to solve the restricted optimization problem and provide numerical experiments to illustrate and compare the proposed approach with the established regularization-based approaches. We further discuss practical implications of our proposal in a modular and interpretable design for GANs which connects the generator training with discriminator computations to allow for learning an overall composite generator.
研究动机与目标
- 通过提供一种比标准GAN损失函数更稳定且具有几何合理性的替代方案,解决GAN训练中的不稳定性和模式崩溃问题。
- 开发一种框架,利用受限凸势函数(特别是输入-凸神经网络)来近似2-Wasserstein距离及最优传输映射。
- 通过受限凸函数类上优化的理论分析,确保所得度量具有统计泛化性和判别能力。
- 通过传输映射的组合学习实现生成器与判别器的联合优化,构建模块化且可解释的GAN架构。
- 为受限凸势函数类上的优化提供理论保证,包括矩匹配性和近似能力。
提出的方法
- 通过Kantorovich对偶性表述2-Wasserstein距离,并利用Brenier定理将最优传输映射表示为凸势函数的梯度。
- 将优化限制在参数化的凸函数类中,特别关注输入-凸神经网络(ICNNs),以保证凸性与可微性。
- 利用最优传输的对偶形式,推导出在受限凸势函数类上的可处理优化问题,从而实现高效训练。
- 引入基于同伦的渐进训练策略,通过逐步增加势函数类的复杂度来加速收敛。
- 通过近似传输映射将生成器训练与判别器计算直接关联,实现组合式GAN。
- 利用矩阵微积分推导出对偶目标函数的闭式梯度,从而通过随机梯度下降实现高效优化。
实验结果
研究问题
- RQ1受限凸势函数能否提供对2-Wasserstein距离的稳定、可微且具有统计泛化能力的近似?
- RQ2凸势函数类的选择(例如ICNNs、分段线性-二次函数)如何影响近似质量与矩匹配特性?
- RQ3受限最优势函数与真实最优传输映射在几何与统计保真度方面有何关系?
- RQ4所提出的框架是否能通过实现组合式、模块化设计,将生成器与判别器学习相连接,从而改善GAN训练?
- RQ5在受限凸势函数类上的优化中,可建立哪些理论保证(例如矩匹配性、可近似性)?
主要发现
- 所提方法在最优点处的凸势函数类切空间上实现了精确的矩匹配,确保了统计一致性。
- 对于高斯分布,当使用仿射凸势函数时,最优传输映射可被精确恢复,且目标函数在参数(A, b)上为凸函数,确保唯一全局最小值点。
- 当使用输入-凸神经网络时,该方法在最优参数配置下可精确匹配特定统计量(例如,x_i 1_{x_i ≥ 0},x_i 1_{x_i ≤ 0}),且当目标分布位于梯度映射的像集中时,可实现度量的精确恢复。
- 该框架为近似误差提供了上界:对于两点分布,误差被限制在2α|v|以内,其中α衡量对称性偏离程度。
- 通过同伦方法实现渐进训练,可逐步增加势函数类的复杂度,从而加速收敛。
- 数值实验表明,与基于正则化的基线方法相比,所提方法在训练稳定性与生成样本质量方面表现更优,尤其在模式覆盖与收敛速度方面优势显著。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。