[論文レビュー] Practical Deep Learning with Bayesian Principles
この論文は VOGN による自然勾配法を用いた変分推論を実践的な深いネットワーク訓練に適用し、CIFAR-10 および ImageNet において Adam/SGD に対して競合的な性能を達成すると同時に、校正された予測やOOD の不確実性改善、継続学習といったベイズ的利点を保持する。
Bayesian methods promise to fix many shortcomings of deep learning, but they are impractical and rarely match the performance of standard methods, let alone improve them. In this paper, we demonstrate practical training of deep networks with natural-gradient variational inference. By applying techniques such as batch normalisation, data augmentation, and distributed training, we achieve similar performance in about the same number of epochs as the Adam optimiser, even on large datasets such as ImageNet. Importantly, the benefits of Bayesian principles are preserved: predictive probabilities are well-calibrated, uncertainties on out-of-distribution data are improved, and continual-learning performance is boosted. This work enables practical deep learning while preserving benefits of Bayesian principles. A PyTorch implementation is available as a plug-and-play optimiser.
研究の動機と目的
- 実用的なベイズ深層学習を動機づけ、拡張性と性能のギャップを埋める。
- 自然勾配 variational inference (VOGN) が標準的な深層学習のコツ(バッチ正規化、データ拡張、分散訓練)を用いて大規模網を効率的に訓練できることを示す。
- 校正された予測確率、分布外の不確実性の改善、継続学習行動の改善といったベイズ的利点を維持すること。
- CIFAR-10、ImageNet など複数のアーキテクチャとデータセットで、非ベイズ基準と競合する性能を示すエビデンスを提供。)
提案手法
- 深層学習をGaussian後方分布 q(w) による変分推論としてベイズ推論として定式化。
- VI に自然勾配更新を用い、SG/DL オプティマの形に近い更新を得る(VOGN)。
- 収束を加速するためにバッチ正規化、データ拡張、モーメント、分散訓練を採用。
- 実用的な二次VI法を得るために対角線 Sigma を持つ Gauss-Newton ベースの分散更新を採用。
- ベイズ訓練の有効データセットサイズを補うデータ拡張スケーリング(rho)を導入。
- ImageNet 専用にデータとMCサンプルの並列性を組み合わせた分散訓練方式を提供。)
実験結果
リサーチクエスチョン
- RQ1VOGN が大規模データセットで Adam/SGD と同等の性能を持って深ネットを訓練できるか?
- RQ2VI via VOGN によるベイズ後方推定が校正された予測と分布外の不確実性の改善をもたらし、実務的訓練ダイナミクスを維持するか?
- RQ3ベイズ原理が逐次タスクの継続学習と知識保持に与える影響は何か?
- RQ4標準的な深層学習技術(バッチ正規化、データ拡張、分散訓練)は VI とどのように相互作用し、実用的なベイズ深層学習を提供するか?
- RQ5VOGN を従来のオプティマとMCドロップアウトと比較した場合のトレードオフ(速度、校正、不確実性の質)は何か?
主な発見
| データセット/アーキテクチャ | Optimiser | 訓練/検証精度(%) | Validation NLL | エポック数 | 1エポックあたりの時間(秒) | ECE | AUROC |
|---|---|---|---|---|---|---|---|
| CIFAR-10/ LeNet-5 (no DA) | Adam | 71.98 / 67.67 | 0.937 | 210 | 6.96 | 0.021 | 0.794 |
| CIFAR-10/ LeNet-5 (no DA) | BBB | 66.84 / 64.61 | 1.018 | 800 | 11.43 | 0.045 | 0.784 |
| CIFAR-10/ LeNet-5 (no DA) | MC-dropout | 68.41 / 67.65 | 0.990 | 210 | 6.95 | 0.087 | 0.797 |
| CIFAR-10/ AlexNet (no DA) | Adam | 100.0 / 67.94 | 2.83 | 161 | 3.12 | 0.262 | 0.793 |
| CIFAR-10/ AlexNet (no DA) | MC-dropout | 97.56 / 72.20 | 1.077 | 160 | 3.25 | 0.140 | 0.818 |
| CIFAR-10/ AlexNet | VOGN | 81.15 / 75.48 | 0.703 | 160 | 10.02 | 0.016 | 0.832 |
| CIFAR-10/ ResNet-18 | Adam | 97.74 / 86.00 | 0.550 | 160 | 11.97 | 0.082 | 0.877 |
| CIFAR-10/ ResNet-18 | MC-dropout | 88.23 / 82.85 | 0.510 | 161 | 12.51 | 0.166 | 0.768 |
| CIFAR-10/ ResNet-18 | VOGN | 91.62 / 84.27 | 0.477 | 161 | 53.14 | 0.040 | 0.876 |
| ImageNet/ ResNet-18 | SGD | 82.63 / 67.79 | 1.38 | 90 | 44.13 | 0.067 | 0.856 |
| ImageNet/ ResNet-18 | Adam | 80.96 / 66.39 | 1.44 | 90 | 44.40 | 0.064 | 0.855 |
| ImageNet/ ResNet-18 | MC-dropout | 72.96 / 65.64 | 1.43 | 90 | 45.86 | 0.012 | 0.856 |
| ImageNet/ ResNet-18 | OGN | 85.33 / 65.76 | 1.60 | 90 | 63.13 | 0.128 | 0.854 |
| ImageNet/ ResNet-18 | VOGN | 73.87 / 67.38 | 1.37 | 90 | 76.04 | 0.029 | 0.854 |
- VOGN は CIFAR-10 と ImageNet の複数のアーキテクチャで Adam/SGD と同等の収束と性能を達成。
- VOGN は well-calibrated な予測確率と分布外データでの不確実性改善を提供。
- バッチ正規化とデータ拡張を用いた VOGN は大規模タスクで標準オプティマと同等のスピード/エポックを実現するが、VI 計算のため1エポックあたりのコストは高い。
- BBB や MC-dropout と比べて、VOGN は特に ImageNet と ResNet-18 でキャリブレーションが良く、過信が少ない。
- 継続学習タスクで、VOGN は既存のベイズ継続学習法(例: VCL)と比べて精度で競争力があり、一部設定でタスクごとの訓練が速い。
- Table 1 は CIFAR-10(LeNet-5, AlexNet, ResNet-18)と ImageNet(ResNet-18)に対して、Adam、SGD、MC-dropout、OGN、K-FAC、Noisy K-FAC に対して競合的または最高クラスの指標を示す。)
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。