[论文解读] ARM: Augment-REINFORCE-Merge Gradient for Stochastic Binary Networks
该论文提出了一种用于训练随机二值网络的增强-REINFORCE-合并(ARM)梯度估计器,通过结合变量增强、REINFORCE和重参数化方法,生成一个无偏、低方差的梯度估计器,且计算开销极小。ARM在具有二值层的离散潜在变量模型中,于变分自编码与最大似然估计任务上实现了最先进性能,其在负对数似然与证据下界指标上均优于有偏与无偏基线方法。
To backpropagate the gradients through stochastic binary layers, we propose the augment-REINFORCE-merge (ARM) estimator that is unbiased, exhibits low variance, and has low computational complexity. Exploiting variable augmentation, REINFORCE, and reparameterization, the ARM estimator achieves adaptive variance reduction for Monte Carlo integration by merging two expectations via common random numbers. The variance-reduction mechanism of the ARM estimator can also be attributed to either antithetic sampling in an augmented space, or the use of an optimal anti-symmetric "self-control" baseline function together with the REINFORCE estimator in that augmented space. Experimental results show the ARM estimator provides state-of-the-art performance in auto-encoding variational inference and maximum likelihood estimation, for discrete latent variable models with one or multiple stochastic binary layers. Python code for reproducible research is publicly available.
研究动机与目标
- 解决在训练具有随机二值层的离散潜在变量模型时梯度估计器方差过高的问题。
- 开发一种无偏且计算高效的梯度估计器,避免对可学习基线参数的依赖。
- 实现深度随机二值前馈网络在变分推断与最大似然估计任务中的有效训练。
- 通过基于公共随机数与扩展空间采样的新型方差减少机制,降低蒙特卡洛积分的方差,且不引入偏差。
提出的方法
- ARM估计器使用变量增强,在扩展潜在空间中构建相关二值向量的联合分布。
- 在增强空间中应用REINFORCE估计器,利用公共随机数以减少方差。
- 通过重参数化与对称采样结合两个期望,有效构建自控基线函数。
- 梯度计算为均匀噪声与在两个相关二值样本上函数值差值的乘积。
- 该估计器避免学习基线参数,保持低计算成本的同时实现自适应方差减少。
- 其推导基于重参数化技巧与得分函数估计器的结合,方差减少归因于增强空间中最优反对称基线。
实验结果
研究问题
- RQ1能否设计一种无偏梯度估计器,用于随机二值网络,实现低方差且无需可学习基线参数?
- RQ2ARM估计器与现有方法(如REINFORCE、REBAR和Gumbel-Softmax)相比,在收敛速度与测试性能方面表现如何?
- RQ3ARM估计器是否能仅通过单一样本有效降低离散潜在变量模型中蒙特卡洛积分的方差?
- RQ4ARM估计器是否在保持低计算复杂度的同时,优于有偏与无偏基线方法,在最大似然估计与变分推断任务中表现更优?
主要发现
- 在MNIST条件分布估计基准上,ARM估计器实现了57.9 ± 0.1的最低测试负对数似然,优于所有对比方法。
- 与标准得分函数(SF)估计器相比,ARM显著降低了方差,后者得到的负对数似然为72.0。
- 该方法在自编码变分推断任务中收敛更快,并实现了更低的测试负证据下界。
- ARM通过避免估计基线参数,保持了低计算复杂度,而REBAR与RELAX则需进行此类估计。
- 由于使用了相关样本与公共随机数,该估计器即使仅使用单个蒙特卡洛样本也保持无偏且低方差。
- 实证结果证实,ARM在包含一个或多个随机二值层的多个基准测试中均提供了最先进性能。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。