Skip to main content
QUICK REVIEW

[论文解读] Practical Deep Learning with Bayesian Principles

Kazuki Osawa, Siddharth Swaroop|arXiv (Cornell University)|Jun 6, 2019
Domain Adaptation and Few-Shot Learning参考文献 50被引用 50
一句话总结

该论文展示了使用自然梯度变分推断(VOGN)进行深度网络的实用训练,达到在 CIFAR-10 和 ImageNet 上对 Adam/SGD 的竞争性能,同时保留贝叶斯的好处,如校准的预测与更好的 OOD 不确定性以及持续学习。

ABSTRACT

Bayesian methods promise to fix many shortcomings of deep learning, but they are impractical and rarely match the performance of standard methods, let alone improve them. In this paper, we demonstrate practical training of deep networks with natural-gradient variational inference. By applying techniques such as batch normalisation, data augmentation, and distributed training, we achieve similar performance in about the same number of epochs as the Adam optimiser, even on large datasets such as ImageNet. Importantly, the benefits of Bayesian principles are preserved: predictive probabilities are well-calibrated, uncertainties on out-of-distribution data are improved, and continual-learning performance is boosted. This work enables practical deep learning while preserving benefits of Bayesian principles. A PyTorch implementation is available as a plug-and-play optimiser.

研究动机与目标

  • 通过解决可扩展性和性能差距,推动并实现实用的贝叶斯深度学习。
  • 展示自然梯度变分推断(VOGN)可以利用标准深度学习技巧(批量归一化、数据增强、分布式训练)高效训练大型网络。
  • 展示保留的贝叶斯优势:校准的预测概率、改进的分布外不确定性,以及改进的持续学习行为。
  • 提供跨多种体系结构和数据集(CIFAR-10、ImageNet)的实证证据,显示与非贝叶斯基线竞争的性能。

提出的方法

  • 通过对高斯后验 q(w) 的变分推断,将深度学习公式化为贝叶斯推断。
  • 对 VI 使用自然梯度更新,得到的更新形式与 SG/DL 优化器类似(VOGN)。
  • 采用批量归一化、数据增强、动量和分布式训练来加速收敛。
  • 采用基于高斯-牛顿的方差更新(对角 Sigma)以获得实用的二阶 VI 方法。
  • 引入数据增强缩放(rho)以补偿贝叶斯训练中的有效数据集大小。
  • 提供一种结合数据和 MC-样本并行的分布式训练方案,以扩展到 ImageNet。

实验结果

研究问题

  • RQ1自然梯度变分推断(VOGN)是否能在大规模数据集上达到与 Adam/SGD 相当的性能来训练深度网络?
  • RQ2通过 VOGN 的贝叶斯后验近似是否能在保持实际训练动态的同时,产生校准的预测和改进的分布外不确定性?
  • RQ3贝叶斯原理对顺序任务中的持续学习和知识保持有何影响?
  • RQ4标准深度学习技术(批量归一化、数据增强、分布式训练)如何与 VI 交互以实现实用的贝叶斯深度学习?
  • RQ5使用 VOGN 相对于传统优化器和 MC-dropout 的权衡有哪些(速度、校准、不确定性质量)?

主要发现

数据集/架构优化器训练/验证准确率 (%)验证 NLL训练轮次每轮时间(s)ECEAUROC
CIFAR-10/ LeNet-5 (no DA)Adam71.98 / 67.670.9372106.960.0210.794
CIFAR-10/ LeNet-5 (no DA)BBB66.84 / 64.611.01880011.430.0450.784
CIFAR-10/ LeNet-5 (no DA)MC-dropout68.41 / 67.650.9902106.950.0870.797
CIFAR-10/ AlexNet (no DA)Adam100.0 / 67.942.831613.120.2620.793
CIFAR-10/ AlexNet (no DA)MC-dropout97.56 / 72.201.0771603.250.1400.818
CIFAR-10/ AlexNetVOGN81.15 / 75.480.70316010.020.0160.832
CIFAR-10/ ResNet-18Adam97.74 / 86.000.55016011.970.0820.877
CIFAR-10/ ResNet-18MC-dropout88.23 / 82.850.51016112.510.1660.768
CIFAR-10/ ResNet-18VOGN91.62 / 84.270.47716153.140.0400.876
ImageNet/ ResNet-18SGD82.63 / 67.791.389044.130.0670.856
ImageNet/ ResNet-18Adam80.96 / 66.391.449044.400.0640.855
ImageNet/ ResNet-18MC-dropout72.96 / 65.641.439045.860.0120.856
ImageNet/ ResNet-18OGN85.33 / 65.761.609063.130.1280.854
ImageNet/ ResNet-18VOGN73.87 / 67.381.379076.040.0290.854
  • VOGN 在 CIFAR-10 和 ImageNet 上横跨多种架构实现了与 Adam/SGD 相似的收敛性和性能。
  • VOGN 提供良好校准的预测概率,并在分布外数据上相较非贝叶斯方法具有改进的不确定性。
  • 在大规模任务上,带有批量归一化和数据增强的 VOGN 在速度/训练轮次上与标准优化器相同,尽管由于 VI 计算每轮成本较高。
  • 与 BBB 和 MC-dropout 相比,VOGN 往往具有更好的校准和更低的过度自信,尤其是在 ImageNet 和 ResNet-18 上。
  • 在持续学习任务中,VOGN 在准确性方面与现有的贝叶斯持续学习方法(如 VCL)相比具有竞争力,且在某些设置下每个任务训练速度更快。
  • Table 1 显示 VOGN 在 CIFAR-10(LeNet-5、AlexNet、ResNet-18)和 ImageNet(ResNet-18)对比 Adam、SGD、MC-dropout、OGN、K-FAC、Noisy K-FAC 时达到具有竞争力或行业内最佳的指标。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。