[论文解读] On the Convergence and Robustness of Training GANs with Regularized Optimal Transport
本文证明在使用正则化最优传输(Wasserstein)目标进行 GAN 训练时,全局收敛到平稳点,并展示通过对偶判别器求解可以有效获得梯度信息,同时引入 Sinkhorn 损失以提高鲁棒性。
Generative Adversarial Networks (GANs) are one of the most practical methods for learning data distributions. A popular GAN formulation is based on the use of Wasserstein distance as a metric between probability distributions. Unfortunately, minimizing the Wasserstein distance between the data distribution and the generative model distribution is a computationally challenging problem as its objective is non-convex, non-smooth, and even hard to compute. In this work, we show that obtaining gradient information of the smoothed Wasserstein GAN formulation, which is based on regularized Optimal Transport (OT), is computationally effortless and hence one can apply first order optimization methods to minimize this objective. Consequently, we establish theoretical convergence guarantee to stationarity for a proposed class of GAN optimization algorithms. Unlike the original non-smooth formulation, our algorithm only requires solving the discriminator to approximate optimality. We apply our method to learning MNIST digits as well as CIFAR-10images. Our experiments show that our method is computationally efficient and generates images comparable to the state of the art algorithms given the same architecture and computational power.
研究动机与目标
- 推动使用正则化 Wasserstein 距离来训练 GAN,并解决原始 Wasserstein 目标的非光滑性。
- 证明正则化 OT 目标对生成器参数的光滑性,并在对判别器近似求解时建立梯度误差界限。
- 在近似判别器解下,展示基于 SGD 的 GAN 训练全局收敛到一个平稳点。
- 提出鲁棒的 Sinkhorn 损失,以在 λ 不小的时候减小偏差,同时保持有意义的距离度量。
- 就平衡判别器精度与生成器步骤以提升收敛性提供算法指引。
提出的方法
- 定义带 KL 或范数-2 正则化的正则化 OT (dc,λ) 及其对偶公式。
- 表示 hλ(θ)=dc,λ(Gθ(q),p) 对 θ 光滑,并界定最优传输计划 π* 对 θ 的依赖。
- 建立对偶的近似求解产生 hλ 的近似梯度,误差为 δ,从而实现带收敛保证的 SGD。
- 提出 Algorithm 1: Oracle-based non-convex SGD,使用 ε-精准的对偶解来获得梯度,并在标准条件下证明收敛到平稳点。
- 引入 Sinkhorn 损失 Lλ(p,q) 以减少来自较大 λ 的偏差,并推导一个带有类似收敛保证的基于 SGD 的方法(Algorithm 2)。
- 提供将判别器精度 ε、梯度方差 σ2 与收敛到平稳解的速度联系起来的理论结果。
实验结果
研究问题
- RQ1正则化 Wasserstein GAN 目标是否能提供关于生成器参数的平滑梯度?
- RQ2判别器(对偶)近似求解如何影响梯度准确性和 GAN 训练中的 SGD 收敛?
- RQ3在实际(近似)判别器求解下,提出的正则化 OT 目标是否保证 GAN 全局收敛到平稳点?
- RQ4Sinkhorn 损失是否对正则化参数 λ 的选择具有鲁棒性,而不偏向学习到的生成器?
- RQ5在使用正则化 OT 的 GAN 中,关于判别器精度与生成器步数的实际指导有哪些能改进收敛?
主要发现
- 在温和假设下,正则化 Wasserstein 距离 hλ(θ) 对生成器参数是光滑的。
- 对偶的ε-精确解给出 hλ 的梯度,误差 δ = O(sqrt(ε/λ));这使得 SGD 收敛到近似平稳性。
- 普通的 SGD 风格方法(Algorithm 1)收敛到近似的平稳解,其速率取决于梯度 Lipschitz 常数、期望精度和判别器误差。
- Sinkhorn 损失 Lλ 提供对 λ 值的鲁棒性,避免偏差,同时保留有意义的距离行为,促进稳定训练。
- 在 MNIST 与 CIFAR-10 的实验显示 SWGAN 方法在正则化 OT 框架下可计算高效并产生具有竞争力的图像,性能取决于成本函数和潜在表示。
- 理论结果扩展到基于 Sinkhorn 损失的 SGD 方法(Algorithm 2),具有类似的收敛保证。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。