[论文解读] Not All Samples Are Created Equal: Deep Learning with Importance Sampling
本文提出了一种用于深度学习中 SGD 的基于原理的的重要性采样方案,该方案使用一个可高效计算的上界来对每样本梯度范数进行估计,以将计算聚焦在信息量大的样本上,从而在 CNN、RNN 以及微调任务中实现方差降维和墙钟时间加速。
Deep neural network training spends most of the computation on examples that are properly handled, and could be ignored. We propose to mitigate this phenomenon with a principled importance sampling scheme that focuses computation on "informative" examples, and reduces the variance of the stochastic gradients during training. Our contribution is twofold: first, we derive a tractable upper bound to the per-sample gradient norm, and second we derive an estimator of the variance reduction achieved with importance sampling, which enables us to switch it on when it will result in an actual speedup. The resulting scheme can be used by changing a few lines of code in a standard SGD procedure, and we demonstrate experimentally, on image classification, CNN fine-tuning, and RNN training, that for a fixed wall-clock time budget, it provides a reduction of the train losses of up to an order of magnitude and a relative improvement of test errors between 5% and 17%.
研究动机与目标
- 动机:解释深度网络中 SGD 的均匀采样低效之处,并旨在通过聚焦于信息量大的样本来加速训练。
- 推导一个可处理的每样本梯度范数的上界,该上界可以在一次前向传播中计算。
- 量化来自重要性采样的方差化减,并建立在有益时才开启 IS 的标准。
- 提供一个简单、通用的算法,可以插入到标准 SGD 训练中以实现加速。
- 在图像分类、微调和序列分类任务上进行实证验证。
提出的方法
- 推导一个上界 hat{G}_{i},用于每样本梯度范数且可在前向传播中计算(方程 13–20)。
- 提出一个两阶段采样方案:先抽取一个较大批次 B,计算一个与 hat{G}_{i} 成正比的分布,然后从该分布中抽取一个较小的批次 b。
- 形式化一个盈利性测试,以判定何时方差降才值得使用 IS,使用由方程 27 推导出的等效批量大小增加 tau。
- 给出算法 1,根据阈值 tau_th 以及 tau 的指数移动平均来在均匀采样和重要性采样之间切换。
- 证明用于方差降的最优采样与每样本梯度范数成正比,但使用一个可处理的上界以实现实际部署。
实验结果
研究问题
- RQ1是否能高效地计算一个可处理的每样本梯度范数的上界,以在深度网络中指导重要性采样?
- RQ2基于该上界的重要性采样是否在固定的墙钟时间预算下,降低梯度方差并在 CNN、RNN 和微调情境中加速训练?
- RQ3在训练过程中何时开启重要性采样有益,以及如何可靠地检测?
- RQ4在相同时间预算下,所提 IS 方案在训练损失和测试误差方面与基于损失的采样和均匀采样相比如何?
- RQ5哪些实际指南(例如预采样大小 B、较小的批次 b、阈值 tau_th)在各类架构上能带来稳定的加速?
主要发现
- 基于上界的 IS 实现了方差降,与基于梯度范数的采样高度相关,且与真实的每样本梯度范数高度相关。
- 在 CIFAR10/CIFAR100 上,该方法实现了墙钟时间加速,并且在某些情况下获得更低的训练损失和更好的测试误差,相较于均匀或基于损失的采样(例如 CIFAR100 显示更快的收敛和在测试误差上的 5%–? 提升)。
- 在微调中,该方法相对于均匀采样在半小时内加速收敛并降低测试误差(如 MIT67 数据集结果)。
- 对于带有 LSTM 的逐像素 MNIST,在固定时间预算内实现更低的训练损失和更好的测试误差,而基于损失的采样可能会降低性能。
- 该算法设计为仅需替换一行代码即可在标准 SGD 工作流中启用重要性采样,并且能够在训练过程中自适应变化的模型参数。
- 方差降可以解释为等效地增大了批量大小,并有一个可计算的准则(tau)以确保加速。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。