[논문 리뷰] A Simple Baseline for Bayesian Uncertainty in Deep Learning
SWAG는 SWA 평균과 SGD 반복에서 추정된 저랭크 플러스 대각 공분산으로 구성된 확장 가능한 신경망 가중치에 대한 가우시안 사후분포를 도입하여, Bayesian 모델 평균화와 시각 작업 전반에 걸친 불확실성 보정 개선을 가능하게 한다.
We propose SWA-Gaussian (SWAG), a simple, scalable, and general purpose approach for uncertainty representation and calibration in deep learning. Stochastic Weight Averaging (SWA), which computes the first moment of stochastic gradient descent (SGD) iterates with a modified learning rate schedule, has recently been shown to improve generalization in deep learning. With SWAG, we fit a Gaussian using the SWA solution as the first moment and a low rank plus diagonal covariance also derived from the SGD iterates, forming an approximate posterior distribution over neural network weights; we then sample from this Gaussian distribution to perform Bayesian model averaging. We empirically find that SWAG approximates the shape of the true posterior, in accordance with results describing the stationary distribution of SGD iterates. Moreover, we demonstrate that SWAG performs well on a wide variety of tasks, including out of sample detection, calibration, and transfer learning, in comparison to many popular alternatives including MC dropout, KFAC Laplace, SGLD, and temperature scaling.
연구 동기 및 목표
- 고위험 영역에서의 의사결정을 돕기 위해 딥러닝에서 신뢰할 수 있는 불확실성 표현의 필요성을 제시한다.
- 네트워크 가중치의 사후를 근사하기 위해 SGD 궤적을 활용하는 확장 가능한 베이지안 추론 방법을 제안한다.
- SWA와 저랭크 플러스 대각 공분산을 결합하여 가우시안 사후를 형성하는 실용 알고리즘(SWAG)을 개발한다.
- SWAG가 시각 벤치마크 전반에서 잘 보정된 예측과 경쟁력 있거나 우수한 불확실성 추정치를 제공함을 입증한다.
제안 방법
- SWA(Steochastic Weight Averaging)를 기반으로 SWA 평균을 사후 평균으로 사용한다.
- SGD 반복의 두 번째 모멘트를 실행하며 대각 공분산을 추정한다.
- SGD 반복에서 마지막 K 편차 벡터를 사용하여 저랭크 공분산을 구성한다.
- 가우시안 사후분포 N(theta_SWA, 1/2*(Sigma_diag + Sigma_low_rank))를 형성한다.
- 예측을 위한 베이지안 모델 평균화를 수행하기 위해 가우시안에서 샘플링한다.
- 필요한 통계를 최소한의 오버헤드로 업데이트하고 저장하는 온라인 절차를 제공한다.
실험 결과
연구 질문
- RQ1SGD 궤적을 사용하여 심층 신경망에서 사후의 국지 기하를 근사할 수 있는가?
- RQ2SWAG 기반 가우시안 사후가 시각 작업 전반에서 기존 기준선보다 더 나은 불확실성 보정을 제공하는가?
- RQ3MC 드롭아웃이나 SGLD와 같은 대안에 비해 SWAG가 도메인 외 탐지 및 전이 학습에 효과적인가?
- RQ4저랭크 플러스 대각 근사가 실제로 대각선만 공분산과 비교하여 어떤 차이가 있는가?
- RQ5언어 모델링 및 회귀 벤치마크에서 보정 및 예측 성능을 더 넓은 기준으로 SWAG가 향상시킬 수 있는가?
주요 결과
- SWAG는 SGD 반복으로 확장된 부분공간에서 사후의 국소 기하를 밀접하게 포착한다.
- SWAG는 CIFAR-10/100 및 ImageNet에서 여러 기준선보다 잘 보정된 불확실성 추정과 더 높은 테스트 로그가능도(test log-likelihood)를 제공한다.
- SWAG는 불확실성 보정에서 MC dropout, SGLD, KFAC-Laplace, SWA 등 많은 대안들보다 우수하다.
- SWAG는 전이 학습 성능과 도메인 외 탐지에서 여러 경쟁자에 비해 향상된다.
- SWAG는 또한 언어 모델링 퍼플렉시티(perplexities)에서 개선을 가져오고 회귀 과제에서 경쟁력 있는 결과를 보인다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.