Skip to main content
QUICK REVIEW

[论文解读] Distilling a Neural Network Into a Soft Decision Tree

Nicholas Frosst, Geoffrey E. Hinton|arXiv (Cornell University)|Nov 27, 2017
Machine Learning and Data Classification参考文献 5被引用 266
一句话总结

本文提出了一种将训练好的神经网络中的知识蒸馏为一种软决策树的方法,使其做出层次化决策,在提高可解释性的同时保持合理的准确性。

ABSTRACT

Deep neural networks have proved to be a very effective way to perform classification tasks. They excel when the input data is high dimensional, the relationship between the input and the output is complicated, and the number of labeled training examples is large. But it is hard to explain why a learned network makes a particular classification decision on a particular test case. This is due to their reliance on distributed hierarchical representations. If we could take the knowledge acquired by the neural net and express the same knowledge in a model that relies on hierarchical decisions instead, explaining a particular decision would be much easier. We describe a way of using a trained neural net to create a type of soft decision tree that generalizes better than one learned directly from the training data.

研究动机与目标

  • 激发深度网络的泛化能力与可解释性之间的张力。
  • 提出一个由神经网络蒸馏而来的软层次决策树。
  • 证明蒸馏得到的树在泛化方面优于直接在原始数据上训练的树。
  • 在 MNIST 及其他数据集上展示该方法,具有定性可解释性方面的好处。

提出的方法

  • 使用在内部节点具有学习到的过滤器的软二叉决策树,以及叶子节点的类别分布 Q_ell。
  • 每个内部节点 i 计算 p_i(x) = sigma(beta(x w_i + b_i)) 作为向右走的概率。
  • 叶子持有类别分布 Q^ell_k = exp(phi^ell_k) / sum_k' exp(phi^ell_k').
  • 通过小批量梯度下降训练树以最小化 L(x) = -log( sum_ell P^ell(x) sum_k T_k log Q^ell_k ).
  • 通过深度相关的交叉熵惩罚来鼓励对子树的均衡使用,该惩罚与 alpha_i(到节点 i 的平均路径概率)相关联。
  • 可选地通过使用软目标 T,将真实标签与神经网络输出混合,以从神经网络预测中进行蒸馏。
  • 在测试时,使用具有最大路径概率的叶子节点来得到最终的预测分布。

实验结果

研究问题

  • RQ1在保持可解释性的同时,软决策树能否仿效神经网络的输入输出函数?
  • RQ2从神经网络进行蒸馏是否能提高软决策树的准确性,相较于直接在数据上训练?
  • RQ3正则化项和与深度相关的惩罚项如何影响学习与泛化?

主要发现

  • 在 MNIST 上,深度为 8 的软决策树在真实标签上训练,测试准确率达到 94.45%。
  • 带有卷积层的神经网络在 MNIST 上达到 99.21%,高于该软树。
  • 来自神经网络的软目标将树提升到 96.76% 的测试准确率,介于 NN 和在真实目标上训练的树之间。
  • 软树比直接在数据上训练的树具有更好的泛化能力,这归因于下游节点的数据分布稀疏性。
  • 在各数据集上,蒸馏使得在可解释模型下仍具备合理的准确性,例如 Connect4: 80.60% 与 78.63%(无 NN 基线);Letter: 78.0%(深度9,原始)和 81.0%(从 NN 集成蒸馏得来)。
  • 该方法产生可解释的决策路径和学习到的滤波器的可视化,有助于解释单个预测。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。