[论文解读] TabPFN: A Transformer That Solves Small Tabular Classification Problems in a Second
TabPFN 是一个预训练的 Transformer,在不进行超参数调优的情况下,在一秒内完成小型表格分类,并通过在上下文学习中逼近贝叶斯后验预测分布,匹配数值数据集上的最先进 AutoML 性能。
We present TabPFN, a trained Transformer that can do supervised classification for small tabular datasets in less than a second, needs no hyperparameter tuning and is competitive with state-of-the-art classification methods. TabPFN performs in-context learning (ICL), it learns to make predictions using sequences of labeled examples (x, f(x)) given in the input, without requiring further parameter updates. TabPFN is fully entailed in the weights of our network, which accepts training and test samples as a set-valued input and yields predictions for the entire test set in a single forward pass. TabPFN is a Prior-Data Fitted Network (PFN) and is trained offline once, to approximate Bayesian inference on synthetic datasets drawn from our prior. This prior incorporates ideas from causal reasoning: It entails a large space of structural causal models with a preference for simple structures. On the 18 datasets in the OpenML-CC18 suite that contain up to 1 000 training data points, up to 100 purely numerical features without missing values, and up to 10 classes, we show that our method clearly outperforms boosted trees and performs on par with complex state-of-the-art AutoML systems with up to 230$ imes$ speedup. This increases to a 5 700$ imes$ speedup when using a GPU. We also validate these results on an additional 67 small numerical datasets from OpenML. We provide all our code, the trained TabPFN, an interactive browser demo and a Colab notebook at https://github.com/automl/TabPFN.
研究动机与目标
- 证明单个预训练 Transformer 在没有数据集特定调优的情况下,能够在不到一秒的时间内解决小型表格分类任务。
- 开发离线训练的 Prior-Data Fitted Network (PFN),以在表格数据先验下近似贝叶斯推断。
- 结合对因果关系有感知的先验(SCMs 和 BNNs),以建模表格数据的多样化生成机制。
- 展示 TabPFN 在 OpenML-CC18 数值数据集上优于提升树并与 AutoML 系统具有竞争力。
- 提供开源代码、预训练的 TabPFN 和演示,以实现复现和社区验证。
提出的方法
- 训练一个 12 层 Transformer 作为 PFN,在新的表格先验下近似后验预测分布。
- 从结构化因果模型(SCMs)和贝叶斯神经网络(BNNs)的混合中构建先验,以建模简单、因果和多样化的数据生成过程。
- 离线在从先验生成的合成数据集上进行训练,以在保留未见合成点上最小化交叉熵,从而实现一次性在线预测。
- 在推断阶段,将训练集和测试特征作为集合值输入,以单次前向传递获得 PPD 预测。
- 实现对变长训练数据和测试样本的置换不变处理,使用零填充以适应不同的特征数量。
- 可选地结合 32 次前向传递并进行数据变换以提高稳定性。
实验结果
研究问题
- RQ1单个预训练 Transformer 是否能够在不进行逐数据集调优的情况下,学习对小型表格数据集执行贝叶斯风格后验预测推断?
- RQ2基于 SCMs 和 BNNs 的先验是否有助于产生更简单的因果解释并提高对小型表格数据的预测性能?
- RQ3在严格意义上的小型数值型表格数据集上,TabPFN 在准确性和速度方面与提升树及 AutoML 系统相比如何?
- RQ4关于分类特征和缺失值,TabPFN 存在哪些限制?集成或先验调整是否可以缓解这些问题?
主要发现
- TabPFN 在 OpenML-CC18 数值数据集上与最先进的 AutoML 系统的准确性相当,数据集包含最多 1,000 个训练点和 100 个数值特征,且每个数据集在不到一秒内完成。
- TabPFN 在小型数据集预测方面比基于 CPU 的 AutoML 流水线快得多(约 230 倍),比基于 GPU 的加速高达约 5,700 倍。
- 该方法在包含分类特征或缺失值的数据集上通常表现较差,但与其他方法进行集成的 TabPFN 可以带来进一步提升。
- TabPFN 的归纳偏置 toward 简单、因果解释,在预测的定性分析和鲁棒性检查中得到证明的益处。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。