[论文解读] Neural Oblivious Decision Ensembles for Deep Learning on Tabular Data
NODE 是一个 differentiable 深度架构,将 oblivious decision trees 扩展到多层集成,在表格数据上取得了最先进的结果,且常常超过经过调优的梯度提升方法。
Nowadays, deep neural networks (DNNs) have become the main instrument for machine learning tasks within a wide range of domains, including vision, NLP, and speech. Meanwhile, in an important case of heterogenous tabular data, the advantage of DNNs over shallow counterparts remains questionable. In particular, there is no sufficient evidence that deep learning machinery allows constructing methods that outperform gradient boosting decision trees (GBDT), which are often the top choice for tabular problems. In this paper, we introduce Neural Oblivious Decision Ensembles (NODE), a new deep learning architecture, designed to work with any tabular data. In a nutshell, the proposed NODE architecture generalizes ensembles of oblivious decision trees, but benefits from both end-to-end gradient-based optimization and the power of multi-layer hierarchical representation learning. With an extensive experimental comparison to the leading GBDT packages on a large number of tabular datasets, we demonstrate the advantage of the proposed NODE architecture, which outperforms the competitors on most of the tasks. We open-source the PyTorch implementation of NODE and believe that it will become a universal framework for machine learning on tabular data.
研究动机与目标
- Motivate the need for deep learning on heterogeneous tabular data where traditional DNNs underperform compared with GBDTs.
- Introduce NODE, a differentiable ensemble of oblivious decision trees trained end-to-end.
- Show that NODE outperforms leading GBDT packages across multiple tabular datasets.
- Demonstrate efficiency and practicality, including end-to-end training and inference considerations.
提出的方法
- Define differentiable oblivious decision trees (ODTs) with shared-depth split features and entmax-based soft decisions.
- Use a NODE layer as a differentiable ensemble of m ODTs of depth d, with a shared feature selection matrix F, thresholds b, and response tensor R.
- Relax Heaviside splits to differentiable entmax-based decisions and use an outer-product to form a choice tensor C for leaf routing.
- Stack multiple NODE layers in a DenseNet-like architecture to form a deep NODE model, with outputs averaged across all trees and layers.
- Preprocess data with quantile transform, initialize parameters data-aware, and train end-to-end with mini-batch SGD (Quasi-Hyperbolic Adam) and checkpoint averaging.
- Provide an inference optimization by precomputing sparse entmax selectors for fast prediction.
实验结果
研究问题
- RQ1Can differentiable, end-to-end trainable oblivious decision ensembles outperform tuned GBDT methods on tabular data?
- RQ2Does stacking NODE layers improve expressive power for tabular problems without sacrificing train/inference efficiency?
- RQ3What role does entmax play in learning sparse, effective feature splits within differentiable decision trees?
- RQ4How do NODE-based models compare to CatBoost, XGBoost, and neural baselines across diverse tabular datasets?
主要发现
| Dataset | CatBoost | XGBoost | NODE | mGBDT | DeepForest |
|---|---|---|---|---|---|
| Epsilon | 0.1119±2e-4 | 0.1144 | 0.1034±3e-4 | OOM | 0.1179 |
| YearPrediction | 80.68±0.04 | 81.11 | 77.43±0.09 | 80.67 | — |
| Higgs | 0.2434±2e-4 | 0.2600 | 0.2412±5e-4 | OOM | 0.2391 |
| Microsoft | 0.5587±2e-4 | 0.5637 | 0.5584±3e-4 | OOM | — |
| Yahoo | 0.5781±3e-4 | 0.5756 | 0.5666±5e-4 | OOM | — |
| Click | 0.3438±1e-3 | 0.3461 | 0.3309±3e-4 | OOM | 0.3333 |
- NODE consistently outperforms CatBoost and XGBoost under default hyperparameters across multiple datasets.
- With tuned hyperparameters, NODE still outperforms competitors on most tasks, with Yahoo and Microsoft sometimes favoring tuned XGBoost.
- Ablation shows entmax (α=1.5) yields superior results over softmax, Gumbel-Softmax, and sparsemax across depths and datasets.
- Feature importance analysis indicates earlier layers provide most input features, while deeper layers contribute more to final prediction.
- NODE reaches competitive training and inference times, with inference on par with optimized GBDT libraries on GPU/CPU setups.
- NODE-based methods robustly handle tabular data and offer a scalable, end-to-end differentiable alternative to GBDTs.
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。