Skip to main content
QUICK REVIEW

[论文解读] Scalable Marginal Likelihood Estimation for Model Selection in Deep Learning

Alexander Immer, Matthias Bauer|arXiv (Cornell University)|Apr 11, 2021
Machine Learning and Data Classification被引用 24
一句话总结

该论文提出了一种可扩展的、在线的边缘似然估计方法,用于贝叶斯深度学习模型选择,采用拉普拉斯方法结合高斯-牛顿(Gauss-Newton)和经验费舍尔(Empirical Fisher)海森矩阵近似。该方法仅使用训练数据即可实现超参数和网络结构的选择,在校准和分布外(OOD)检测方面优于交叉验证和人工调参,尤其在验证数据较少的场景下表现更优。

ABSTRACT

Marginal-likelihood based model-selection, even though promising, is rarely used in deep learning due to estimation difficulties. Instead, most approaches rely on validation data, which may not be readily available. In this work, we present a scalable marginal-likelihood estimation method to select both hyperparameters and network architectures, based on the training data alone. Some hyperparameters can be estimated online during training, simplifying the procedure. Our marginal-likelihood estimate is based on Laplace's method and Gauss-Newton approximations to the Hessian, and it outperforms cross-validation and manual-tuning on standard regression and image classification datasets, especially in terms of calibration and out-of-distribution detection. Our work shows that marginal likelihoods can improve generalization and be useful when validation data is unavailable (e.g., in nonstationary settings).

研究动机与目标

  • 解决由于边缘似然估计不可计算而导致的深度学习中可扩展的贝叶斯模型选择问题。
  • 仅使用训练数据实现超参数和网络结构的选择,避免对验证集的依赖。
  • 开发一种计算高效的在线方法,用于边缘似然估计,适用于现代深度神经网络。
  • 证明边缘似然估计在模型泛化和不确定性校准方面可优于标准方法(如交叉验证和人工调参)。

提出的方法

  • 使用拉普拉斯方法近似边缘似然,利用海森矩阵的二阶信息。
  • 采用广义高斯-牛顿(Generalized Gauss-Newton, GGN)和经验费舍尔(Empirical Fisher, EF)近似来估计海森矩阵,以实现可扩展性。
  • 应用对角矩阵和块对角矩阵近似,降低海森矩阵估计的计算成本。
  • 通过基于梯度的更新,在训练过程中在线优化可微分的超参数(如先验方差、噪声方差)。
  • 在训练结束后,通过基于估计边缘似然的排序完成离散的网络结构选择。
  • 将该方法集成到标准训练流程中,仅引入少量开销,每 F=10 个周期使用克罗内克分解近似。

实验结果

研究问题

  • RQ1边缘似然估计能否被改进为适用于现代深度学习模型的可扩展且实用的方法?
  • RQ2在缺乏验证数据的情况下,边缘似然估计能否优于交叉验证和人工调参进行模型选择?
  • RQ3边缘似然是否与真实世界基准中的测试准确率和不确定性校准存在相关性?
  • RQ4能否在训练过程中通过边缘似然估计在线优化超参数?

主要发现

  • 所提方法在回归和图像分类基准上的表现与交叉验证相当或更优,尤其在校准和分布外(OOD)检测方面表现突出。
  • 在 CIFAR-10 和 CIFAR-100 上,ResNets 的边缘似然高于标准 CNN,即使参数量呈指数级增长,表明其泛化能力更强。
  • 在 CIFAR-10/100 上,测试准确率与边缘似然的等级相关系数达到 97%(Spearman’s ρ),表明与模型性能高度一致。
  • 在 FashionMNIST 上,CNN 的边缘似然高于 MLP,即使两者准确率相近,表明更低的模型复杂度更有利于获得更高的边缘似然。
  • 该方法提升了泛化能力并减小了泛化差距,在 NLL 和 ECE 指标上相比基线(使用数据增强)最高提升达 2 倍。
  • 使用该方法进行在线模型选择的耗时仅约为单次训练的 2 倍,时间效率优于交叉验证。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。