[논문 리뷰] Improving Generalization Performance by Switching from Adam to SGD
이 논문은 Adam으로 시작하고 그래디언트-하위공간 투영 기준이 만족되면 SGD로 전환하는 자동 하이브리드 옵티마이저 SWATS를 제안하여 여러 작업에서 일반화 성능을 향상시킵니다.
Despite superior training outcomes, adaptive optimization methods such as Adam, Adagrad or RMSprop have been found to generalize poorly compared to Stochastic gradient descent (SGD). These methods tend to perform well in the initial portion of training but are outperformed by SGD at later stages of training. We investigate a hybrid strategy that begins training with an adaptive method and switches to SGD when appropriate. Concretely, we propose SWATS, a simple strategy which switches from Adam to SGD when a triggering condition is satisfied. The condition we propose relates to the projection of Adam steps on the gradient subspace. By design, the monitoring process for this condition adds very little overhead and does not increase the number of hyperparameters in the optimizer. We report experiments on several standard benchmarks such as: ResNet, SENet, DenseNet and PyramidNet for the CIFAR-10 and CIFAR-100 data sets, ResNet on the tiny-ImageNet data set and language modeling with recurrent networks on the PTB and WT2 data sets. The results show that our strategy is capable of closing the generalization gap between SGD and Adam on a majority of the tasks.
연구 동기 및 목표
- 적응적 방법(Adam)과 SGD 사이의 일반화 격차를 설명한다.
- Adam의 빠른 초기 진행과 SGD의 일반화를 결합한 하이브리드 학습 전략을 제안한다.
- 추가 하이퍼파라미터 없이 자동 전환 메커니즘을 개발한다.
- 이미지 분류 및 언어 모델 벤치마크에서 방법을 검증한다.
제안 방법
- SWATS를 Adam으로 시작해 투영 기반 기준이 작동할 때 SGD로 전환하는 두 단계의 옵티마이저로 정의한다.
- Adam 스텝 p_k와 그래디언트 g_k를 계산하고, 비정규(Non-orthogonal) 프로젝션을 이용해 SGD 방향이 Adam 스텝과 정렬되도록 하는 SGD 학습률 gamma_k를 도출한다.
- 전환 후 SGD 비율을 추정하기 위해 gamma_k의 지수 평균 lambda_k를 유지한다.
- |lambda_k/(1-beta2^k) - gamma_k| < epsilon 이 성립할 때 전환을 트리거하여 Lambda = lambda_k/(1-beta2^k)의 SGD 학습률을 얻는다.
- Adam에 있는 것 이외의 추가 하이퍼파라미터는 도입되지 않으며, 전환 전에는 바이어스 보정된 모멘텀 기반의 Adam 업데이트를 사용한다.
- DenseNet, ResNet, PyramidNet, SENet, Tiny-ImageNet 및 언어 모델(Permutorial PTB 및 WT2)에서 CIFAR-10/100에 대해 SWATS를 SGD 및 Adam과 비교 평가한다.
실험 결과
연구 질문
- RQ1Can a hybrid optimizer combining Adam and SGD achieve generalization closer to SGD while retaining Adam’s fast initial progress?
- RQ2What automatic switching criterion can determine the optimal switch point without adding hyperparameters?
- RQ3How does SWATS perform across diverse tasks (image classification and language modeling) compared to pure Adam or SGD?
주요 결과
| 모델 | 데이터 셋 | SGDM | Adam | SWATS | Lambda | 전환 시점(에포크) |
|---|---|---|---|---|---|---|
| ResNet-32 | CIFAR-10 | 0.1 | 0.001 | 0.001 | 0.52 | 1.37 |
| DenseNet | CIFAR-10 | 0.1 | 0.001 | 0.001 | 0.79 | 11.54 |
| PyramidNet | CIFAR-10 | 0.1 | 0.001 | 0.0007 | 0.85 | 4.94 |
| SENet | CIFAR-10 | 0.1 | 0.001 | 0.001 | 0.54 | 24.19 |
| ResNet-32 | CIFAR-100 | 0.3 | 0.002 | 0.002 | 1.22 | 10.42 |
| DenseNet | CIFAR-100 | 0.1 | 0.001 | 0.001 | 0.51 | 11.81 |
| PyramidNet | CIFAR-100 | 0.1 | 0.001 | 0.001 | 0.76 | 18.54 |
| SENet | CIFAR-100 | 0.1 | 0.001 | 0.001 | 1.39 | 2.04 |
| LSTM | PTB | 55† | 0.003 | 0.003 | 7.52 | 186.03 |
| QRNN | PTB | 35† | 0.002 | 0.002 | 4.61 | 184.14 |
| LSTM | WT-2 | 60† | 0.003 | 0.003 | 1.11 | 259.47 |
| QRNN | WT-2 | 60† | 0.003 | 0.004 | 14.46 | 295.71 |
- SWATS generally matches the best performance among SGD and Adam across multiple architectures and datasets.
- Switching often occurs within the first 20 epochs for CIFAR datasets and around epoch 49 for Tiny-ImageNet, with occasional brief degradation during switching that is recovered later.
- The learned SGD learning rate after switch, Lambda, aligns with tuned SGD rates across tasks (as shown in Table 1).
- Adam shows strong initial progress but poorer generalization compared to SGD; SWATS closes this gap by switching to SGD at an informed point.
- In language modeling tasks, SWATS achieves comparable generalization to Adam while potentially requiring fewer training epochs to reach peak performance.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.