[论文解读] Practical Deep Heteroskedastic Regression
这篇论文提出一种在保留数据集的后验线性方差头,用于深层异方差回归中的不确定性校准,利用中间潜在表示,在QM9和OMol25数据集上实现了具有竞争力甚至更优的不确定性量化,同时保持均值预测准确性。
Uncertainty quantification (UQ) in deep learning regression is of wide interest, as it supports critical applications including sequential decision making and risk-sensitive tasks. In heteroskedastic regression, where the uncertainty of the target depends on the input, a common approach is to train a neural network that parameterizes the mean and the variance of the predictive distribution. Still, training deep heteroskedastic regression models poses practical challenges in the trade-off between uncertainty quantification and mean prediction, such as optimization difficulties, representation collapse, and variance overfitting. In this work we identify previously undiscussed fallacies and propose a simple and efficient procedure that addresses these challenges jointly by post-hoc fitting a variance model across the intermediate layers of a pretrained network on a hold-out dataset. We demonstrate that our method achieves on-par or state-of-the-art uncertainty quantification on several molecular graph datasets, without compromising mean prediction accuracy and remaining cheap to use at prediction time.
研究动机与目标
- 在训练深层异方差回归模型时识别核心挑战。
- 提出一个在保留数据上拟合的实用后验方差头。
- 利用中间潜在表示来预测方差并实现集成。
- 在分子数据集上证明在保持均值预测质量的同时改善不确定性量化。
提出的方法
- 像往常一样训练均值预测器并保持其参数固定。
- 附加一个线性方差头,该头以中间潜在表示 zl 为输入。
- 将 σ^2ϕ(x*) 计算为在选定潜在层上的线性投影之和: σ^2ϕ(x*) = sp Σl∈Lσ Wl^T zl(x*).
- 在保留数据集上使用负对数似然损失拟合方差头,与均值训练解耦。
- 可选地通过对多个潜在表示特定估计器取平均,形成高斯混合: p(y*|x*) = (1/|Lσ|) Σl∈Lσ N(y*|μθ(x*), σl(x*)^2).
- 将从单个表示学习的方差估计器进行集成,以提升标定和鲁棒性。
实验结果
研究问题
- RQ1后验、保留数据校准的方差头是否能提供与端到端均值-方差训练相当或更优的不确定性估计?
- RQ2使用中间潜在表示是否比仅使用最终潜在表示更能改善方差预测?
- RQ3表示选择和集成对分子属性任务的标定度量和负对数似然(NLL)有何影响?
- RQ4该方法是否可扩展到大型预训练模型和数据集,同时不影响均值准确性或预测速度?
主要发现
- 后验方差集成在NLL上通常与端到端的均值-方差模型不相上下甚至优于,同时保持均值 MAE 表现。
- 使用较早的潜在表示进行方差预测通常比使用后期表示更有效;跨表示的集成可获得最佳结果。
- 该方法在预测阶段开销极小且几乎无额外超参数,利用保留数据进行校准,保持实用性。
- 该方法从 QM9 泛化到大型 OMol25 预训练模型,实现了标定的不确定性估计,并在基线上显著提升 NLL。
- 后验集成曲线与理想观测结果对齐,表明在主动学习或贝叶斯优化中具有可靠的不确定性排序。
- 对方差预测器的集成(高斯混合)对异常值和小型保留集具有鲁棒性,但在标定(ECE)与尖锐度之间存在权衡。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。