[論文レビュー] A Simple Baseline for Bayesian Uncertainty in Deep Learning
SWAGは、SWA平均とSGD反復から推定された低秩と対角共分散を組み合わせて構築された、ニューラルネットワーク重みのスケーラブルなGaussian事後分布を導入し、視覚タスク全体でのベイズ的モデルアベレージングと不確実性キャリブレーションの改善を可能にします。
We propose SWA-Gaussian (SWAG), a simple, scalable, and general purpose approach for uncertainty representation and calibration in deep learning. Stochastic Weight Averaging (SWA), which computes the first moment of stochastic gradient descent (SGD) iterates with a modified learning rate schedule, has recently been shown to improve generalization in deep learning. With SWAG, we fit a Gaussian using the SWA solution as the first moment and a low rank plus diagonal covariance also derived from the SGD iterates, forming an approximate posterior distribution over neural network weights; we then sample from this Gaussian distribution to perform Bayesian model averaging. We empirically find that SWAG approximates the shape of the true posterior, in accordance with results describing the stationary distribution of SGD iterates. Moreover, we demonstrate that SWAG performs well on a wide variety of tasks, including out of sample detection, calibration, and transfer learning, in comparison to many popular alternatives including MC dropout, KFAC Laplace, SGLD, and temperature scaling.
研究の動機と目的
- 高リスク領域での意思決定を支援するための深層学習における信頼性の高い不確実性表現の必要性を動機づける。
- ネットワーク重みの事後分布を近似するためにSGD軌跡を活用する、スケーラブルなベイズ推論手法を提案する。
- SWAと低ランク+対角共分散を組み合わせてガウス後方分布を形成する実用的なアルゴリズム(SWAG)を開発する。
- SWAGが視覚ベンチマーク全体で良く校正された予測と競争力のある、あるいは優れた不確実性推定をもたらすことを示す。
提案手法
- SWA(Stochastic Weight Averaging)を基盤として、SWA平均を事後平均として用いる。
- SGD反復の2次モーメントを実行中に推定して、対角共分散を見積もる。
- SGD反復からの最後のK個の偏差ベクトルを用いて低ランク共分散を構築する。
- Gaussian後方分布N(theta_SWA, 1/2*(Sigma_diag + Sigma_low_rank))を形成する。
- 予測のためのベイズ的モデルアベレージングを行うためにGaussianからサンプルをとる。
- 最小限のオーバーヘッドで必要な統計を更新・保存するオンライン手続きを提供する。
実験結果
リサーチクエスチョン
- RQ1SGD軌跡は深層ネットワークにおける事後分布の局所幾何を近似できるか。
- RQ2SWAGベースのGaussian後方分布は、視覚タスク全体で既存のベースラインよりも不確実性キャリブレーションを改善するか。
- RQ3SWAGはMCドロップアウトやSGLDなどの代替法と比較して、アウト・オブ・ドメイン検出と転移学習に有効か。
- RQ4低ランク+対角近似は実務上、対角のみの共分散と比較してどのようになるか。
- RQ5SWAGは言語モデリングや回帰ベンチマークにおける較正と予測性能を、より広いベースラインとして改善できるか。
主な発見
- SWAGはSGD反復によって張られる部分空間における後方分布の局所幾何を緊密に捉える。
- SWAGはCIFAR-10/100とImageNetで、良く校正された不確実性推定と他のいくつかのベースラインよりも高いテスト対数尤度を提供する。
- SWAGはMCドロップアウト、SGLD、KFAC-Laplace、SWAなどの多くの代替法よりも不確実性キャリブレーションで優れる。
- SWAGは転移学習性能とアウト・オブ・ドメイン検出を、いくつかの競合他社と比較して改善する。
- SWAGは言語モデリングのパープレクシティを改善し、回帰タスクで競争力のある結果を示す。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。