Skip to main content
QUICK REVIEW

[論文レビュー] On the Convergence and Robustness of Training GANs with Regularized Optimal Transport

Maziar Sanjabi, Jimmy Ba|arXiv (Cornell University)|Feb 22, 2018
Generative Adversarial Networks and Image Synthesis被引用数 65
ひとこと要約

この論文は、正則化された最適輸送(Wasserstein)目的を用いたGAN訓練における定常点へのグローバル収束を証明し、勾配情報が双対識別器の解法を介して効率的に得られること、頑健性のためのSinkhorn lossを導入することを示す。

ABSTRACT

Generative Adversarial Networks (GANs) are one of the most practical methods for learning data distributions. A popular GAN formulation is based on the use of Wasserstein distance as a metric between probability distributions. Unfortunately, minimizing the Wasserstein distance between the data distribution and the generative model distribution is a computationally challenging problem as its objective is non-convex, non-smooth, and even hard to compute. In this work, we show that obtaining gradient information of the smoothed Wasserstein GAN formulation, which is based on regularized Optimal Transport (OT), is computationally effortless and hence one can apply first order optimization methods to minimize this objective. Consequently, we establish theoretical convergence guarantee to stationarity for a proposed class of GAN optimization algorithms. Unlike the original non-smooth formulation, our algorithm only requires solving the discriminator to approximate optimality. We apply our method to learning MNIST digits as well as CIFAR-10images. Our experiments show that our method is computationally efficient and generates images comparable to the state of the art algorithms given the same architecture and computational power.

研究の動機と目的

  • Motivate the use of regularized Wasserstein distance to train GANs and address non-smoothness of the original Wasserstein objective.
  • Prove smoothness of the regularized OT objective with respect to generator parameters and establish gradient error bounds when the discriminator is solved approximately.
  • Demonstrate global convergence of SGD-based GAN training to a stationary point under approximate discriminator solutions.
  • Propose a robust Sinkhorn loss to mitigate bias when lambda is not small, while preserving meaningful distance measures.
  • Provide algorithmic guidance on balancing discriminator accuracy and generator steps for improved convergence.

提案手法

  • Define regularized OT (dc,λ) with KL or norm-2 regularizers and its dual formulation.
  • Show hλ(θ)=dc,λ(Gθ(q),p) is smooth in θ and bound dependence of the optimal transport plan π* on θ.
  • Establish that approximately solving the dual yields an approximate gradient for hλ with error δ, enabling SGD with convergence guarantees.
  • Propose Algorithm 1: Oracle-based non-convex SGD that uses an ε-accurate dual solution to obtain gradients and proves convergence to a stationary point under standard conditions.
  • Introduce Sinkhorn loss Lλ(p,q) to reduce bias from large λ and derive a SGD-based method (Algorithm 2) with analogous convergence guarantees.
  • Provide theoretical results linking discriminator accuracy ε, gradient variance σ2, and the convergence rate to stationary solutions.

実験結果

リサーチクエスチョン

  • RQ1Can the regularized Wasserstein GAN objective provide smooth gradients with respect to generator parameters?
  • RQ2How does approximate solving of the discriminator (dual) affect gradient accuracy and SGD convergence in GAN training?
  • RQ3Does the proposed regularized OT objective guarantee global convergence to a stationary point for GANs under practical (approximate) discriminator solves?
  • RQ4Does the Sinkhorn loss offer robustness to the choice of regularization parameter λ without biasing the learned generator?
  • RQ5What practical guidance on discriminator accuracy vs. generator steps improves convergence in GANs using regularized OT?

主な発見

  • The regularized Wasserstein distance hλ(θ) is smooth in the generator parameters under mild assumptions.
  • An ε-accurate solution to the dual regularized OT problem yields a gradient for hλ with error δ = O(sqrt(ε/λ)); this enables SGD with convergence to approximate stationarity.
  • A vanilla SGD-like method (Algorithm 1) converges to an approximate stationary solution with a rate depending on the gradient Lipschitz constant, desired accuracy, and discriminator error.
  • The Sinkhorn loss Lλ provides robustness to λ values and avoids bias while preserving meaningful distance behavior, facilitating stable training.
  • Experiments on MNIST and CIFAR-10 show SWGAN approaches can be computationally efficient and produce competitive images under the regularized OT framework, with performance depending on cost function and latent representations.
  • Theoretical results extend to a Sinkhorn-loss-based SGD method (Algorithm 2) with analogous convergence guarantees.

より良い研究を、今すぐ始めましょう

論文設計から論文執筆まで、研究時間を劇的に削減しましょう。

クレジットカード登録不要

このレビューはAIが作成し、人間の編集者が確認しました。