Skip to main content
QUICK REVIEW

[论文解读] When Do Neural Nets Outperform Boosted Trees on Tabular Data?

Duncan C. McElfresh, Sujay Khandagale|arXiv (Cornell University)|May 4, 2023
Explainable Artificial Intelligence (XAI)被引用 71
一句话总结

本论文对19种算法在176个表格数据集上进行了大规模比较,发现NN vs. GBDT的争论常被夸大,简单的基线或对GBDT的轻度超参数调优在许多数据集上就能与NN的性能相匹配甚至超过;TabPFN在小数据集上常表现突出,而GBDT在较大或不规则数据集上占主导;作者发布了 TabZilla 作为基准测试套件。

ABSTRACT

Tabular data is one of the most commonly used types of data in machine learning. Despite recent advances in neural nets (NNs) for tabular data, there is still an active discussion on whether or not NNs generally outperform gradient-boosted decision trees (GBDTs) on tabular data, with several recent works arguing either that GBDTs consistently outperform NNs on tabular data, or vice versa. In this work, we take a step back and question the importance of this debate. To this end, we conduct the largest tabular data analysis to date, comparing 19 algorithms across 176 datasets, and we find that the 'NN vs. GBDT' debate is overemphasized: for a surprisingly high number of datasets, either the performance difference between GBDTs and NNs is negligible, or light hyperparameter tuning on a GBDT is more important than choosing between NNs and GBDTs. A remarkable exception is the recently-proposed prior-data fitted network, TabPFN: although it is effectively limited to training sets of size 3000, we find that it outperforms all other algorithms on average, even when randomly sampling 3000 training datapoints. Next, we analyze dozens of metafeatures to determine what properties of a dataset make NNs or GBDTs better-suited to perform well. For example, we find that GBDTs are much better than NNs at handling skewed or heavy-tailed feature distributions and other forms of dataset irregularities. Our insights act as a guide for practitioners to determine which techniques may work best on their dataset. Finally, with the goal of accelerating tabular data research, we release the TabZilla Benchmark Suite: a collection of the 36 'hardest' of the datasets we study. Our benchmark suite, codebase, and all raw results are available at https://github.com/naszilla/tabzilla.

研究动机与目标

  • 质疑在表格数据设置中对 NN 与 GBDT 性能的强调。
  • 评估在多样化数据集上算法选择还是对超参数的调优驱动性能提升。
  • 识别预测 NN 或 GBDT 表现更好的数据集特征(元特征)。
  • 为从业者就表格数据的方法选择和调优提供实践性指导。

提出的方法

  • 在来自 OpenML 数据集的176个表格数据集上评估19种算法(GBDT、NN、TabPFN 和基线)。
  • 使用 Optuna 对每个数据集进行多达30个设置的超参数调优,且每次运行最多耗时10小时。
  • 对每个数据集进行10折交叉验证,报告测试准确率和对数损失作为主要指标。
  • 使用 PyMFE 计算共计965个元特征,以分析数据集特征。
  • 使用 Friedman 和 Wilcoxon 符号秩检验并进行 Holm-Bonferroni 校正以评估统计显著性。
  • 发布 TabZilla 基准套件,包括36个具有挑战性数据集,并提供开源代码和结果。

实验结果

研究问题

  • RQ1不同算法族(GBDT 与 NN)在大量且多样的表格数据集上的性能相对关系如何?
  • RQ2数据集规模、不规则性或其他元特征是否能预测在何时 NN 或 GBDT 表现更好?
  • RQ3简单基线或对强模型的轻量级超参数调优是否常常胜过跨族算法选择?
  • RQ4哪些数据集属性最能解释特定方法的成功或失败,并如何为新数据集的实际选择提供指导?

主要发现

  • 在176个数据集上没有单一算法占绝对优势;CatBoost 经常领先,但仍有数据集被他人击败。
  • TabPFN 平均上取得了最佳表现,且训练时间也非常快;在小数据集(≤1250 个样本)时,TabPFN 在推理速度也很快的情况下可能优于其他方法。
  • 在98个数据集(排除极端内存/时间问题的子集)中,TabPFN 在平均水平上优于所有其他方法且具有统计显著性。
  • 对强基线(如 CatBoost)的超参数调优带来的增益大于在约三分之一的数据集上在 GBDT 与 NN 之间切换。
  • GBDTs 往往在更大且更不规则的数据集上优于 NN(例如特征分布呈严重重尾或偏斜)。
  • 从业者指导:先从简单的基线开始,然后对 CatBoost 进行轻量调优,并使用元特征来指导新数据的算法选择。
  • TabZilla 基准套件包含36个困难的数据集,已发布以加速表格研究,代码和结果公开可用。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。