[论文解读] NBDT: Neural-Backed Decision Trees
NBDTs 将神经网络的最终层替换为一个可微分的斜决策树,在保持高准确性的同时提供可解释的、基于路径的解释。
Machine learning applications such as finance and medicine demand accurate and justifiable predictions, barring most deep learning methods from use. In response, previous work combines decision trees with deep learning, yielding models that (1) sacrifice interpretability for accuracy or (2) sacrifice accuracy for interpretability. We forgo this dilemma by jointly improving accuracy and interpretability using Neural-Backed Decision Trees (NBDTs). NBDTs replace a neural network's final linear layer with a differentiable sequence of decisions and a surrogate loss. This forces the model to learn high-level concepts and lessens reliance on highly-uncertain decisions, yielding (1) accuracy: NBDTs match or outperform modern neural networks on CIFAR, ImageNet and better generalize to unseen classes by up to 16%. Furthermore, our surrogate loss improves the original model's accuracy by up to 2%. NBDTs also afford (2) interpretability: improving human trustby clearly identifying model mistakes and assisting in dataset debugging. Code and pretrained NBDTs are at https://github.com/alvinwan/neural-backed-decision-trees.
研究动机与目标
- 旨在共同提升图像分类任务的预测准确性与可解释性。
- 用一个可微分的斜决策树替换神经网络的最终线性层。
- 引入树监督损失和诱导层次结构以学习高层概念。
- 实现基于路径概率的推理,容忍不确定的中间决策。
- 展示对未见类别的改进泛化能力以及更可信的解释。
提出的方法
- 用一个可微分的斜决策树替换最终线性层,叶权重与类别预测绑定。
- 使用软性(概率性)路径遍历,以便从不确定的早期决策中恢复(软推理)。
- 用相应的神经网络权重对节点权重进行种子初始化,并通过 softmax 内积计算子节点概率。
- 通过对预训练的类别权重向量进行分层聚类并对叶权重求平均以形成内结点权重,来构建诱导层次结构。
- 在内部结点上使用 WordNet 概念进行标注,以在可能的地方提供语义含义。
- 使用树监督损失进行训练,将标准交叉熵与基于路径概率的层次路径分布交叉熵相结合,权重随时间变化。
实验结果
研究问题
- RQ1神经网络背书的决策树是否能够在 CIFAR、TinyImageNet 和 ImageNet 上达到与现代神经网络相当或更高的准确性?
- RQ2基于模型权重构建的诱导层次结构是否在 NBDTs 中优于数据驱动或基于 WordNet 的层次结构?
- RQ3树监督损失是否提升原模型的准确性并帮助学习高层次决策?
- RQ4与显著性图相比,NBDTs 是否提供更有用、更可信的解释,以识别错误分类和模糊标签?
- RQ5与标准神经网络相比,NBDTs 是否在未见类别上具有更好的泛化能力?
主要发现
- NBDTs 在 CIFAR、TinyImageNet 和 ImageNet 的准确性上与现代网络相当或更高。
- NBDTs 对未见类别的泛化可达 16%,并且可将原始模型的准确性提高至多 2%。
- 由预训练权重构建的诱导层次结构在准确性方面优于 WordNet 和数据驱动的层次结构。
- 带路径概率的树监督能改善学习并获得比分层 softmax 更好的性能。
- NBDT 的解释帮助用户比显著性图更准确地识别模型错误,并在具有挑战性的任务中提升信任度。
- 零-shot 超类别泛化显示 NBDT 在若干超类别区分上优于骨干网络。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。