[论文解读] Hydra: Preserving Ensemble Diversity for Model Distillation
Hydra 使用一个共享主体并设多个头来蒸馏集合,保留各成员的预测和不确定性,相较于标准蒸馏在预测性能和不确定性量化方面均有提升。
Ensembles of models have been empirically shown to improve predictive performance and to yield robust measures of uncertainty. However, they are expensive in computation and memory. Therefore, recent research has focused on distilling ensembles into a single compact model, reducing the computational and memory burden of the ensemble while trying to preserve its predictive behavior. Most existing distillation formulations summarize the ensemble by capturing its average predictions. As a result, the diversity of the ensemble predictions, stemming from each member, is lost. Thus, the distilled model cannot provide a measure of uncertainty comparable to that of the original ensemble. To retain more faithfully the diversity of the ensemble, we propose a distillation method based on a single multi-headed neural network, which we refer to as Hydra. The shared body network learns a joint feature representation that enables each head to capture the predictive behavior of each ensemble member. We demonstrate that with a slight increase in parameter count, Hydra improves distillation performance on classification and regression settings while capturing the uncertainty behavior of the original ensemble over both in-domain and out-of-distribution tasks.
研究动机与目标
- 受到在蒸馏后保留集成不确定性需求的启发。
- 提出一种多头蒸馏架构以保留每个成员的行为。
- 在分类与回归任务中,将 Hydra 与标准蒸馏和先前的网络进行比较评估。
提出的方法
- 引入 Hydra:一个具有一个共享主体和 M 个头的神经网络(每个头对应一个集成成员)。
- 每个头模仿一个特定的集成成员;主体提供共享的特征表示。
- 通过最小化每个头与其相应集成成员之间的平均 KL 散度(分类)或高斯输出之间的散度(回归)来训练。
- 在训练期间使用温度 T 来升温分布以提高跨分布覆盖。
- 两阶段训练:首先模仿平均集成(Hinton head),然后训练所有头以匹配各个成员。
- 在各数据集上与 Knowledge Distillation 和 Prior Networks 进行比较;报告 NLL、Brier 得分、准确率和模型不确定性。
实验结果
研究问题
- RQ1与基于平均的蒸馏相比,Hydra 是否能够真实地保留集成多样性?
- RQ2Hydra 是否在域内和域外数据上提升预测性能与不确定性量化?
- RQ3Hydra 如何在参数效率与对集成多样性的保真之间进行权衡?
- RQ4Hydra 对分类和回归任务的影响是什么?
主要发现
- Hydra 在 MNIST 和 CIFAR-10 上达到或超过集成的预测性能。
- 在 MNIST 上,Hydra 的 NLL 为 0.0465,Brier 为 −0.9776,接近集成的 NLL 0.0439 和 Brier −0.9780,MU 为 2.28e-5。
- 在 CIFAR-10 上,Hydra 的 ACC 为 0.8992,NLL 为 0.3179,更接近集成而非其他蒸馏方法,MU 为 0.0074。
- 在若干指标上,Hydra 优于 Knowledge Distillation 和 Prior Networks,特别在不确定性量化(MU)和 NLL 上。
- Hydra 在参数增量适中与对集成多样性的保真之间提供了一个实际的平衡。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。