Skip to main content
QUICK REVIEW

[论文解读] Learning an Adaptive Learning Rate Schedule

Zhen Xu, Andrew M. Dai|arXiv (Cornell University)|Sep 20, 2019
Domain Adaptation and Few-Shot Learning参考文献 19被引用 39
一句话总结

本论文提出一种强化学习框架,自动学习一个自适应学习率调度器,能够对训练动态做出响应,并在数据集与架构之间展示改进的结果与迁移性。

ABSTRACT

The learning rate is one of the most important hyper-parameters for model training and generalization. However, current hand-designed parametric learning rate schedules offer limited flexibility and the predefined schedule may not match the training dynamics of high dimensional and non-convex optimization problems. In this paper, we propose a reinforcement learning based framework that can automatically learn an adaptive learning rate schedule by leveraging the information from past training histories. The learning rate dynamically changes based on the current training dynamics. To validate this framework, we conduct experiments with different neural network architectures on the Fashion MINIST and CIFAR10 datasets. Experimental results show that the auto-learned learning rate controller can achieve better test results. In addition, the trained controller network is generalizable -- able to be trained on one data set and transferred to new problems.

研究动机与目标

  • 由于高维非凸优化中的多样化训练动态,需要灵活的学习率调度而非固定参数形式的驱动
  • 提出一个强化学习框架,基于以往的训练历史自动调整学习率
  • 定义适当的状态特征、奖励信号和动作设计,以实现稳定的学习率控制
  • 证明学习控制器在数据集与架构之间的泛化能力与迁移性有所提升

提出的方法

  • 一个强化学习控制器基于从受训网络观察到的训练动态提出学习率缩放因子
  • 状态观测包括训练/验证损失、预测方差,以及最终层权重的统计信息,以及前一步的学习率
  • 奖励为逐步的验证损失,以便为信用分配提供较为频繁的反馈
  • 动作是对前一步学习率的学习率缩放因子,用于实现预热与衰减
  • 控制器采用 Proximal Policy Optimization (PPO) 来学习一个最小化累计验证损失的策略
  • 实验将自动学习的时间表与基线的分步衰减在 Fashion-MNIST 和 CIFAR-10 上使用 CNN 与 ResNet 架构进行对比

实验结果

研究问题

  • RQ1一种基于 RL 的控制器能否比固定步长的参数化调度更有效地自适应学习率?
  • RQ2学习到的控制器是否能在不同数据集与模型架构之间泛化?
  • RQ3将逐步的验证损失作为奖励是否比仅使用最终奖励能改善信用分配?
  • RQ4学习率缩放动作是否比直接输出原始学习率更稳定、具备更强的可迁移性?

主要发现

DatasetModelTest Loss (Baseline)Test Accuracy (Baseline)Test Loss (Auto-learned)Test Accuracy (Auto-learned)
Fashion MNISTCNN0.2497 (0.0042)0.9102 (0.0019)0.2351 ∗ (0.0038)0.9201 ∗ (0.0022)
Fashion MNISTResNet0.2346 (0.0074)0.9188 (0.0029)0.2296 (0.0069)0.9192 (0.0028)
CIFAR-10CNN0.9539 (0.0140)0.6759 (0.0048)0.9361 ∗ (0.0104)0.6787 (0.0041)
CIFAR-10ResNet0.8317 (0.0155)0.7395 (0.0206)0.6288 ∗ (0.0196)0.8181 ∗ (0.0069)
  • 自动学习的时间表在所有测试任务中均比基线分步衰减的测试损失与准确率表现更好
  • 控制器呈现多样化的学习模式(如先热身再衰减,或先平坦后热身与衰减),体现对模型/数据集的动态适应
  • 迁移实验表明在 CIFAR-10 训练的控制器可有效迁移到 Fashion-MNIST,且优于迁移的基线
  • 逐步奖励信号提升了训练动态并使学习率控制更稳定,相较于仅使用最终奖励
  • 该方法在两种数据集上的 CNN 和 ResNet 架构上均具备泛化性

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。