Skip to main content
QUICK REVIEW

[论文解读] Biased Importance Sampling for Deep Neural Network Training

Angelos Katharopoulos, François Fleuret|arXiv (Cornell University)|May 31, 2017
Advanced Neural Network Applications参考文献 18被引用 48
一句话总结

本文提出了一种用于深度神经网络训练的有偏重要性采样方法,该方法以损失值作为重要性度量,并通过轻量级辅助网络近似计算以降低计算成本。与均匀采样相比,该方法将训练速度提升了20–30%,改善了泛化性能,并实现了更低方差的更快收敛,尤其在CIFAR10上的卷积神经网络(CNNs)和Penn Treebank上的循环神经网络(RNNs)中表现显著。

ABSTRACT

Importance sampling has been successfully used to accelerate stochastic optimization in many convex problems. However, the lack of an efficient way to calculate the importance still hinders its application to Deep Learning. In this paper, we show that the loss value can be used as an alternative importance metric, and propose a way to efficiently approximate it for a deep model, using a small model trained for that purpose in parallel. This method allows in particular to utilize a biased gradient estimate that implicitly optimizes a soft max-loss, and leads to better generalization performance. While such method suffers from a prohibitively high variance of the gradient estimate when using a standard stochastic optimizer, we show that when it is combined with our sampling mechanism, it results in a reliable procedure. We showcase the generality of our method by testing it on both image classification and language modeling tasks using deep convolutional and recurrent neural networks. In particular, our method results in 30% faster training of a CNN for CIFAR10 than when using uniform sampling.

研究动机与目标

  • 为解决在大规模数据集上训练深度神经网络时的高计算成本问题。
  • 克服在深度学习中计算精确重要性权重(如梯度范数)的不可行性。
  • 开发一种可扩展、低开销的重要性采样方案,以提升训练收敛速度和泛化性能。
  • 在不增加标准采样之外的计算开销的前提下,降低梯度方差并加速训练。
  • 使该方法在不同架构(CNNs、RNNs)和任务(图像分类、语言建模)中具有通用性。

提出的方法

  • 该方法使用损失值作为重要性的代理指标,构建一种采样分布,其梯度方差相比均匀采样更小。
  • 在主模型并行训练一个小型辅助网络,以预测每个训练样本的损失,从而高效近似重要性权重。
  • 重要性采样方案被实现为一种有偏梯度估计器,隐式最小化软最大损失,促进更好的泛化性能。
  • 通过平滑机制在线更新采样分布,以在训练迭代过程中稳定重要性估计。
  • 通过用基于损失的近似替代基于梯度范数的采样,避免了昂贵的二阶计算。
  • 该方法与标准优化器(如Adam)兼容,并可无缝集成到现有训练流程中。

实验结果

研究问题

  • RQ1损失值能否作为深度学习中重要性采样的有效且计算可行的代理?
  • RQ2轻量级辅助网络能否以极低的计算开销准确近似大型深度模型的损失?
  • RQ3基于损失的重要性采样在实践中是否能降低梯度方差并加速训练收敛?
  • RQ4该方法是否能在不增加过拟合风险的前提下提升泛化性能?
  • RQ5该方法在不同架构和数据集(包括CNNs和RNNs)上的可扩展性如何?

主要发现

  • 与均匀采样相比,该方法在CIFAR10上的CNN模型训练速度提升了30%。
  • 在Penn Treebank语言建模任务中,尽管每轮训练仅多花费10%的时间,该方法仍比均匀采样减少了20%的总训练时间(节省近2小时)。
  • 使用轻量级辅助网络进行损失近似,使训练时间减少了20%,同时保持或提升了泛化性能。
  • 在第5轮训练时,MNIST上的测试误差降低了0.2%;在第30轮训练时,CIFAR10上的测试误差降低了约1%。
  • 当k=0.5(平滑参数)时,该方法对噪声重要的估计具有鲁棒性,且对超参数调优需求更低。
  • 与依赖梯度范数或经验性超参数的先前方法相比,该方法在复杂数据集(如Penn Treebank)上表现更优。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。