[论文解读] Transformers Can Do Bayesian Inference
该论文提出先验数据拟合网络(PFNs),一种利用Transformer通过在从先验分布中抽取的数据点集合输入上进行上下文学习来执行贝叶斯推断的方法。PFNs在MCMC和NUTS基础上实现高达200倍的加速,近似高斯过程的后验预测分布,并在表格回归、少样本图像分类和贝叶斯神经网络等多种任务中表现出色。
Currently, it is hard to reap the benefits of deep learning for Bayesian methods, which allow the explicit specification of prior knowledge and accurately capture model uncertainty. We present Prior-Data Fitted Networks (PFNs). PFNs leverage in-context learning in large-scale machine learning techniques to approximate a large set of posteriors. The only requirement for PFNs to work is the ability to sample from a prior distribution over supervised learning tasks (or functions). Our method restates the objective of posterior approximation as a supervised classification problem with a set-valued input: it repeatedly draws a task (or function) from the prior, draws a set of data points and their labels from it, masks one of the labels and learns to make probabilistic predictions for it based on the set-valued input of the rest of the data points. Presented with a set of samples from a new supervised learning task as input, PFNs make probabilistic predictions for arbitrary other data points in a single forward propagation, having learned to approximate Bayesian inference. We demonstrate that PFNs can near-perfectly mimic Gaussian processes and also enable efficient Bayesian inference for intractable problems, with over 200-fold speedups in multiple setups compared to current methods. We obtain strong results in very diverse areas such as Gaussian process regression, Bayesian neural networks, classification for small tabular data sets, and few-shot image classification, demonstrating the generality of PFNs. Code and trained PFNs are released at https://github.com/automl/TransformersCanDoBayesianInference.
研究动机与目标
- 通过利用深度学习实现高效的后验近似,解决在低数据场景下应用贝叶斯方法的挑战。
- 克服高斯过程和贝叶斯神经网络等复杂模型中精确贝叶斯推断的不可计算性。
- 开发一种通用框架,使任何可采样的先验分布均可用于后验近似。
- 仅通过先验的采样机制,无需解析形式或复杂近似,实现可扩展、可微分且灵活的贝叶斯推断。
- 在多种任务中验证该方法的有效性,包括小样本表格数据、少样本学习和回归任务,具备优异的校准性和不确定性量化能力。
提出的方法
- PFNs将后验近似重构为具有集合输入的监督分类问题:对于每个训练任务,从先验中采样一个函数,收集一组(x, y)对,掩码其中一个标签,并训练模型以概率方式预测被掩码的标签。
- PFN的输入是一组(x, y)对,其中一个标签被掩码;模型利用Transformer架构中的注意力机制学习预测缺失标签的分布。
- PFN通过在多个采样任务上最小化−∑ log qθ(y_test|x_test, D_train)的对数似然目标进行端到端训练。
- 推理时,将真实数据集D_train和测试点x_test输入已训练的PFN,模型在一次前向传播中输出完整的预测分布qθ*(y_test|x_test, D_train)。
- 为回归任务引入一种新型预测分布,实现对连续输出的正确不确定性估计。
- 该方法支持灵活的先验:任何可采样的分布(包括不可计算或复杂的先验,如贝叶斯神经网络或高斯过程)均可使用。
实验结果
研究问题
- RQ1Transformer能否在无需后验解析形式的情况下,有效用于贝叶斯推断中后验预测分布的近似?
- RQ2在多种低数据机器学习任务中,Transformer的上下文学习能力在多大程度上可被用于执行贝叶斯推断?
- RQ3在准确性、速度和校准性方面,PFNs与MCMC+NUTS和SVI+Bayes-by-Backprop等成熟贝叶斯基线方法相比表现如何?
- RQ4PFNs能否在不同类型的先验(包括不可计算或标准方法难以近似的先验)上实现泛化?
- RQ5架构选择(如注意力头数量、位置编码、激活函数)对PFNs在后验近似中的性能有何影响?
主要发现
- PFNs在Dionis数据集上实现0.981的平均AUC,在jannis数据集上实现0.996的平均AUC,近乎完美地复现了高斯过程的预测结果,优于所有基线方法。
- PFN-BNN在21个表格数据集上平均AUC达到0.855(每类30个训练样本),显著优于T-BNN(0.654)及其他基线,在准确性和校准性方面均表现更优。
- 与MCMC+NUTS相比,PFNs在推理时间上实现高达200倍的加速,相同基准测试中,GPU上仅需13秒,而NUTS耗时超过12小时。
- 该方法具备优异的校准性,PFN-BNN的期望校准误差(ECE)为0.025,显著低于基线逻辑回归模型的0.157。
- 在Omniglot少样本图像分类任务中,PFNs实现0.981的平均AUC,优于KNN(0.871)和CatBoost(0.945),表明其在少样本设置下具有强大的泛化能力。
- PFN框架可泛化至多种先验:仅通过先验的采样机制,成功近似了高斯过程、贝叶斯神经网络及其他不可计算模型的后验分布。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。