Skip to main content
QUICK REVIEW

[Paper Review] Wasserstein Dependency Measure for Representation Learning

Sherjil Ozair, Corey Lynch|arXiv (Cornell University)|Mar 28, 2019
Domain Adaptation and Few-Shot Learning42 references20 citations
TL;DR

This paper introduces the Wasserstein Dependency Measure (WDM) as a novel representation learning objective that replaces the KL divergence in mutual information estimation with the Wasserstein distance, using Lipschitz-continuous neural networks to stabilize training. The proposed method, Wasserstein Predictive Coding (WPC), achieves significantly better representation quality than contrastive predictive coding (CPC) on tasks with high mutual information, especially when the data's structure does not align with neural network inductive biases.

ABSTRACT

Mutual information maximization has emerged as a powerful learning objective for unsupervised representation learning obtaining state-of-the-art performance in applications such as object recognition, speech recognition, and reinforcement learning. However, such approaches are fundamentally limited since a tight lower bound of mutual information requires sample size exponential in the mutual information. This limits the applicability of these approaches for prediction tasks with high mutual information, such as in video understanding or reinforcement learning. In these settings, such techniques are prone to overfit, both in theory and in practice, and capture only a few of the relevant factors of variation. This leads to incomplete representations that are not optimal for downstream tasks. In this work, we empirically demonstrate that mutual information-based representation learning approaches do fail to learn complete representations on a number of designed and real-world tasks. To mitigate these problems we introduce the Wasserstein dependency measure, which learns more complete representations by using the Wasserstein distance instead of the KL divergence in the mutual information estimator. We show that a practical approximation to this theoretically motivated solution, constructed using Lipschitz constraint techniques from the GAN literature, achieves substantially improved results on tasks where incomplete representations are a major challenge.

Motivation & Objective

  • Address the fundamental limitation of mutual information maximization in unsupervised representation learning, where tight lower bounds require exponentially large sample sizes relative to mutual information.
  • Identify that mutual information-based methods fail to learn complete representations in high-mutual-information tasks such as video understanding and reinforcement learning.
  • Propose a new learning objective based on the Wasserstein distance to overcome the theoretical and practical shortcomings of KL-based mutual information estimators.
  • Demonstrate empirically that WPC, a practical implementation of the Wasserstein dependency measure, learns more complete and robust representations than CPC, especially under challenging data distributions.
  • Show that WPC is less sensitive to minibatch size and generalizes better when the data's structure does not align with the inductive bias of convolutional networks.

Proposed method

  • Replace the KL divergence in mutual information estimation with the Wasserstein distance to define a new dependency measure, termed the Wasserstein dependency measure (WDM).
  • Construct a practical estimator by enforcing Lipschitz continuity on the neural network used in the mutual information estimator, drawing on techniques from the GAN literature.
  • Use a contrastive predictive coding (CPC)-style framework but replace the mutual information objective with the WDM objective to train representation models.
  • Apply weight clipping or gradient penalty to enforce the Lipschitz constraint, ensuring stable and meaningful gradient updates during training.
  • Train the representation model to maximize the WDM between the context and future representation, encouraging the model to capture more factors of variation.
  • Evaluate the method on synthetic and real-world datasets with high mutual information, including MultiOmniglot, CelebA, and MultiviewShapes3D, comparing performance to CPC.

Experimental results

Research questions

  • RQ1Why do mutual information-based representation learning methods fail to learn complete representations in high-mutual-information settings such as video or reinforcement learning?
  • RQ2Can replacing the KL divergence with the Wasserstein distance in mutual information estimation lead to more robust and complete representations?
  • RQ3How does the performance of the proposed Wasserstein predictive coding (WPC) method compare to contrastive predictive coding (CPC) across different data distributions and network architectures?
  • RQ4To what extent does the Lipschitz constraint improve the stability and generalization of representation learning in low-data or high-mutual-information regimes?
  • RQ5Does WPC maintain superior performance across varying minibatch sizes and data structures that do not align with the inductive bias of convolutional networks?

Key findings

  • On the SplitCelebA dataset with high mutual information (~34.43 nats), WPC achieved 0.87 accuracy using fully connected networks, outperforming CPC’s 0.85.
  • On the same dataset, WPC maintained consistent performance across different network architectures (fully connected and convolutional), while CPC’s performance dropped significantly with convolutional networks.
  • On StackedMultiOmniglot, where the data structure does not align with CNN inductive bias, WPC outperformed CPC by a wider margin than on SpatialMultiOmniglot, indicating robustness to architectural mismatch.
  • WPC achieved optimal performance with a minibatch size of 32 and showed minimal improvement with larger batches, unlike CPC, which required larger batches to stabilize.
  • On MultiviewShapes3D, WPC consistently outperformed CPC across all tested dataset and minibatch sizes, demonstrating generalization across diverse data distributions.
  • The results confirm that WPC mitigates the fundamental limitation of mutual information estimation—exponential sample complexity—by using the Wasserstein distance, leading to more complete representations in high-information settings.

Better researchstarts right now

From paper design to paper writing, dramatically reduce your research time.

No credit card · Free plan available

This review was created by AI and reviewed by human editors.