Skip to main content
QUICK REVIEW

[论文解读] Predictive Uncertainty Estimation via Prior Networks

Andrey Malinin, Mark Gales|arXiv (Cornell University)|Feb 28, 2018
Adversarial Robustness in Machine Learning参考文献 28被引用 357
一句话总结

这篇论文引入 Prior Networks (PNs) 来明确建模分布不确定性,与数据不确定性和模型不确定性分离,从而在分布外检测和错误分类检测方面实现更好的效果,并将 Dirichlet Prior Networks (DPNs) 应用于 MNIST 和 CIFAR-10。

ABSTRACT

Estimating how uncertain an AI system is in its predictions is important to improve the safety of such systems. Uncertainty in predictive can result from uncertainty in model parameters, irreducible data uncertainty and uncertainty due to distributional mismatch between the test and training data distributions. Different actions might be taken depending on the source of the uncertainty so it is important to be able to distinguish between them. Recently, baseline tasks and metrics have been defined and several practical methods to estimate uncertainty developed. These methods, however, attempt to model uncertainty due to distributional mismatch either implicitly through model uncertainty or as data uncertainty. This work proposes a new framework for modeling predictive uncertainty called Prior Networks (PNs) which explicitly models distributional uncertainty. PNs do this by parameterizing a prior distribution over predictive distributions. This work focuses on uncertainty for classification and evaluates PNs on the tasks of identifying out-of-distribution (OOD) samples and detecting misclassification on the MNIST dataset, where they are found to outperform previous methods. Experiments on synthetic and MNIST and CIFAR-10 data show that unlike previous non-Bayesian methods PNs are able to distinguish between data and distributional uncertainty.

研究动机与目标

  • 动机:分离三种预测不确定性的需求:模型不确定性( epistemic)、数据不确定性( aleatoric)以及分布不确定性(数据集位移)。
  • 提出 Prior Networks,使其参数化对预测分布的分布,以隔离分布不确定性。
  • 开发并评估 Dirichlet Prior Networks (DPNs) 用于分类任务,重点是 OOD 检测和错分检测。
  • 给出源自 PN 框架的不确定性度量,并与贝叶斯/多模型基线进行比较。

提出的方法

  • 引入 Prior Networks (PNs),它明确建模对预测分布 p(mu|x, theta) 的分布。
  • 使用 Dirichlet 分布来参数化 p(mu|x; theta),其中 alpha = f(x; theta),实现对内在域预测的尖锐边角,以及对分布外输入的平坦先验。
  • 通过多任务目标训练 Dirichlet Prior Networks (DPNs),使其最小化对 in-domain 数据的 sharp Dirichlet 目标的 KL 散度,以及对 out-of-domain 数据的 flat Dirichlet 目标的 KL 散度(eq. 12)。
  • 对同分布内目标进行正则化和平滑(eq. 15),以避免 delta-function 目标,并可选地使用 teacher-student 平滑。
  • 讨论 PN 分层的不同边缘化(数据、分布不确定性、模型不确定性),并从这些边缘化推导不确定性度量(熵、互信息)。
  • 在合成数据、MNIST 和 CIFAR-10 上评估 PN/Dirichlet PN,并与标准 DNNs 和 MC-Dropout 集成进行比较。

实验结果

研究问题

  • RQ1Prior Networks 能否在分类任务中分别建模数据不确定性、分布不确定性和模型不确定性?
  • RQ2相较于如 DNNs 和 MC-Dropout 集成等基线,Dirichlet Prior Networks 是否改善分布外检测和错误分类检测?
  • RQ3在 PN 框架下,哪些不确定性度量(熵、互信息、微分熵)最能反映不同来源的不确定性?
  • RQ4PN 基方法在 MNIST 和 CIFAR-10 上的表现如何,包括嘈杂/增强场景以及各种真实世界 OOD 数据集?

主要发现

  • Dirichlet Prior Networks 提供的分布不确定性估计在 MNIST/CIFAR-10 的 OOD 检测中比 MC-Dropout 和标准 DNNs 更加准确。
  • PNs 在 MNIST 和 CIFAR-10 的错误分类检测上优于基线。
  • Dirichlet prior 的微分熵在类别区分弱或嘈杂时对 OOD 检测特别有效。
  • 来自 PN 框架的不确定性度量可以在测试时解析计算,成本低于集成。
  • 在合成数据下,当类别重叠度高时,PN 的区分同分布与非同分布的能力有所提升,这是标准熵度量所不及。
  • 熵和最大后验概率仍是强有力的简单指标,在某些 OOD 场景(尤其是类别不那么鲜明时)微分熵具有优势。)

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。