[论文解读] SLANG: Fast Structured Covariance Approximations for Bayesian Deep Learning with Natural Gradient
SLANG 提出了一种快速、随机、低秩的近似自然梯度方法,用于贝叶斯深度学习中的变分推断,仅通过网络对数似然的反向传播梯度来估计结构化协方差矩阵(对角加低秩)。与平均场方法相比,该方法实现了更快的收敛速度和更准确的不确定性估计,在标准基准测试上的性能与当前最先进方法相当。
Uncertainty estimation in large deep-learning models is a computationally challenging task, where it is difficult to form even a Gaussian approximation to the posterior distribution. In such situations, existing methods usually resort to a diagonal approximation of the covariance matrix despite, the fact that these matrices are known to result in poor uncertainty estimates. To address this issue, we propose a new stochastic, low-rank, approximate natural-gradient (SLANG) method for variational inference in large, deep models. Our method estimates a "diagonal plus low-rank" structure based solely on back-propagated gradients of the network log-likelihood. This requires strictly less gradient computations than methods that compute the gradient of the whole variational objective. Empirical evaluations on standard benchmarks confirm that SLANG enables faster and more accurate estimation of uncertainty than mean-field methods, and performs comparably to state-of-the-art methods.
研究动机与目标
- 为解决大规模深度神经网络中高效且准确的不确定性估计挑战。
- 克服平均场变分推断的局限性,后者由于对角协方差近似而低估不确定性。
- 开发一种可扩展至深度模型的方法,同时保持低内存和计算成本。
- 在无需完整变分目标梯度的情况下,实现结构化协方差近似(对角加低秩)。
- 在计算开销更低的前提下,实现与当前最先进方法相当的性能。
提出的方法
- SLANG 使用一种近似自然梯度算法来优化变分参数,仅依赖于网络对数似然的反向传播梯度。
- 它估计一个由对角项和低秩分量组成的结构化协方差矩阵,直接从梯度统计中学习。
- 该方法避免计算完整变分目标的梯度,相比重参数化方法显著降低了计算成本。
- 它采用一种随机、迭代的优化方案,通过小批量梯度逐步构建协方差近似。
- 该算法使用自适应学习率和动量,超参数通过贝叶斯优化和交叉验证进行调优。
- 该方法应用于使用全批量或小批量训练的贝叶斯神经网络,通过蒙特卡洛采样进行推理。
实验结果
研究问题
- RQ1与平均场近似相比,低秩加对角协方差结构是否能改善深度贝叶斯神经网络中的不确定性估计?
- RQ2是否可以通过仅使用对数似然梯度,使自然梯度优化在大规模深度模型中计算上变得高效?
- RQ3SLANG 在标准基准测试上是否比平均场方法和当前最先进方法实现更快的收敛速度和更好的不确定性估计?
- RQ4该方法是否能有效扩展到深度网络,同时保持极低的内存和计算开销?
- RQ5SLANG 的性能如何随不同低秩维度和超参数设置而变化?
主要发现
- SLANG 显著改善了平均场方法的不确定性估计,特别是在减少方差低估方面,USPS 数据集上的结果已证明这一点。
- 在 MNIST 和 UCI 回归基准上,SLANG 的性能与当前最先进方法相当,且收敛速度更快。
- 在 MNIST 上,当 L=32 时,SLANG 达到了 97.8% 的测试准确率和 0.138 的负对数似然,优于平均场基线方法。
- 在 UCI 数据集上,SLANG 的测试 NLL 始终低于 Bayes-by-Backprop 和平均场 VI,在某些情况下改善幅度高达 15%。
- 所有数据集的最佳学习率均为 α=β=0.02154435,表明其对超参数选择具有鲁棒性。
- 由于避免了完整变分目标的梯度计算,该方法所需的梯度计算次数少于重参数化方法。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。