Skip to main content
QUICK REVIEW

[논문 리뷰] Deep Learning for Case-Based Reasoning through Prototypes: A Neural Network that Explains Its Predictions

Oscar Li, Hao Liu|arXiv (Cornell University)|2017. 10. 13.
Explainable Artificial Intelligence (XAI)인용 수 158
한 줄 요약

잠재 공간에서 학습된 프로토타입을 통해 예측을 설명하기 위해 오토인코더와 프로토타입 기반 계층을 결합한 해석 가능한 신경망을 소개한다.

ABSTRACT

Deep neural networks are widely used for classification. These deep models often suffer from a lack of interpretability -- they are particularly difficult to understand because of their non-linear nature. As a result, neural networks are often treated as "black box" models, and in the past, have been trained purely to optimize the accuracy of predictions. In this work, we create a novel network architecture for deep learning that naturally explains its own reasoning for each prediction. This architecture contains an autoencoder and a special prototype layer, where each unit of that layer stores a weight vector that resembles an encoded training input. The encoder of the autoencoder allows us to do comparisons within the latent space, while the decoder allows us to visualize the learned prototypes. The training objective has four terms: an accuracy term, a term that encourages every prototype to be similar to at least one encoded input, a term that encourages every encoded input to be close to at least one prototype, and a term that encourages faithful reconstruction by the autoencoder. The distances computed in the prototype layer are used as part of the classification process. Since the prototypes are learned during training, the learned network naturally comes with explanations for each prediction, and the explanations are loyal to what the network actually computes.

연구 동기 및 목표

  • 딥러닝에서 해석 가능한 예측의 필요성을 동기화하고 표준 신경망의 해석 가능성 부족을 해결한다.
  • 오토인코더와 프로토타입 계층을 통합하여 사례 기반 설명을 제공하는 신경망 아키텍처를 제안한다.
  • 잠재 공간 프로토타입을 입력 공간으로 디코딩하여 학습된 프로토타입의 시각화를 가능하게 한다.
  • 모델이 비교적 해석 가능성을 높이는 전용 정규화 항들을 통해 예측 성능을 유지하도록 한다.

제안 방법

  • 두 구성 요소 아키텍처: 오토인코더(인코더 f와 디코더 g)와 잠재 공간에서의 프로토타입 분류 네트워크 h.
  • Prototype layer p는 인코딩 입력 z=f(x)와 R^q의 m개 프로토타입 p1,...,pm 사이의 제곱 L2 거리를 계산한다; 완전 연결 계층 W가 거리를 클래스 로짓으로 결합하고 소프트맥스가 뒤따른다.
  • 학습 목표는 교차 엔트로피 손실 E, 재구성 손실 R, 그리고 두 가지 해석 가능성 정규화 항 R1과 R2를 결합한 것으로, 총 손실 L = E(h∘f,D) + λR(g∘f,D) + λ1R1(...) + λ2R2(...)로 하이퍼파라미터화한다.
  • R1은 잠재 공간에서 각 프로토타입이 적어도 하나의 인코딩된 입력과 가깝도록 유도하고; R2는 모든 인코딩된 입력이 적어도 하나의 프로토타입과 가깝도록 유도한다.
  • 프로토타입 벡터는 잠재 공간에 존재하여 입력 공간으로 디코딩하여 시각화할 수 있으며; W는 프로토타입-클래스 간 관계를 반영하도록 학습될 수 있다.

실험 결과

연구 질문

  • RQ1학습된 프로토타입을 잠재 공간에서 사례 기반 추론으로 이용해 예측을 설명하도록 설계된 신경망이 가능할까?
  • RQ2잠재 공간 프로토타입과 명시적 정규화 항이 정확도를 크게 손상시키지 않으면서 의미 있고 시각화 가능한 설명을 제공할까?
  • RQ3R1과 R2 해석 가능성 항이 프로토타입의 품질과 데이터 세트 간 일반화에 어떤 영향을 미칠까?
  • RQ4프로토타입-클래스 가중치 매트릭스 W를 학습하는 것이 분류 동작과 해석 가능성에 어떤 영향을 미칠까?
  • RQ5비해석 가능 네트워크와 비교했을 때 이 아키텍처는 표준 이미지 분류 벤치마크에서 어떤 성능을 보일까?

주요 결과

  • 모델은 MNIST(train 99.53%, test 99.22%), Fashion-MNIST(89.95%), Cars 데이터셋에서 경쟁력 있는 정확도를 달성하면서 프로토타입을 통해 내재적 설명을 제공한다.
  • 해독된 프로토타입은 R1과 R2가 가능하게 한 잠재 공간 표현을 바탕으로 실제 숫자 및 의류 아이템을 시각적으로 유사하게 나타낸다.
  • 특정 아블레이션 연구에서 프로토타입 계층이나 디코더를 제거하면 비해석 가능 기초모형과 유사한 정확도를 보이며, 해석 가능성이 이 작업들에서 성능을 크게 감소시키지 않음을 시사한다.
  • 학습된 가중치 행렬 W는 각 클래스에 가장 큰 영향을 미치는 프로토타입을 드러내며 클래스 간 관계 및 프로토타입의 유용성에 대한 통찰을 제공한다.
  • 프로토타입 시각화는 클래스 내 변이(예: 6과 3의 다양한 필기 스타일) 및 교차 클래스 애매성을 보여주며 사례 기반 추론과 일치한다.

더 나은 연구,지금 바로 시작하세요

연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.

카드 등록 없음 · 무료 플랜 제공

이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.