Skip to main content
QUICK REVIEW

[论文解读] metric-learn: Metric Learning Algorithms in Python

William de Vazelhes, CJ Carey|arXiv (Cornell University)|Aug 13, 2019
Machine Learning and Algorithms参考文献 17被引用 31
一句话总结

metric-learn 是一个 Python 软件包,实现了使用统一的 scikit-learn 兼容 API 的监督式和弱监督式度量学习算法,可无缝集成到机器学习流程中,用于分类、聚类和检索等任务。它支持通过 LMNN、MMC 和三元组/四元组学习器等算法实现马氏距离学习,完全兼容交叉验证、超参数调优和流水线构建。

ABSTRACT

metric-learn is an open source Python package implementing supervised and weakly-supervised distance metric learning algorithms. As part of scikit-learn-contrib, it provides a unified interface compatible with scikit-learn which allows to easily perform cross-validation, model selection, and pipelining with other machine learning estimators. metric-learn is thoroughly tested and available on PyPi under the MIT licence.

研究动机与目标

  • 提供一个统一的、可投入生产的 Python 软件包,用于度量学习算法,使其能与 scikit-learn 生态系统无缝集成。
  • 在一个一致的接口中支持多种监督类型——类别标签、成对样本、三元组和四元组。
  • 支持端到端的机器学习工作流,包括交叉验证、模型选择以及与 k-NN 和 LDA 等估计器的流水线构建。
  • 提供可扩展、文档齐全且积极维护的开源解决方案,采用 MIT 许可证。
  • 通过支持标准和新兴的度量学习方法,弥合度量学习研究与实际部署之间的差距。

提出的方法

  • 该软件包通过学习一个线性变换矩阵 L 来实现马氏距离学习,其中点 x 和 x′ 之间的距离定义为 D_L(x,x′) = ||Lx - Lx′||₂。
  • 所有度量学习器都继承自 scikit-learn 的基类(如 TransformerMixin),从而实现与流水线、交叉验证和模型选择的兼容性。
  • 监督式学习器使用类别标签来优化一个度量,使同类别样本更接近,异类样本更远离。
  • 成对学习器使用点对的二元标签(+1 表示相似,-1 表示不相似)来学习一个度量,使相似点对之间的距离最小化。
  • 三元组和四元组学习器分别使用三元组和四元组的点,强制要求第一个点比第三个点更接近第二个点(或前两个点比后两个点更接近)。
  • 该软件包包含一种阈值校准方法(calibrate_threshold),可自动为弱监督式预测任务设置距离阈值。

实验结果

研究问题

  • RQ1如何在单一、与 scikit-learn 兼容的接口下统一度量学习算法,以提升其在机器学习工作流中的可用性和集成性?
  • RQ2在真实应用场景中,与完全监督方法相比,弱监督式度量学习(如成对、三元组、四元组监督)的性能影响如何?
  • RQ3一个单一、可扩展的 Python 软件包能否高效支持跨不同监督级别和数据规模的多种度量学习算法?
  • RQ4将度量学习与 scikit-learn 流水线集成,对下游任务(如 k-NN 分类)的模型准确性和开发效率有何影响?
  • RQ5构建一个可投入生产的开源度量学习库的关键设计原则是什么,以同时支持经典和新兴算法?

主要发现

  • metric-learn 通过一致的、与 scikit-learn 兼容的 API,支持 14 种度量学习算法,涵盖四种监督类型——监督式、成对、三元组和四元组。
  • 该软件包可与 scikit-learn 工具完全集成,包括交叉验证、网格搜索和流水线组合,如在 Wine 数据集上使用 LMNN-kNN 流水线所展示的那样。
  • 弱监督式学习器(如 MMC)在 LFW 成对数据集上的交叉验证中实现了 0.893 ± 0.025 的平均 ROC-AUC 得分,表明在极少监督下仍具有强大性能。
  • 该软件包可在 PyPI 和 conda-forge 上获取,具备完整的测试覆盖率并处于积极开发中,支持 Python 3.6 及以上版本。
  • components_ 和 get_mahalanobis_matrix 的实现使用户能够提取并解释学习到的变换矩阵和马氏距离矩阵,以供分析和可视化。
  • 未来扩展将包括用于大规模数据的随机求解器,以及对多标签、高维和非线性度量学习算法的支持。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。