[论文解读] SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training
SAINT 在特征自注意力与跨样本注意力的基础上,结合对比式自监督预训练,在表格数据上优于传统提升方法。它在多样基准上常常超越 XGBoost、CatBoost 和 LightGBM。
Tabular data underpins numerous high-impact applications of machine learning from fraud detection to genomics and healthcare. Classical approaches to solving tabular problems, such as gradient boosting and random forests, are widely used by practitioners. However, recent deep learning methods have achieved a degree of performance competitive with popular techniques. We devise a hybrid deep learning approach to solving tabular data problems. Our method, SAINT, performs attention over both rows and columns, and it includes an enhanced embedding method. We also study a new contrastive self-supervised pre-training method for use when labels are scarce. SAINT consistently improves performance over previous deep learning methods, and it even outperforms gradient boosting methods, including XGBoost, CatBoost, and LightGBM, on average over a variety of benchmark tasks.
研究动机与目标
- 为表格数据提出一种能处理异质特征类型且不依赖固有列顺序信息的神经网络方法。
- 提出 SAINT,一种基于 transformer 的架构,在特征上应用自注意力并在行之间应用跨样本注意力。
- 引入对比自监督预训练阶段,以在半监督设置下提升性能。
- 在广泛的基准测试中展示 SAINT 相对于基于树的方法和先前的深度表格模型的经验改进。
提出的方法
- 将连续特征和分类特征投射到一个共同的密集嵌入空间。
- 用结合自注意力和新颖的跨样本注意力(跨批次行)的 transformer 编码器处理嵌入。
- 在 transformer 处理之前,通过每个特征的可学习投影对连续特征进行嵌入。
- 使用混合目标进行预训练:对比损失(InfoNCE)加上来自增强视图的去噪损失(输入空间中的 CutMix 和嵌入空间中的 mixup)。
- 通过 MLP 从 [CLS] 嵌入预测目标进行微调。
- 提供消融研究和注意力可视化以解释模型行为。
实验结果
研究问题
- RQ1SAINT 的自注意力与跨样本注意力的结合是否能在表格数据建模方面超越传统的提升方法?
- RQ2对比预训练在表格数据的半监督设置中是否带来提升?
- RQ3将连续特征进行嵌入相比于以往的表格变换器,其性能如何?
- RQ4跨样本注意力在何种情形最有益(例如大量特征、少量标签)?
主要发现
- SAINT 的变体在 14 个二分类数据集的 AUROC 上通常优于基线模型,SAINT 常常取得最佳结果。
- 平均而言,SAINT 相对于传统提升方法(XGBoost、LightGBM、CatBoost)以及其他深度表格模型有提升。
- 在半监督设置中,预训练 SAINT(同时含自注意力和跨样本注意力)获得最佳结果,尤其是在标注数据有限的情况下。
- 将连续特征嵌入显著提升性能,TabTransformer 的比较也证明了这一点。
- 跨样本注意力对噪声具有鲁棒性,在特征数量较多或数据稀缺时尤为有用。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。