Skip to main content
QUICK REVIEW

[论文解读] Distilling the Knowledge in a Neural Network

Geoffrey E. Hinton, Oriol Vinyals|arXiv (Cornell University)|Mar 9, 2015
Neural Networks and Applications参考文献 9被引用 13,896
一句话总结

论文展示了如何通过蒸馏(使用软目标)将来自大集成模型或高度正则化网络的知识转移到一个更小的单一模型,在 MNIST、语音识别和大规模图像数据集上实现显著的性能提升。

ABSTRACT

A very simple way to improve the performance of almost any machine learning algorithm is to train many different models on the same data and then to average their predictions. Unfortunately, making predictions using a whole ensemble of models is cumbersome and may be too computationally expensive to allow deployment to a large number of users, especially if the individual models are large neural nets. Caruana and his collaborators have shown that it is possible to compress the knowledge in an ensemble into a single model which is much easier to deploy and we develop this approach further using a different compression technique. We achieve some surprising results on MNIST and we show that we can significantly improve the acoustic model of a heavily used commercial system by distilling the knowledge in an ensemble of models into a single model. We also introduce a new type of ensemble composed of one or more full models and many specialist models which learn to distinguish fine-grained classes that the full models confuse. Unlike a mixture of experts, these specialist models can be trained rapidly and in parallel.

研究动机与目标

  • 通过利用集成或大模型训练,在有限的延迟和资源下部署准确模型的必要性。
  • 引入一个蒸馏框架,使用软目标将泛化能力从笨重模型转移到小模型。
  • 展示蒸馏在 MNIST、语音识别和具有专业化集成的大规模图像数据集上的实际收益。

提出的方法

  • 通过在 softmax 中提高温度 T 来定义软目标,以产生更柔和的输出分布。
  • 在笨重模型产生的软目标上训练蒸馏模型,并可选地结合硬目标使用加权目标函数。
  • 证明在高温下匹配 logits 是蒸馏的一个特例,并讨论随 T 的梯度缩放(梯度 ~ 1/T^2)。
  • 使用可标签或无标签的转移集合;当存在标签时,以适当的权重和缩放将软目标损失与硬目标损失混合。
  • 提出由专门处理易混淆类别子集的专家组成的集成,初始自通用模型,并通过尘箱调整来实现平衡,防止过拟合。

实验结果

研究问题

  • RQ1一个小模型是否能通过软目标学习到大型集成的泛化行为?
  • RQ2应如何配置蒸馏(温度、损失权重)以最大化知识转移?
  • RQ3在 MNIST、语音识别和具有专业化集成的极大数据集上应用蒸馏时有哪些收益?
  • RQ4匹配 logits 是否是蒸馏的一个特例,温度如何影响 logits 所携带的信息?
  • RQ5对于极大的标签空间,专业化集成及其蒸馏有多有效?

主要发现

系统测试帧准确率WER
Baseline58.9%10.9%
10xEnsemble61.1%10.7%
Distilled Single model60.8%10.7%
  • 在 MNIST 上,使用软目标的蒸馏使较小的网络从硬目标的 146 误差显著提升到 74 错误,接近大模型的性能。
  • 在语音识别中,蒸馏后的单模型取得了与 10 模型集合相似的提升,基线帧准确率 58.9%,WER 10.9%,而蒸馏后为 60.8% 帧准确率和 10.7% WER。
  • 蒸馏将大部分集成的提升转移到单模型;对于ASR,蒸馏模型捕获了超过80%的集成改进。
  • 在 JFT 数据集上,将通用模型与 61 个专家模型结合,在 top-1 准确率上相对基线提升 4.4%。
  • 在易混淆子集上训练的专家可以独立训练,蒸馏后仍保留其收益且训练成本不过高。
  • 软目标作为强正则化器,即使转移数据极少(在 ASR 类设置中仅约 3% 的数据),也能实现良好的泛化。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。