[论文解读] Regularization Learning Networks: Deep Learning for Tabular Datasets
本文提出了正则化学习网络(RLNs),一种深度学习框架,为神经网络中的每个权重分配独立的正则化系数,从而在特征重要性差异显著的表格数据集上实现性能提升。通过在训练过程中使用一种新型反事实损失(Counterfactual Loss)优化这些系数——无需依赖验证集——RLNs实现了与梯度提升树(GBTs)相当的性能,生成高度稀疏且可解释的模型,并在表格数据上显著优于标准深度神经网络(DNNs)。
Despite their impressive performance, Deep Neural Networks (DNNs) typically underperform Gradient Boosting Trees (GBTs) on many tabular-dataset learning tasks. We propose that applying a different regularization coefficient to each weight might boost the performance of DNNs by allowing them to make more use of the more relevant inputs. However, this will lead to an intractable number of hyperparameters. Here, we introduce Regularization Learning Networks (RLNs), which overcome this challenge by introducing an efficient hyperparameter tuning scheme which minimizes a new Counterfactual Loss. Our results show that RLNs significantly improve DNNs on tabular datasets, and achieve comparable results to GBTs, with the best performance achieved with an ensemble that combines GBTs and RLNs. RLNs produce extremely sparse networks, eliminating up to 99.8% of the network edges and 82% of the input features, thus providing more interpretable models and reveal the importance that the network assigns to different inputs. RLNs could efficiently learn a single network in datasets that comprise both tabular and unstructured data, such as in the setting of medical imaging accompanied by electronic health records. An open source implementation of RLN can be found at https://github.com/irashavitt/regularization_learning_networks.
研究动机与目标
- 为解决深度神经网络(DNNs)在表格数据集上的表现不如梯度提升树(GBTs)的问题,特别是由于输入特征重要性存在高度变异性。
- 探究为每个权重分配唯一正则化系数是否能提升DNN在非分布式表示(如表格数据中的表示)上的性能。
- 开发一种高效的超参数调优方法,避免因调优数百万个独立正则化系数而带来的不可行复杂度。
- 实现混合数据任务的联合学习,例如将表格电子健康记录与医学影像等非结构化数据结合。
- 生成稀疏、可解释的模型,揭示有意义的特征重要性并支持特征选择。
提出的方法
- 引入一种新型损失函数——反事实损失($\mathcal{L}_{CF}$),用于在训练过程中联合优化正则化系数与网络权重。
- 在对数空间中优化正则化系数,并在每次更新后应用投影操作,以防止系数消失。
- 通过在反事实损失中直接引导反向传播过程中的超参数调优,消除了对独立验证集的需求。
- 为网络中的每个权重分配唯一的正则化系数,实现适应特征重要性变异性的模块化正则化。
- 使用基于梯度的优化方法,端到端地同时更新权重和正则化系数。
- 在训练后施加稀疏性约束,使网络能够消除高达99.8%的连接和82%的输入特征,从而增强可解释性。
实验结果
研究问题
- RQ1为每个权重分配独立正则化系数是否能提升DNN在具有高度可变输入特征重要性的表格数据集上的性能?
- RQ2是否能够高效优化数百万个正则化系数,而无需依赖验证集或无导数的超参数调优方法?
- RQ3反事实损失如何实现深度网络中权重与正则化系数的有效联合优化?
- RQ4RLNs在多大程度上能生成稀疏、可解释的模型,准确反映表格数据中真实的特征重要性?
- RQ5RLNs能否与GBTs有效结合形成集成模型,在表格预测任务中实现最先进性能?
主要发现
- RLNs显著提升了DNN在表格数据集上的性能,与标准DNN相比,解释方差提高了2.75±0.05倍。
- RLNs实现了与梯度提升树(GBTs)相当的性能,尤其在输入特征重要性变异性较高的场景中表现优异。
- RLNs与GBTs的集成模型在4项特征中的3项上优于所有其他集成模型,并在微生物组预测任务中除一项特征外均达到最先进水平。
- RLNs生成了极稀疏的网络结构,可消除高达99.8%的网络连接和82%的输入特征,且稀疏性在训练前10至20个周期内即已实现。
- RLNs推导出的特征重要性相较于DNNs的Jensen-Shannon散度降低48%±1%,相较于语言模型(LMs)降低54%±2%,表明其具有更高的稳定性和可解释性。
- RLNs中特征重要性的熵为4.6比特,而DNNs中为9.5比特,表明其特征重要性分布更具意义且非均匀。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。