[论文解读] FitNets: Hints for Thin Deep Nets
本文通过使用教师网络的中间线索来训练更深更瘦的学生网络(FitNets),扩展了知识蒸馏的应用,使在参数显著更少、推理更快的条件下实现高精度。
While depth tends to improve network performances, it also makes gradient-based training more difficult since deeper networks tend to be more non-linear. The recently proposed knowledge distillation approach is aimed at obtaining small and fast-to-execute models, and it has shown that a student network could imitate the soft output of a larger teacher network or ensemble of networks. In this paper, we extend this idea to allow the training of a student that is deeper and thinner than the teacher, using not only the outputs but also the intermediate representations learned by the teacher as hints to improve the training process and final performance of the student. Because the student intermediate hidden layer will generally be smaller than the teacher's intermediate hidden layer, additional parameters are introduced to map the student hidden layer to the prediction of the teacher hidden layer. This allows one to train deeper students that can generalize better or run faster, a trade-off that is controlled by the chosen student capacity. For example, on CIFAR-10, a deep student network with almost 10.4 times less parameters outperforms a larger, state-of-the-art teacher network.
研究动机与目标
- 推动对宽大且深的网络进行压缩,以提高内存和计算效率。
- 提出一种使用教师派生提示来训练瘦而深的学生网络的方法。
- 将知识蒸馏与中间表示结合来引导训练。
- 证明更深更瘦的模型可以在标准基准上达到甚至超过教师模型的性能。
- 展示更好的优化的分阶段训练和课程学习视角。
提出的方法
- 对知识蒸馏(KD)的回顾,其中学生通过温度参数 tau 模拟教师的软化输出。
- 引入基于提示的训练,其中教师的隐藏层(提示)通过一个回归器在尺寸不同时引导学生的对应隐藏层(受引导隐藏层)。
- 使用卷积回归器将学生的受引导层映射到教师的提示层,从而减少参数的增长。
- 描述一个分阶段训练过程:首先用提示训练到受引导层,然后使用 KD 损失训练完整的 FitNet。
- 给出损失函数 L_KD,将标准交叉熵与教师输出的软化项结合起来,由 lambda 平衡;L_HT 用于教师提示与学生受引导表示之间的基于提示的映射。
- 讨论与课程学习的关系,其中教师的置信度充当课程信号,训练过程中对 lambda 进行退火。
实验结果
研究问题
- RQ1是否可以通过利用教师的中间表示作为提示来有效训练更深更瘦的学生网络?
- RQ2基于提示的训练再加上 KD 是否在训练深度瘦网络方面优于标准反向传播和纯 KD?
- RQ3在使用 FitNets 时,模型深度、参数数量与推理效率之间的权衡是什么?
- RQ4与教师及其他压缩方法相比,FitNets 在标准视觉基准上的泛化能力如何?
主要发现
- 深厚、瘦小的学生网络在参数和计算量显著更少的条件下可以超过教师。
- 基于提示的训练(HT)使得比仅使用 KD 更深的网络成为可能,带来更好的泛化。
- 在 CIFAR-10 上,深度 11 层、约 250K 参数的 FitNet 达到 89.01% 的准确率,超过教师并实现显著的加速和压缩。
- 在 CIFAR-10 的较大 FitNets(例如 11–19 层),准确率达到 91.61%,参数约 250 万,相较于教师(约 900 万参数)在准确度上有明显提升,尽管容量大幅降低。
- 在 CIFAR-100 上,FitNets 再次超越教师,参数显著减少(大约少 3 倍)且仍具竞争力的准确度。
- 在 SVHN 上,参数约 3 万到 150 万的 FitNets 获得接近甚至优于教师的有竞争力的错误率,同时仅使用了一小部分参数。
- MNIST 测试表明 HT 加 KD 能带来显著提升,某 FitNet 使用比教师少 12 倍的参数实现 0.51% 的错误分类率。
- AFLW 实验表明提示在更瘦的架构中带来显著改进,在若干情况下 HT 的表现优于 KD。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。