[论文解读] TabTransformer: Tabular Data Modeling Using Contextual Embeddings
TabTransformer 使用 Transformer 层的上下文嵌入来建模表格数据,在准确性方面优于 ML 基线并可媲美 GBDT,对于缺失/嘈杂数据具有鲁棒处理,并采用两阶段半监督预训练方法。
We propose TabTransformer, a novel deep tabular data modeling architecture for supervised and semi-supervised learning. The TabTransformer is built upon self-attention based Transformers. The Transformer layers transform the embeddings of categorical features into robust contextual embeddings to achieve higher prediction accuracy. Through extensive experiments on fifteen publicly available datasets, we show that the TabTransformer outperforms the state-of-the-art deep learning methods for tabular data by at least 1.0% on mean AUC, and matches the performance of tree-based ensemble models. Furthermore, we demonstrate that the contextual embeddings learned from TabTransformer are highly robust against both missing and noisy data features, and provide better interpretability. Lastly, for the semi-supervised setting we develop an unsupervised pre-training procedure to learn data-driven contextual embeddings, resulting in an average 2.1% AUC lift over the state-of-the-art methods.
研究动机与目标
- 通过为分类特征学习上下文嵌入,缩小表格数据上 MLP 与梯度提升决策树(GBDT)之间的性能差距。
- 利用 Transformer 基于自注意力,将列嵌入转换为上下文表示,以提高预测准确性。
- 展示对缺失和嘈杂分类特征的鲁棒性,并对学习到的嵌入进行可解释性分析。
- 提出一个两阶段半监督学习流程(在未标注数据上进行预训练,然后进行微调),以在标注数据稀缺时提升性能。
提出的方法
- 为每个分类特征嵌入一个专用的列嵌入表,其中包含一个缺失值嵌入。
- 将嵌入序列通过 N 层 Transformer 处理(多头自注意力后接前馈块)。
- 将来自顶层 Transformer 的上下文嵌入与连续特征拼接后送入 MLP 进行最终预测。
- 可选地在未标注数据上使用 MLM(掩码语言建模)或 RTD(替换单词检测)任务对 Transformer 层进行预训练,然后再用带标注的数据进行微调。
- 通过基于梯度的学习端到端优化,以最小化标准监督损失(分类的交叉熵,回归的均方误差)。
- 在半监督设置中,执行两阶段工作流:(i)在未标注数据上进行预训练,(ii)在带标注数据上进行微调。
实验结果
研究问题
- RQ1基于 Transformer 的分类特征上下文嵌入是否能在表格数据上超越传统的 MLP?
- RQ2相对于基线神经模型,上下文嵌入是否对缺失和嘈杂的分类特征具有鲁棒性?
- RQ3在多样化数据集上,TabTransformer 相对于树模型(GBDT)及其他深度表格模型的表现如何?
- RQ4在标签数据有限的情况下,两阶段半监督预训练/微调流程是否在 AUC 上带来可衡量的提升?
主要发现
| 模型名称 | 平均 AUC (%) | 标准差 (%) |
|---|---|---|
| TabTransformer | 82.8 | 0.4 |
| MLP | 81.8 | 0.4 |
| GBDT | 82.9 | 0.4 |
| Sparse MLP | 81.4 | 0.4 |
| Logistic Regression | 80.4 | 0.4 |
| TabNet | 77.1 | 0.5 |
| VIB | 80.5 | 0.4 |
- TabTransformer 在15个数据集中的14个上比基线 MLP 提升,平均 AUC 提升1.0%。
- TabTransformer 在监督学习中与 GBDT 相匹配或接近超越,超过若干深度表格基线(如 TabNet、VIB)。
- 上下文嵌入在 Transformer 层之间变得更具预测力,使对嵌入的线性模型也能接近端到端的性能。
- 模型对嘈杂和缺失的分类特征具有鲁棒性,随着噪声或缺失程度增加,优于 MLP。
- 在半监督设置中,TabTransformer-RTD/MLM 的预训练在未标注数据充足时相对于竞争对手带来有意义的 AUC 提升(平均提升最高可达 2.1%)。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。