[论文解读] Deep Learning for Case-Based Reasoning through Prototypes: A Neural Network that Explains Its Predictions
引入一种可解释的神经网络,将自编码器与基于原型的层结合,通过潜在空间中学习的原型来解释预测。
Deep neural networks are widely used for classification. These deep models often suffer from a lack of interpretability -- they are particularly difficult to understand because of their non-linear nature. As a result, neural networks are often treated as "black box" models, and in the past, have been trained purely to optimize the accuracy of predictions. In this work, we create a novel network architecture for deep learning that naturally explains its own reasoning for each prediction. This architecture contains an autoencoder and a special prototype layer, where each unit of that layer stores a weight vector that resembles an encoded training input. The encoder of the autoencoder allows us to do comparisons within the latent space, while the decoder allows us to visualize the learned prototypes. The training objective has four terms: an accuracy term, a term that encourages every prototype to be similar to at least one encoded input, a term that encourages every encoded input to be close to at least one prototype, and a term that encourages faithful reconstruction by the autoencoder. The distances computed in the prototype layer are used as part of the classification process. Since the prototypes are learned during training, the learned network naturally comes with explanations for each prediction, and the explanations are loyal to what the network actually computes.
研究动机与目标
- 在深度学习中需要可解释的预测,并解决标准神经网络缺乏可解释性的问题。
- 提出一种神经架构,将自编码器与原型层整合,以提供基于案例的解释。
- 通过将潜在空间中的原型解码回输入空间,使学习到的原型可视化。
- 在通过专门的正则化项提升可解释性的同时,确保模型保持有竞争力的预测性能。
提出的方法
- 两组件架构:一个自编码器(编码器 f 和解码器 g)和潜在空间中的原型分类网络 h。
- 原型层 p 计算编码输入 z=f(x) 与 R^q 中的 m 个原型 p1,...,pm 的平方 L2 距离;一个全连接层 W 将距离整合为类别对数,再经过 softmax。
- 训练目标将交叉熵损失 E、重构损失 R 与两个可解释性正则化项 R1、R2 结合起来,再加上超参数化的总损失 L = E(h∘f,D) + λR(g∘f,D) + λ1R1(...) + λ2R2(...).
- R1 促使每个原型在潜在空间中接近至少一个编码输入;R2 促使每个编码输入接近至少一个原型。
- 原型向量位于潜在空间中,通过解码回输入空间实现可视化;W 可以学习以反映原型与类别之间的关系。
实验结果
研究问题
- RQ1是否可以设计一个神经网络,通过在潜在空间中使用学习到的原型进行基于案例的推理来解释其预测?
- RQ2潜在空间中的原型与显式正则化项结合,是否能够在不牺牲准确性的情况下产生有意义且可视化的解释?
- RQ3解释性项 R1 和 R2 如何影响原型质量以及在各数据集上的泛化?
- RQ4学习原型到类别的权重矩阵 W 对分类行为和可解释性有什么影响?
- RQ5与非可解释网络相比,该架构在标准图像分类基准上的表现如何?
主要发现
- 模型在 MNIST(train 99.53%, test 99.22%)、Fashion-MNIST(89.95%)和 Cars 数据集上达到有竞争力的准确率,同时通过原型提供内生解释。
- 解码后的原型在视觉上类似真实数字和服装物品,展示了由 R1 和 R2 启用的有意义的潜在空间表征。
- 消融研究表明,移除原型层或解码器会得到接近非可解释基线的准确率,说明在这些任务中可解释性并不显著降低性能。
- 学习得到的权重矩阵 W 揭示了哪些原型对每个类别影响最大,提供了类别关系和原型效用的洞见。
- 原型可视化展示了类内变异(例如数字 6 与 3 的不同书写风格)以及跨类别的歧义,与基于案例的推理一致。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。