Skip to main content
QUICK REVIEW

[논문 리뷰] Distilling a Neural Network Into a Soft Decision Tree

Nicholas Frosst, Geoffrey E. Hinton|arXiv (Cornell University)|2017. 11. 27.
Machine Learning and Data Classification참고 문헌 5인용 수 266
한 줄 요약

이 논문은 학습된 신경망에서 소프트 의사결정 트리로 지식을 증류하여 계층적 의사결정을 가능하게 하고, 해석가능성을 높이면서도 합리적인 정확도를 유지하는 방법을 제시한다.

ABSTRACT

Deep neural networks have proved to be a very effective way to perform classification tasks. They excel when the input data is high dimensional, the relationship between the input and the output is complicated, and the number of labeled training examples is large. But it is hard to explain why a learned network makes a particular classification decision on a particular test case. This is due to their reliance on distributed hierarchical representations. If we could take the knowledge acquired by the neural net and express the same knowledge in a model that relies on hierarchical decisions instead, explaining a particular decision would be much easier. We describe a way of using a trained neural net to create a type of soft decision tree that generalizes better than one learned directly from the training data.

연구 동기 및 목표

  • 깊은 신경망의 일반화와 해석가능성 사이의 긴장을 자극한다.
  • 신경망에서 증류된 소프트 계층적 의사결정 트리를 제안한다.
  • 증류된 트리가 원데이터로 훈련된 트리보다 일반화가 더 잘 되는지 보인다.
  • MNIST 및 다른 데이터셋에서 직관적으로 해석가능한 이점을 보이는 접근법을 시연한다.

제안 방법

  • 내부 노드에서 학습된 필터와 클래스에 대한 리프 분포 Q_ell를 갖는 소프트 이진 의사결정 트리를 사용한다.
  • 각 내부 노드 i는 p_i(x) = sigma(beta(x w_i + b_i))를 오른쪽으로 갈 확률로 계산한다.
  • 리프는 클래스 분포 Q^ell_k = exp(phi^ell_k) / sum_k' exp(phi^ell_k')를 보유한다.
  • 손실 L(x) = -log( sum_ell P^ell(x) sum_k T_k log Q^ell_k )를 최소화하도록 미니배치 경사 하강법으로 트리를 훈련한다.
  • 루브를 통해 하위 트리의 균형 사용을 장려하는 정규화로, 노드 i에 대한 평균 경로 확률과 깊이 의존적인 교차 엔트로피 페널티(alpha_i)에 연결한다.
  • 선택적으로 소프트 타깃 T를 사용하여 신경망 예측에서 얻은 라벨과 NN 출력의 혼합으로 증류한다.
  • 추론 시 최종 예측 분포를 얻기 위해 경로 확률이 가장 높은 리프를 사용한다.

실험 결과

연구 질문

  • RQ1소프트 의사결정 트리가 해석가능하게 유지되면서 신경망의 입력-출력 함수를 모방할 수 있는가?
  • RQ2신경망으로부터의 증류가 데이터에서 직접 훈련하는 경우보다 소프트 의사결정 트리의 정확도를 향상시키는가?
  • RQ3정규화 항과 깊이 관련 페널티가 학습 및 일반화에 어떤 영향을 미치는가?

주요 결과

  • MNIST에서 참 타깃으로 훈련된 깊이 8의 소프트 의사결정 트리는 테스트 정확도 94.45%에 도달한다.
  • CNN 층을 갖춘 신경망은 MNIST에서 99.21%의 성능으로 소프트 트리보다 높다.
  • 신경망에서 얻은 소프트 타깃은 트리의 테스트 정확도를 96.76%로 올려, 참 타깃으로 훈련된 트리와 NN의 중간 성능을 보인다.
  • 낮은 노드에서의 데이터 분포 희소성으로 인해 소프트 트리는 데이터에 직접 훈련된 트리보다 일반화가 더 잘된다.
  • 데이터셋 전반에 걸쳐 증류는 해석가능한 모델임에도 합리적인 정확도를 가능하게 하였으며, 예: Connect4: 80.60% 대 78.63% (NN-없는 기본선); Letter: 78.0% (깊이 9, 원본) 및 81.0% (NN 앙상블에서 증류).
  • 이 접근법은 의사결정 경로와 학습된 필터의 해설을 돕는 해석가능한 시각화를 제공한다.

더 나은 연구,지금 바로 시작하세요

연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.

카드 등록 없음 · 무료 플랜 제공

이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.