Skip to main content
QUICK REVIEW

[论文解读] Revisiting Reweighted Wake-Sleep for Models with Stochastic Control Flow

Tuan Anh Le, Adam R. Kosiorek|arXiv (Cornell University)|May 26, 2018
Gaussian Processes and Bayesian Inference参考文献 31被引用 24
一句话总结

本文重新审视了用于训练随机控制流模型(scfms)的重加权唤醒睡眠(rws)算法,证明rws在性能上优于当前最先进的方法,如重要性加权自编码器(IWAE)和连续松弛方法。与IWAE不同,rws在粒子数增加时能同时提升模型和推理网络的质量,提供更低方差的梯度估计器,并在多种scfm架构中表现出更强的鲁棒性。

ABSTRACT

Stochastic control-flow models (SCFMs) are a class of generative models that involve branching on choices from discrete random variables. Amortized gradient-based learning of SCFMs is challenging as most approaches targeting discrete variables rely on their continuous relaxations---which can be intractable in SCFMs, as branching on relaxations requires evaluating all (exponentially many) branching paths. Tractable alternatives mainly combine REINFORCE with complex control-variate schemes to improve the variance of naive estimators. Here, we revisit the reweighted wake-sleep (RWS) (Bornschein and Bengio, 2015) algorithm, and through extensive evaluations, show that it outperforms current state-of-the-art methods in learning SCFMs. Further, in contrast to the importance weighted autoencoder, we observe that RWS learns better models and inference networks with increasing numbers of particles. Our results suggest that RWS is a competitive, often preferable, alternative for learning SCFMs.

研究动机与目标

  • 为解决在随机控制流模型(scfms)中基于摊销的梯度学习问题,其中离散分支阻止了标准的连续松弛技术应用。
  • 评估重加权唤醒睡眠(rws)是否能在scfm学习中超越现有最先进方法,如带有控制变量的IWAE或连续松弛方法。
  • 研究现有方法(如唤醒睡眠,ws)和加权唤醒睡眠(ww)的失败模式,特别是低粒子数情形下的分支剪枝问题。
  • 提出并验证一种防御性采样扩展方法δ-ww,以减轻低粒子数下ww训练的偏差,并提升推理网络质量。

提出的方法

  • 重新审视重加权唤醒睡眠(rws)算法,该算法通过基于多个粒子的重加权估计器,在生成模型和推理网络之间交替进行优化。
  • 采用自归一化重要性采样估计器来计算模型和推理网络参数的梯度,相比朴素的reinforce估计器,可降低方差。
  • 提出δ-ww,一种加权唤醒睡眠的变体,通过将推理网络与均匀提议分布结合(qϕ,δ(z|x) = (1−δ)qϕ(z|x) + δUniform(z)),以减少低粒子数情形下的偏差。
  • 使用多种粒子设置(K=2, K=4, K=8)评估不同计算预算下的可扩展性和性能提升。
  • 将rws应用于三个基准任务:概率上下文无关文法(PCFG)、用于多数字MNIST的Attend, Infer, Repeat(AIR)模型,以及高斯混合模型(GMM),以分析失败模式。
  • 采用模型证据下界(ELBO)最大化作为训练目标,通过重加权重要性采样进行梯度估计,以处理离散随机控制流。

实验结果

研究问题

  • RQ1rws是否在学习随机控制流模型方面优于当前最先进的方法,如带有控制变量的IWAE或连续松弛方法?
  • RQ2rws的性能如何随粒子数增加而变化?其是否能同时提升模型和推理网络的质量?
  • RQ3唤醒睡眠变体中的分支剪枝失败模式由何原因引起?能否通过防御性采样加以缓解?
  • RQ4在何种情形下(如低粒子数与高粒子数)更适合使用ws或ww?数据分布偏差如何影响学习结果?
  • RQ5像δ-ww这样的简单修改能否在不牺牲高粒子数性能的前提下,提升ww在低粒子数设置下的稳定性和性能?

主要发现

  • rws在所有评估任务中均一致优于基于IWAE的方法,包括vimco、relax和带控制变量的reinforce,在模型似然和推理网络质量方面表现更优。
  • 与IWAE不同,rws在粒子数增加时,模型和推理网络质量均呈单调提升;而IWAE在粒子数较多时(尤其在AIR任务中)推理网络质量反而下降。
  • 在GMM实验中,标准ww在低粒子数情形(K=2)下出现由偏差引起的分支剪枝失败模式,模型坍缩至狭窄支持区域,无法充分探索潜在空间。
  • 所提出的δ-ww变体成功缓解了该偏差问题,在低粒子数情形(K=2)下优于所有其他方法,同时在高粒子数下仍保持强劲性能。
  • rws在连续松弛方法失效的场景中表现有效,例如具有潜在无限递归的PCFG,证明其适用于复杂的控制流结构。
  • 本研究确认,ws与ww变体的选择取决于梯度偏差的主要来源:若为数据分布偏差,则ww更优;若为自归一化估计器偏差,则ws更优。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。