[논문 리뷰] Sharpness-Aware Minimization for Efficiently Improving Generalization
SAM은 훈련 손실뿐 아니라 손실 지형의 샤프니즈를 최적화하여 일반화를 개선하고 CIFAR, ImageNet 및 전이 작업에서 테스트 성능을 향상시킵니다.
In today's heavily overparameterized models, the value of the training loss provides few guarantees on model generalization ability. Indeed, optimizing only the training loss value, as is commonly done, can easily lead to suboptimal model quality. Motivated by prior work connecting the geometry of the loss landscape and generalization, we introduce a novel, effective procedure for instead simultaneously minimizing loss value and loss sharpness. In particular, our procedure, Sharpness-Aware Minimization (SAM), seeks parameters that lie in neighborhoods having uniformly low loss; this formulation results in a min-max optimization problem on which gradient descent can be performed efficiently. We present empirical results showing that SAM improves model generalization across a variety of benchmark datasets (e.g., CIFAR-10, CIFAR-100, ImageNet, finetuning tasks) and models, yielding novel state-of-the-art performance for several. Additionally, we find that SAM natively provides robustness to label noise on par with that provided by state-of-the-art procedures that specifically target learning with noisy labels. We open source our code at \url{https://github.com/google-research/sam}.
연구 동기 및 목표
- 초과 매개변수화된 모델에서 손실 지형의 기하와 일반화의 관계에 대한 동기를 제공합니다.
- 손실 값과 손실 샤프니스를 모두 최소화하는 실용적인 최적화 목표를 제안합니다.
- 샤프니스를 최적화하는 것이 다양한 데이터셋과 아키텍처에 걸쳐 일반화를 향상시킨다는 것을 입증합니다.
제안 방법
- Introduce SAM: min_w max_{||epsilon||_p <= rho} L_S(w+epsilon) + lambda||w||^2.
- 내부 최대화를 미분하고 1차 테일러 전개를 사용하여 효율적인 기울기 근사를 도출합니다.
- hat{epsilon}(w)을 p=2의 경우 학습 손실 그래디언트의 rho-스케일 정규화로 계산합니다.
- hat{epsilon}(w)로 평가된 기울기를 사용하여 SAM 목표에 대해 SGD로 w를 업데이트합니다.
- Pseudo-code(Algorithm 1)를 제공하고 병렬 처리 및 2차 항 소거를 포함한 실용적 구현 세부사항을 논의합니다.
- m-샤프니스(서브배치 교란)와 해essian 스펙트럼을 분석하여 샤프니스와 일반화 간의 연결을 제시합니다.
실험 결과
연구 질문
- RQ1샤프니스 항을 학습 목표에 포함시키면 표준 비전 벤치마크에서 일반화가 향상됩니까?
- RQ2SAM은 CIFAR-10/100, ImageNet 및 파인튜닝 작업에서 SGD에 비해 어떻게 성능합니까?
- RQ3이웃 크기 rho와 m-샤프니스 변형이 성능 및 일반화에 미치는 영향은 무엇입니까?
- RQ4SAM은 노이즈 라벨 메서드와 비교해 라벨 노이즈에 대한 견고성을 제공합니까?
주요 결과
| 모델 | 증강 | CIFAR-10 (SAM) | CIFAR-10 (SGD) | CIFAR-100 (SAM) | CIFAR-100 (SGD) |
|---|---|---|---|---|---|
| WRN-28-10 (200 epochs) | Basic | 2.7±0.1 | 3.5±0.1 | 16.5±0.2 | 18.8±0.2 |
| WRN-28-10 (200 epochs) | Cutout | 2.3±0.1 | 2.6±0.1 | 14.9±0.2 | 16.9±0.1 |
| WRN-28-10 (200 epochs) | AA | 2.1±<0.1 | 2.3±0.1 | 13.6±0.2 | 15.8±0.2 |
| WRN-28-10 (1800 epochs) | Basic | 2.4±0.1 | 3.5±0.1 | 16.3±0.2 | 19.1±0.1 |
| WRN-28-10 (1800 epochs) | Cutout | 2.1±0.1 | 2.7±0.1 | 14.0±0.1 | 17.4±0.1 |
| WRN-28-10 (1800 epochs) | AA | 1.6±<0.1 | 1.9±<0.1 | 11.3±0.1 | 14.6±0.1 |
| Shake-Shake (26 2x96d) | Basic | 2.3±<0.1 | 2.7±0.1 | 15.1±0.1 | 17.0±0.1 |
| Shake-Shake (26 2x96d) | Cutout | 2.0±<0.1 | 2.3±0.1 | 14.2±0.2 | 15.7±0.2 |
| Shake-Shake (26 2x96d) | AA | 1.6±<0.1 | 1.9±0.1 | 12.8±0.1 | 14.1±0.2 |
| PyramidNet | Basic | 2.7±0.1 | 4.0±0.1 | 14.6±0.4 | 19.7±0.3 |
| PyramidNet | Cutout | 1.9±0.1 | 2.5±0.1 | 12.6±0.2 | 16.4±0.1 |
| PyramidNet | AA | 1.6±0.1 | 1.9±0.1 | 11.6±0.1 | 14.6±0.1 |
| PyramidNet+ShakeDrop | Basic | 2.1±0.1 | 2.5±0.1 | 13.3±0.2 | 14.5±0.1 |
| PyramidNet+ShakeDrop | Cutout | 1.6±<0.1 | 1.9±0.1 | 11.3±0.1 | 11.8±0.2 |
| PyramidNet+ShakeDrop | AA | 1.4±<0.1 | 1.6±<0.1 | 10.3±0.1 | 10.6±0.1 |
- SAM은 표준 SGD에 비해 CIFAR-10/100, ImageNet 및 파인튜닝 작업에서 일반화를 일관되게 향상시킵니다.
- CIFAR-10/100에서 SAM은 여러 모델과 증강(예: WRN, Shake-Shake, PyramidNet 및 조합)에서 최첨단과 유사한 결과를 달성하며, 예를 들어 특정 설정에서 CIFAR-10의 오류율은 1.6% 수준, AA가 있는 일부 구성에서 CIFAR-100은 11.3% 수준입니다.
- SAM은 명시적 노이즈 라벨 전략이 없는 baselines를 능가하는 경우가 많으며, 노이즈 라벨에 대한 견고성을 제공합니다.
- m-샤프니스 변형은 더 작은 m(가속기 하위 배치)이 일반화에 더 좋고 실제 일반화 격차와 더 강한 상관관계를 보임을 보여줍니다.
- ImageNet에서 SAM으로 학습된 ResNet 변종은 상위-1 및 상위-5 정확도를 향상시키며, 예를 들어 ResNet-152의 상위-1 오차가 400 에포크에서 SAM 없이 20.3%에서 18.4%로 감소합니다.
- 해essian 분석은 SAM이 곡률이 훨씬 낮은 최솟값으로 수렴하고(예: lambda_max가 SAM은 약 1.0, 비SAM은 약 24), 더 완만한 스펙트럼을 가짐을 확인합니다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.