[论文解读] Learning Wake-Sleep Recurrent Attention Models
本文提出了一种用于随机硬注意力网络的训练方法——唤醒-睡眠循环注意力模型(WS-RAM),通过重加权唤醒-睡眠学习与控制变量,改进后验推断并降低梯度方差。该方法在图像分类与图像字幕生成任务中实现了与变分推断相当的性能,且训练速度显著更快,展现出最先进的训练效率。
Despite their success, convolutional neural networks are computationally expensive because they must examine all image locations. Stochastic attention-based models have been shown to improve computational efficiency at test time, but they remain difficult to train because of intractable posterior inference and high variance in the stochastic gradient estimates. Borrowing techniques from the literature on training deep generative models, we present the Wake-Sleep Recurrent Attention Model, a method for training stochastic attention networks which improves posterior inference and which reduces the variability in the stochastic gradients. We show that our method can greatly speed up the training time for stochastic attention networks in the domains of image classification and caption generation.
研究动机与目标
- 解决随机硬注意力模型训练中的挑战,此类模型存在后验推断不可计算及梯度方差过高的问题。
- 在不牺牲图像分类与图像字幕生成任务性能的前提下,提升基于注意力机制模型的训练效率。
- 开发一种统一的训练流程,结合推理网络、重加权唤醒-睡眠学习与通过控制变量实现的方差减少。
- 与现有变分基线相比,实现更快的收敛速度与更优的注意力策略探索能力。
提出的方法
- WS-RAM使用生成网络建模注意力策略,同时使用独立的推理网络近似对快照位置的后验分布,且在训练过程中可访问标签信息。
- 应用重加权唤醒-睡眠算法联合训练生成网络与推理网络,通过迭代优化改进后验近似。
- 在训练过程中,利用来自推理网络的提议分布进行重要性采样,以估计不可计算的后验期望。
- 引入控制变量以降低随机梯度估计的方差,从而加速收敛。
- 引入探索启发式方法,防止在变分基线中过早收敛至次优策略。
- 通过基于重要性采样与控制变量的梯度估计,实现端到端的随机反向传播训练。
实验结果
研究问题
- RQ1重加权唤醒-睡眠方法是否能改进随机硬注意力模型中的后验推断?
- RQ2控制变量的使用是否显著降低了注意力模型训练中的梯度方差?
- RQ3WS-RAM能否在显著更短的训练时间内实现与变分推断相当的性能?
- RQ4引入具有标签访问权限的推理网络在注意力策略学习中产生了何种影响?
- RQ5探索启发式方法在多大程度上提升了随机注意力模型的训练稳定性和收敛性?
主要发现
- 在经过1000万次更新后,WS-RAM在翻译并缩放后的MNIST数据集上达到1.62%的测试错误率,优于变分基线(3.11%)以及未使用控制变量的消融WS-RAM(1.85%)。
- 与变分基线相比,WS-RAM显著缩短了训练时间,在MNIST与Flickr8k数据集的训练曲线中均显示出更快的收敛速度,且性能相当。
- 与基线方法相比,控制变量使梯度方差降低了40%-50%,表现为更低的梯度方差估计值与更高的有效样本量(ESS)。
- 推理网络改善了后验近似,但这一改进并未始终体现为更高的ESS,表明方差减少主要由控制变量驱动。
- 与变分基线不同,WS-RAM无需探索启发式方法即可避免陷入局部极小值,后者在无启发式方法时会退化为单一快照尺度。
- 在Flickr8k数据集上,WS-RAM的BLEU-1、BLEU-2、BLEU-3与BLEU-4得分分别为61.1、40.4、26.9与17.8,与变分方法的性能(62.3、41.6、26.9、17.2)相当,但训练速度更快。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。