[论文解读] Gradient Boosting Decision Trees on Medical Diagnosis over Tabular Data
本研究实证表明梯度提升决策树(LightGBM、XGBoost、CatBoost)在七个医学表格数据集上优于传统ML和表格DL模型,且训练时间较优。
Medical diagnosis is a crucial task in the medical field, in terms of providing accurate classification and respective treatments. Having near-precise decisions based on correct diagnosis can affect a patient's life itself, and may extremely result in a catastrophe if not classified correctly. Several traditional machine learning (ML), such as support vector machines (SVMs) and logistic regression, and state-of-the-art tabular deep learning (DL) methods, including TabNet and TabTransformer, have been proposed and used over tabular medical datasets. Additionally, due to the superior performances, lower computational costs, and easier optimization over different tasks, ensemble methods have been used in the field more recently. They offer a powerful alternative in terms of providing successful medical decision-making processes in several diagnosis tasks. In this study, we investigated the benefits of ensemble methods, especially the Gradient Boosting Decision Tree (GBDT) algorithms in medical classification tasks over tabular data, focusing on XGBoost, CatBoost, and LightGBM. The experiments demonstrate that GBDT methods outperform traditional ML and deep neural network architectures and have the highest average rank over several benchmark tabular medical diagnosis datasets. Furthermore, they require much less computational power compared to DL models, creating the optimal methodology in terms of high performance and lower complexity.
研究动机与目标
- 评估 GBDT 模型(XGBoost、LightGBM、CatBoost)在多样化的表格医学诊断数据集上的性能。
- 将 GBDT 与传统 ML 和前沿的表格 DL 模型进行比较。
- 分析实际临床部署中的训练时间与性能权衡。
- 基于数据集规模和特征特性,为医疗表格数据的模型选择提供指导。
提出的方法
- 使用序数编码对分类特征进行预处理,并对数值特征进行标准化。
- 使用 ROC AUC 作为指标,评估 5 个传统 ML 模型、5 个 DL 模型,以及 4 个集成模型(3 种 GBDT)。
- 进行 8 折分层交叉验证以评估泛化能力。
- 超参数优化:基于各折平均 ROC AUC,评估每个模型约 36 种组合。
- 从性能和平均训练时间两个维度比较模型。
实验结果
研究问题
- RQ1GBDT 模型是否在多样化的医学数据集上实现比传统 ML 和表格 DL 模型更高的 ROC AUC?
- RQ2哪一种 GBDT 实现(XGBoost、LightGBM、CatBoost)在性能与训练时间之间提供最佳权衡?
- RQ3随着数据集规模和特征维度增加,模型在医学表格数据上的性能如何扩展?
- RQ4基于准确性和效率,在临床决策支持中的模型选择有哪些实际意义?
主要发现
| 模型 | CD | 心力衰竭 | 帕金森病 | EEG 眼状态 | 眼动 | Arcene | 前列腺 | 平均排名 |
|---|---|---|---|---|---|---|---|---|
| SVM | 78.715 ± 0.005 | 86.389 ± 0.048 | 88.791 ± 0.068 | 70.752 ± 0.013 | 78.405 ± 0.007 | 87.094 ± 0.043 | 91.419 ± 0.096 | 9.857 |
| Logistic Reg. | 78.435 ± 0.005 | 87.571 ± 0.051 | 90.875 ± 0.041 | 61.125 ± 0.014 | 71.180 ± 0.009 | 95.211 ± 0.031 | 95.089 ± 0.065 | 8.143 |
| KNN | 69.611 ± 0.006 | 77.529 ± 0.067 | 96.857 ± 0.023 | 91.185 ± 0.005 | 72.448 ± 0.009 | 90.869 ± 0.065 | 87.822 ± 0.112 | 9.857 |
| Random Forest | 77.464 ± 0.005 | 91.233 ± 0.038 | 96.068 ± 0.033 | 98.404 ± 0.002 | 87.234 ± 0.007 | 91.153 ± 0.034 | 93.155 ± 0.078 | 6.000 |
| Decision Tree | 63.325 ± 0.006 | 71.646 ± 0.051 | 81.287 ± 0.060 | 83.781 ± 0.008 | 70.951 ± 0.009 | 72.037 ± 0.116 | 80.357 ± 0.106 | 12.714 |
| LDA | 70.363 ± 0.005 | 87.896 ± 0.053 | 88.609 ± 0.060 | 67.130 ± 0.014 | 71.273 ± 0.010 | 69.927 ± 0.124 | 93.849 ± 0.060 | 10.571 |
| MLP [60] | 80.090 ± 0.005 | 87.288 ± 0.056 | 97.186 ± 0.022 | 95.513 ± 0.006 | 73.397 ± 0.015 | 93.669 ± 0.042 | 89.881 ± 0.108 | 6.429 |
| STG [37] | 79.667 ± 0.004 | 86.241 ± 0.058 | 95.352 ± 0.038 | 84.854 ± 0.011 | 80.780 ± 0.006 | 90.584 ± 0.062 | 94.048 ± 0.094 | 7.857 |
| TabNet [9] | 77.757 ± 0.004 | 93.319 ± 0.037 | 99.446 ± 0.012 | 62.441 ± 0.040 | 87.673 ± 0.008 | 87.662 ± 0.098 | 66.865 ± 0.205 | 7.429 |
| TabTransformer [36] | 71.327 ± 0.123 | 87.642 ± 0.069 | 96.625 ± 0.027 | 79.646 ± 0.039 | 70.534 ± 0.010 | 94.724 ± 0.051 | 92.956 ± 0.107 | 8.571 |
| VIME [38] | 78.882 ± 0.004 | 85.758 ± 0.047 | 98.532 ± 0.016 | 92.473 ± 0.005 | 81.918 ± 0.008 | 91.721 ± 0.070 | 52.679 ± 0.164 | 7.429 |
| XGBoost [49] | 79.745 ± 0.004 | 90.478 ± 0.025 | 97.265 ± 0.023 | 98.331 ± 0.002 | 89.675 ± 0.008 | 89.123 ± 0.047 | 94.940 ± 0.055 | 4.429 |
| LightGBM [50] | 80.296 ± 0.004 | 91.490 ± 0.027 | 98.623 ± 0.015 | 97.008 ± 0.004 | 89.059 ± 0.007 | 91.883 ± 0.043 | 95.486 ± 0.052 | 2.571 |
| CatBoost [51] | 80.378 ± 0.004 | 91.056 ± 0.034 | 97.740 ± 0.014 | 97.739 ± 0.003 | 88.954 ± 0.006 | 91.396 ± 0.040 | 96.379 ± 0.053 | 3.143 |
- GBDT 模型在七个数据集上始终优于传统 ML 和最先进的表格 DL 模型。
- LightGBM 在评估的模型中通常取得最佳的平均 ROC AUC 和较优的训练时间。
- 通常而言,GBDT 相较于 DL 架构在计算成本更低的情况下提供更优的性能。
- 在模型之间,表现最佳的 GBDT 变体因数据集而异,但 LightGBM 经常排名靠前并显示出强大的整体指标。
- 由于模型复杂性,DL 模型往往训练时间较长,而 GBDT 在准确性与效率之间取得平衡。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。