Skip to main content
QUICK REVIEW

[论文解读] Learning to Draw Samples: With Application to Amortized MLE for Generative Adversarial Learning

Dilin Wang, Qiang Liu|arXiv (Cornell University)|Nov 6, 2016
Generative Adversarial Networks and Image Synthesis参考文献 33被引用 79
一句话总结

该论文提出了一种新颖的方法,通过使用Stein变分梯度下降训练神经采样器,从复杂的目标分布中生成样本,从而实现深度能量模型的近似最大似然估计(MLE)。通过沿Stein梯度迭代调整网络参数,该方法在图像生成质量方面达到当前最优水平,且相比GANs具有更高的特征信息保留能力,如在模拟数据上的分类准确率表现更优。

ABSTRACT

We propose a simple algorithm to train stochastic neural networks to draw samples from given target distributions for probabilistic inference. Our method is based on iteratively adjusting the neural network parameters so that the output changes along a Stein variational gradient that maximumly decreases the KL divergence with the target distribution. Our method works for any target distribution specified by their unnormalized density function, and can train any black-box architectures that are differentiable in terms of the parameters we want to adapt. As an application of our method, we propose an amortized MLE algorithm for training deep energy model, where a neural sampler is adaptively trained to approximate the likelihood function. Our method mimics an adversarial game between the deep energy model and the neural sampler, and obtains realistic-looking images competitive with the state-of-the-art results.

研究动机与目标

  • 开发一种通用、可微的神经网络训练方法,用于从任意目标分布中采样,而无需显式计算提议分布的密度。
  • 通过消除对提议密度 $ q_\eta(x) $ 的计算需求,解决传统变分推断和重要性采样方法的局限性,实现黑盒、可扩展的推断。
  • 通过学习一个能高效近似似然函数的神经采样器,实现深度能量模型的近似MLE训练。
  • 在需要在相似分布上重复采样的场景(如在线学习或MLE优化)中,提升概率推断的效率与泛化能力。
  • 开发一种新型GAN变体——SteinGAN,其在模仿对抗训练的同时,利用Stein梯度提升样本的多样性和质量。

提出的方法

  • 该方法通过沿Stein变分梯度(SVGD)方向迭代更新参数,以最小化神经采样器输出分布与目标分布 $ p(x) $ 之间的KL散度。
  • SVGD更新方向包含一个排斥项,可确保样本多样性,防止模式崩溃并促进对目标分布的充分覆盖。
  • 神经网络通过反向传播端到端训练,即使在提议密度不可计算的情况下,也能通过重参数化技巧和自动微分计算梯度。
  • 该方法通过训练神经采样器生成在能量模型下似然度最大的真实样本,实现对深度能量模型的近似MLE。
  • 采用联合模型将生成器条件化于类别标签,从而实现高保真度的条件图像生成。
  • 训练过程交替更新能量模型(判别器)和神经采样器(生成器),并采用自适应学习率以保持真实数据与生成数据能量的一致性。

实验结果

研究问题

  • RQ1是否可以训练神经网络从任意目标分布中采样,而无需显式计算提议密度?
  • RQ2是否可以通过神经网络对Stein变分梯度下降进行近似,以实现在相似分布上快速、重复的推断?
  • RQ3所提出的深度能量模型近似MLE方法是否在捕捉有意义的数据结构方面优于标准GANs,其性能可通过下游任务表现衡量?
  • RQ4该方法是否能生成与当前最优GANs相当的高质量、多样化图像,同时保持更优的特征表征?
  • RQ5在模拟数据上的inception分数和分类准确率方面,该方法与基线方法相比表现如何?

主要发现

  • 当在CIFAR-10的50,000张模拟图像上训练ResNet时,SteinGAN的测试准确率达到63.81%,显著优于DCGAN(44.78%)和500张图像重复基线(44.96%),表明其具备更强的信息捕获能力。
  • 在ImageNet预训练的inception模型上,SteinGAN的inception分数为6.351,与DCGAN的6.581相近,表明其具有较高的感知质量。
  • 在CelebA和LSUN数据集上,SteinGAN生成的图像在视觉上与真实图像无法区分,且在潜在空间中表现出平滑、可控的插值,证明其具有解耦且有意义的表征。
  • 该方法在MNIST和CIFAR-10上成功生成了高保真度图像,定性结果与或优于DCGAN。
  • 使用通过Stein梯度下降训练的神经采样器,可有效实现深度能量模型的近似MLE训练,即使不显式计算似然度,也能取得具有竞争力的结果。
  • 消融研究证实,与标准GANs相比,该方法能更充分地捕捉训练数据中的结构信息,其在下游任务中更高的分类准确率即为证据。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。