Skip to main content
QUICK REVIEW

[论文解读] Wasserstein Dependency Measure for Representation Learning

Sherjil Ozair, Corey Lynch|arXiv (Cornell University)|Mar 28, 2019
Domain Adaptation and Few-Shot Learning参考文献 42被引用 20
一句话总结

本文提出 Wasserstein 依赖度量(WDM)作为一种新型表示学习目标,用 Wasserstein 距离替代互信息估计中的 KL 散度,通过使用利普希茨连续神经网络来稳定训练。所提出的方法——Wasserstein 预测编码(WPC)——在高互信息任务中,尤其当数据结构与神经网络归纳偏置不一致时,其表示质量显著优于对比预测编码(CPC)。

ABSTRACT

Mutual information maximization has emerged as a powerful learning objective for unsupervised representation learning obtaining state-of-the-art performance in applications such as object recognition, speech recognition, and reinforcement learning. However, such approaches are fundamentally limited since a tight lower bound of mutual information requires sample size exponential in the mutual information. This limits the applicability of these approaches for prediction tasks with high mutual information, such as in video understanding or reinforcement learning. In these settings, such techniques are prone to overfit, both in theory and in practice, and capture only a few of the relevant factors of variation. This leads to incomplete representations that are not optimal for downstream tasks. In this work, we empirically demonstrate that mutual information-based representation learning approaches do fail to learn complete representations on a number of designed and real-world tasks. To mitigate these problems we introduce the Wasserstein dependency measure, which learns more complete representations by using the Wasserstein distance instead of the KL divergence in the mutual information estimator. We show that a practical approximation to this theoretically motivated solution, constructed using Lipschitz constraint techniques from the GAN literature, achieves substantially improved results on tasks where incomplete representations are a major challenge.

研究动机与目标

  • 解决无监督表示学习中互信息最大化的基本局限性,即紧致下界需要相对于互信息呈指数增长的样本量。
  • 指出基于互信息的方法在视频理解与强化学习等高互信息任务中无法学习完整表示。
  • 提出一种基于 Wasserstein 距离的新学习目标,以克服基于 KL 散度的互信息估计器在理论与实践上的不足。
  • 通过实证证明,WPC(Wasserstein 依赖度量的实际实现)学习到的表示比 CPC 更完整、更鲁棒,尤其在具有挑战性的数据分布下表现更优。
  • 表明 WPC 对小批量大小不敏感,且在数据结构与卷积网络归纳偏置不匹配时泛化能力更强。

提出的方法

  • 在互信息估计中用 Wasserstein 距离替代 KL 散度,定义一种新的依赖度量,称为 Wasserstein 依赖度量(WDM)。
  • 通过在互信息估计器中使用的神经网络强制实现利普希茨连续性,构建实用的估计器,借鉴生成对抗网络(GAN)文献中的技术。
  • 采用类似对比预测编码(CPC)的框架,但将互信息目标替换为 WDM 目标,以训练表示模型。
  • 应用权重裁剪或梯度惩罚来强制实现利普希茨约束,确保训练过程中梯度更新稳定且有意义。
  • 训练表示模型以最大化上下文与未来表示之间的 WDM,从而鼓励模型捕捉更多变化因素。
  • 在具有高互信息的合成数据集与真实世界数据集(包括 MultiOmniglot、CelebA 和 MultiviewShapes3D)上评估该方法,并与 CPC 进行性能比较。

实验结果

研究问题

  • RQ1为何基于互信息的表示学习方法在高互信息场景(如视频或强化学习)中无法学习完整表示?
  • RQ2在互信息估计中,用 Wasserstein 距离替代 KL 散度是否能带来更鲁棒、更完整的表示?
  • RQ3所提出的 Wasserstein 预测编码(WPC)方法在不同数据分布与网络架构下,与对比预测编码(CPC)相比表现如何?
  • RQ4利普希茨约束在低数据量或高互信息场景下,对表示学习的稳定性与泛化能力改善程度如何?
  • RQ5当数据结构与卷积网络归纳偏置不匹配时,WPC 是否仍能保持优于 CPC 的性能?

主要发现

  • 在 SplitCelebA 数据集(高互信息,约 34.43 nats)上,WPC 使用全连接网络实现 0.87 的准确率,优于 CPC 的 0.85。
  • 在同一数据集上,WPC 在不同网络架构(全连接与卷积)下均保持一致性能,而 CPC 在使用卷积网络时性能显著下降。
  • 在 StackedMultiOmniglot 上,由于数据结构与 CNN 归纳偏置不匹配,WPC 相较于 CPC 的性能优势比在 SpatialMultiOmniglot 上更明显,表明其对架构不匹配具有更强鲁棒性。
  • WPC 在小批量大小为 32 时达到最优性能,且更大批次下性能提升微乎其微;而 CPC 需要更大批次才能稳定。
  • 在 MultiviewShapes3D 上,WPC 在所有测试的数据集与小批量大小下均持续优于 CPC,表明其在多样化数据分布下的泛化能力。
  • 结果证实,通过使用 Wasserstein 距离,WPC 有效缓解了互信息估计的根本局限——指数级样本复杂度,从而在高信息量场景下实现更完整的表示。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。