[論文レビュー] Scalable Marginal Likelihood Estimation for Model Selection in Deep Learning
本稿では、ラプラス法とガウス・ニュートンおよび経験的フィッシャー・ヘッシアン近似を用いた、ベイズ的ディープラーニングモデル選択のスケーラブルでオンラインな周辺尤度推定手法を提案する。訓練データのみを用いてハイパーパrameterおよびアーキテクチャ選択が可能であり、交差検証や手動チューニングに比べて、特に検証データが少ない状況下で、キャリブレーションおよびOOD検出の面で優れた性能を発揮する。
Marginal-likelihood based model-selection, even though promising, is rarely used in deep learning due to estimation difficulties. Instead, most approaches rely on validation data, which may not be readily available. In this work, we present a scalable marginal-likelihood estimation method to select both hyperparameters and network architectures, based on the training data alone. Some hyperparameters can be estimated online during training, simplifying the procedure. Our marginal-likelihood estimate is based on Laplace's method and Gauss-Newton approximations to the Hessian, and it outperforms cross-validation and manual-tuning on standard regression and image classification datasets, especially in terms of calibration and out-of-distribution detection. Our work shows that marginal likelihoods can improve generalization and be useful when validation data is unavailable (e.g., in nonstationary settings).
研究の動機と目的
- 周辺尤度推定が不確実であるため、ディープラーニングにおけるスケーラブルなベイズ的モデル選択の欠如に対処すること。
- 検証データセットに依存せずに、訓練データのみを用いてハイパーパrameterおよびアーキテクチャ選択を可能にすること。
- 現代のディープニューラルネットワークに適した、計算的に効率的でオンラインな周辺尤度推定手法を開発すること。
- 周辺尤度推定が、交差検証や手動チューニングといった標準的手法に比べて、モデルの一般化性能および不確実性キャリブレーションの面で優れていることを実証すること。
提案手法
- ヘッシアンの2次情報を利用することで、周辺尤度を近似するラプラス法を用いる。
- スケーラビリティを確保するため、一般化ガウス・ニュートン(GGN)および経験的フィッシャー(EF)近似をヘッシアンに適用する。
- ヘッシアン推定の計算コストを削減するために、対角行列およびブロック対角行列近似を適用する。
- 勾配ベースの更新を用いて、訓練中に微分可能なハイパーパrameter(例:事前分散、ノイズ分散)をオンラインで最適化する。
- 訓練後、推定された周辺尤度に基づいてモデルをランク付けすることで、離散的アーキテクチャ選択を実行する。
- Kronecker因数分解近似を10エポックごとに使用し、最小限のオーバーヘッドで標準的な訓練パイプラインに統合する。
実験結果
リサーチクエスチョン
- RQ1現代のディープラーニングモデルに適したスケーラブルで実用的な周辺尤度推定は可能か?
- RQ2検証データが利用できない状況下でも、周辺尤度が交差検証や手動チューニングを上回るモデル選択性能を示せるか?
- RQ3周辺尤度は、実世界のベンチマークにおいてテスト精度および不確実性キャリブレーションと相関しているか?
- RQ4周辺尤度推定を用いて、訓練中にハイパーパrameterをオンラインで最適化できるか?
主な発見
- 提案手法は、回帰および画像分類ベンチマークにおいて、交差検証と同等またはそれ以上の性能を発揮し、特にキャリブレーションおよびOOD検出の面で優れた結果を示した。
- CIFAR-10およびCIFAR-100では、パラメータ数が指数関数的に多いにもかかわらず、ResNetsが標準CNNよりも高い周辺尤度を達成しており、より良い一般化性能を示している。
- CIFAR-10/100では、テスト精度と周辺尤度の順位相関係数(スピアマンのρ)が97%に達し、モデル性能と強く整合していることが示された。
- FashionMNISTでは、同程度の精度を達成するもとで、CNNがMLPよりも高い周辺尤度を達成しており、より低いモデル複雑度がより良い周辺尤度をもたらす傾向があることを示唆している。
- 一般化性能が向上し、一般化ギャップが縮小され、NLLおよびECEにおいてベースライン(データオーグメンテーション付き)に比べて最大2倍の性能向上が達成された。
- オンラインモデル選択には、単一実行の2倍程度のトレーニング時間で実現可能であり、時間的効率性において交差検証を上回った。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。