Skip to main content
QUICK REVIEW

[论文解读] Improving Generalization Performance by Switching from Adam to SGD

Nitish Shirish Keskar, Richard Socher|arXiv (Cornell University)|Dec 20, 2017
Neural Networks and Applications参考文献 3被引用 403
一句话总结

本文提出 SWATS,一种自动混合优化器,起始使用 Adam,当梯度子空间投影准则满足时切换到 SGD,从而在多任务上提高泛化能力。

ABSTRACT

Despite superior training outcomes, adaptive optimization methods such as Adam, Adagrad or RMSprop have been found to generalize poorly compared to Stochastic gradient descent (SGD). These methods tend to perform well in the initial portion of training but are outperformed by SGD at later stages of training. We investigate a hybrid strategy that begins training with an adaptive method and switches to SGD when appropriate. Concretely, we propose SWATS, a simple strategy which switches from Adam to SGD when a triggering condition is satisfied. The condition we propose relates to the projection of Adam steps on the gradient subspace. By design, the monitoring process for this condition adds very little overhead and does not increase the number of hyperparameters in the optimizer. We report experiments on several standard benchmarks such as: ResNet, SENet, DenseNet and PyramidNet for the CIFAR-10 and CIFAR-100 data sets, ResNet on the tiny-ImageNet data set and language modeling with recurrent networks on the PTB and WT2 data sets. The results show that our strategy is capable of closing the generalization gap between SGD and Adam on a majority of the tasks.

研究动机与目标

  • 激发自适应方法(Adam)与 SGD 之间的泛化差距。
  • 提出一种混合训练策略,将 Adam 的快速初始进展与 SGD 的泛化能力结合起来。
  • 开发一个无需额外超参数的自动切换机制。
  • 在图像分类和语言建模基准上展示该方法。

提出的方法

  • 将 SWATS 定义为两阶段优化器,起始使用 Adam,当投影基准触发时切换到 SGD。
  • 计算 Adam 步长 p_k 和梯度 g_k,从非正交投影导出 SGD 学习率 gamma_k,确保 SGD 方向与 Adam 步长对齐。
  • 维护 gamma_k 的指数滑动平均 lambda_k,以估计切换后的 SGD 速率。
  • 当 |lambda_k/(1-beta2^k) - gamma_k| < epsilon 时触发切换,得到 SGD 学习率 Lambda = lambda_k/(1-beta2^k)。
  • 除了 Adam 中的超参数外,不引入额外的超参数;在切换前使用带偏差修正、带动量的 Adam 更新。
  • 在 CIFAR-10/100 上对 DenseNet、ResNet、PyramidNet、SENet,以及 Tiny-ImageNet 上,并在 PTB 和 WT2 上的语言模型,评估 SWATS 与 SGD 与 Adam 的比较。

实验结果

研究问题

  • RQ1将 Adam 和 SGD 结合的混合优化器是否能够在保持 Adam 快速初始进展的同时,让泛化接近 SGD?
  • RQ2什么自动切换准则可以在不添加超参数的情况下确定最佳切换点?
  • RQ3与纯 Adam 或 SGD 相比,SWATS 在不同任务(图像分类和语言建模)中的表现如何?

主要发现

模型数据集SGDMAdamSWATSLambda切换点(epochs)
ResNet-32CIFAR-100.10.0010.0010.521.37
DenseNetCIFAR-100.10.0010.0010.7911.54
PyramidNetCIFAR-100.10.0010.00070.854.94
SENetCIFAR-100.10.0010.0010.5424.19
ResNet-32CIFAR-1000.30.0020.0021.2210.42
DenseNetCIFAR-1000.10.0010.0010.5111.81
PyramidNetCIFAR-1000.10.0010.0010.7618.54
SENetCIFAR-1000.10.0010.0011.392.04
LSTMPTB55†0.0030.0037.52186.03
QRNNPTB35†0.0020.0024.61184.14
LSTMWT-260†0.0030.0031.11259.47
QRNNWT-260†0.0030.00414.46295.71
  • SWATS 在多种架构和数据集上通常达到 SGD 和 Adam 中的最佳表现之一。
  • 对于 CIFAR 数据集,切换通常发生在前 20 个 epoch 内;对于 Tiny-ImageNet,大约在第 49 epoch 切换,切换时偶有短暂下降,随后恢复。
  • 切换后的 SGD 学习率 Lambda 与跨任务调优的 SGD 速率一致(如表 1 所示)。
  • Adam 展现出强劲的初期进展,但与 SGD 相比泛化较差;SWATS 通过在知情点切换到 SGD 来缩小这一差距。
  • 在语言建模任务中,SWATS 实现的泛化与 Adam 相当,但达到最佳性能可能需要更少的训练 epoch。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。