[论文解读] Improving Output Uncertainty Estimation and Generalization in Deep Learning via Neural Network Gaussian Processes
该论文提出了一种混合模型,通过将深度神经网络(DNNs)用作高斯过程(GPs)的均值函数,结合了DNN与GP的优点,实现了精确的输出不确定性估计和改进的泛化能力。该方法采用随机变分推断和随机梯度下降进行可扩展训练,在真实世界时空数据集上的不确定性校准和预测准确性方面,均优于独立的DNN和GP。
We propose a simple method that combines neural networks and Gaussian processes. The proposed method can estimate the uncertainty of outputs and flexibly adjust target functions where training data exist, which are advantages of Gaussian processes. The proposed method can also achieve high generalization performance for unseen input configurations, which is an advantage of neural networks. With the proposed method, neural networks are used for the mean functions of Gaussian processes. We present a scalable stochastic inference procedure, where sparse Gaussian processes are inferred by stochastic variational inference, and the parameters of neural networks and kernels are estimated by stochastic gradient descent methods, simultaneously. We use two real-world spatio-temporal data sets to demonstrate experimentally that the proposed method achieves better uncertainty estimation and generalization performance than neural networks and Gaussian processes.
研究动机与目标
- 为解决深度神经网络在安全关键应用中缺乏可靠的输出不确定性估计的问题。
- 通过利用深度神经网络的表征能力,克服高斯过程在数据稀疏区域的泛化能力差的问题。
- 开发一种可扩展的推断方法,使此类混合模型能够在精确GP推断不可行的大规模数据集上进行训练。
- 结合深度学习(对未见输入的泛化能力)和高斯过程(灵活的局部插值与不确定性量化)的优势。
- 通过实证结果证明,该混合模型在不确定性估计和点预测方面均优于独立的DNN和GP。
提出的方法
- 所提出的方法使用深度神经网络作为高斯过程的均值函数,从而实现灵活、数据驱动的均值预测。
- 在非线性函数上施加高斯过程先验,实现对预测结果的贝叶斯推断与不确定性量化。
- 采用带诱导点的稀疏高斯过程以降低计算复杂度,实现对大规模数据集的可扩展性。
- 应用随机变分推断以近似GP函数的后验分布,支持小批量训练。
- 使用随机梯度下降联合优化神经网络参数与核超参数。
- 该方法支持通过GP推断和DNN组件的反向传播实现端到端训练。
实验结果
研究问题
- RQ1能否有效将深度神经网络用作高斯过程的均值函数,以改善不确定性估计?
- RQ2将深度神经网络的表征能力与高斯过程的不确定性量化能力相结合,是否能提升在未见数据上的泛化能力?
- RQ3能否开发一种可扩展的推断方法,支持此类混合模型的大规模训练?
- RQ4在预测准确性和不确定性校准方面,该方法与独立的DNN和GP相比表现如何?
- RQ5模型架构、核函数选择以及诱导点数量对性能有何影响?
主要发现
- 在USHCN数据集上,该方法在所有缺失数据情景下均达到最低的测试均方误差(0.041),优于GP(0.054)和NN(0.048)基线。
- 在CC数据集上,该方法实现了最佳的不确定性校准,95%置信水平下的平均覆盖率为0.355,优于GP(0.412)和NN(0.364)。
- 在USHCN数据集上,当缺失数据比例为90%时,该方法将均方误差较GP降低最多25%,较NN降低最多15%。
- 即使在数据稀疏区域,该模型仍保持了较高的不确定性校准能力,在CC数据集80%缺失数据情况下,95%预测区间覆盖了93.8%的真实值。
- 计算时间具有竞争力:在USHCN数据集95%缺失数据情况下,该方法耗时1374秒,优于GP的1854秒,也快于NN的226秒。
- 该模型对缺失数据表现出鲁棒性,在两个数据集的所有缺失水平(50%、80%、95%)下均保持一致的性能。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。