Skip to main content
QUICK REVIEW

[论文解读] Wasserstein Distance Rivals Kullback-Leibler Divergence for Knowledge Distillation

Jiayi Lv, Hee-Chul Yang|arXiv (Cornell University)|Dec 11, 2024
Brain Tumor Detection and Classification被引用 5
一句话总结

本文提出 WKD,一种基于 Wasserstein 距离的知识蒸馏方法,能够实现对数 logits 的跨类别关系推理(WKD-L)以及对中间特征的连续分布匹配(WKD-F),在 ImageNet、CIFAR-100 和 MS-COCO 上优于 KL-Div 变体和当前最先进的蒸馏方法。

ABSTRACT

Since pioneering work of Hinton et al., knowledge distillation based on Kullback-Leibler Divergence (KL-Div) has been predominant, and recently its variants have achieved compelling performance. However, KL-Div only compares probabilities of the corresponding category between the teacher and student while lacking a mechanism for cross-category comparison. Besides, KL-Div is problematic when applied to intermediate layers, as it cannot handle non-overlapping distributions and is unaware of geometry of the underlying manifold. To address these downsides, we propose a methodology of Wasserstein Distance (WD) based knowledge distillation. Specifically, we propose a logit distillation method called WKD-L based on discrete WD, which performs cross-category comparison of probabilities and thus can explicitly leverage rich interrelations among categories. Moreover, we introduce a feature distillation method called WKD-F, which uses a parametric method for modeling feature distributions and adopts continuous WD for transferring knowledge from intermediate layers. Comprehensive evaluations on image classification and object detection have shown (1) for logit distillation WKD-L outperforms very strong KL-Div variants; (2) for feature distillation WKD-F is superior to the KL-Div counterparts and state-of-the-art competitors. The source code is available at https://peihuali.org/WKD

研究动机与目标

  • 通过利用跨类别的类别关系(IRs),推动超越按类别的 KL-Divergence 的知识蒸馏改进。
  • 提出基于 WD 的对 logits(WKD-L)和中间特征(WKD-F)的蒸馏方法。
  • 用 Centered Kernel Alignment (CKA) 对类别 IRs 进行建模,并将其转换为 logit 蒸馏的 WD 运输成本。
  • 用高斯分布建模中间层特征分布,并将 WD 计算为黎曼度量以蒸馏特征。

提出的方法

  • 通过对教师特征计算的 CKA 定义类别间的 IRs,并将 IRs 转换为 logit 蒸馏的 WD 运输成本。
  • 使用基于熵正则化的运输问题,将教师与学生 logits 之间的离散 WD 表述为一个成本来自 IR 基于相似性的成本。
  • 在 logits 中加入目标与非目标的分离,通过一个两项损失:对非目标的 WD 与对目标的交叉熵相结合。
  • 对于特征,将教师和学生的分布建模为高斯分布(均值和协方差),并使用高斯间的闭式 WD(均值项与协方差项之和)。
  • 可选地应用空间金字塔化,并对实用性在 Gaussian Diag 与 Full 协方差之间进行选择,通过 gamma 参数平衡均值与协方差的贡献。
(a) Real-world categories exhibit rich interrelations (IRs) in feature space, e.g., dog is near other mammal while far from artifact like car. We quantify pairwise IRs as feature similarities among categories. Best viewed by zooming in .
(a) Real-world categories exhibit rich interrelations (IRs) in feature space, e.g., dog is near other mammal while far from artifact like car. We quantify pairwise IRs as feature similarities among categories. Best viewed by zooming in .

实验结果

研究问题

  • RQ1WD 基于蒸馏是否能利用跨类别关系在 logit 蒸馏中超过以 KL-Div 为基础的方法?
  • RQ2将中间层特征建模为高斯分布并使用 WD,相较于 KL-Div 和非参数方法,是否能提升知识转移?
  • RQ3IR 建模方法(CKA 配合不同核函数)对 WKD-L 性能的影响是?
  • RQ4WKD-L 与 WKD-F 单独及组合在分类和检测任务上的表现如何?

主要发现

  • WKD-L 在 ImageNet 和 CIFAR-100 的 logit 蒸馏中优于强大的 KL-Div 变体。
  • WKD-F 在特征蒸馏方面超过 KL-Div 对手,其中高斯(Diag)常被偏好以提升鲁棒性和效率。
  • 使用 CKA 建模类别关系(尤其是 RBF 或线性核)提升基于 WD 的 logit 蒸馏。
  • 结合 WKD-L 与 WKD-F 在分类和检测任务上比单独使用任一方法有进一步提升。
  • 在 MS-COCO 目标检测上,基于 WD 的蒸馏相对于基于 KL-Div 的方法表现出有竞争力的提升。
(b) For logit distillation, discrete WD performs cross-category comparison by exploiting pairwise IRs, in contrast to KL-Div that is a category-to-category measure and lacks a mechanism to use such IRs (cf. Figure 2 ). For feature distillation, we use Gaussians for distribution modeling and continuo
(b) For logit distillation, discrete WD performs cross-category comparison by exploiting pairwise IRs, in contrast to KL-Div that is a category-to-category measure and lacks a mechanism to use such IRs (cf. Figure 2 ). For feature distillation, we use Gaussians for distribution modeling and continuo

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。