Skip to main content
QUICK REVIEW

[论文解读] Prototypical Networks for Few-shot Learning

Jake Snell, Kevin Swersky|arXiv (Cornell University)|Mar 15, 2017
Domain Adaptation and Few-Shot Learning被引用 5,185
一句话总结

原型网络学会一种简单的嵌入表示,每个类别由其样本的均值(原型)来表示;分类通过最近的原型使用欧氏距离实现,在小样本和零样本任务中达到最先进的结果。

ABSTRACT

We propose prototypical networks for the problem of few-shot classification, where a classifier must generalize to new classes not seen in the training set, given only a small number of examples of each new class. Prototypical networks learn a metric space in which classification can be performed by computing distances to prototype representations of each class. Compared to recent approaches for few-shot learning, they reflect a simpler inductive bias that is beneficial in this limited-data regime, and achieve excellent results. We provide an analysis showing that some simple design decisions can yield substantial improvements over recent approaches involving complicated architectural choices and meta-learning. We further extend prototypical networks to zero-shot learning and achieve state-of-the-art results on the CU-Birds dataset.

研究动机与目标

  • 激励一种简单、数据高效的归纳偏置,用于减少在有限数据下的过拟合。
  • 提出一种基于度量的办法,使每个类别在嵌入空间中由一个原型表示。
  • 表示欧氏距离到类别原型能带来强性能,并通过混合密度和聚类概念来解释该方法。
  • 通过对类别元数据进行嵌入形成原型,将该方法扩展到零样本学习,并在标准基准上进行评估。

提出的方法

  • 学习一个嵌入函数 f_phi,将输入映射到一个 M 维空间。
  • 为每个类别 k 定义一个原型 c_k,作为嵌入支持样本的均值:c_k = (1/|S_k|) sum_{(x_i,y_i) in S_k} f_phi(x_i)。
  • 通过 p_phi(y=k|x) 与 exp(-d(f_phi(x), c_k)) 成正比来对查询 x 进行分类,使用距离 d(主要是平方欧氏距离)。
  • 通过最小化真实类别的负对数概率来训练,训练过程的 episode 抽取子集类别和样本作为支持集和查询集。
  • 给出一个概率解释:对于常规的 Bregman 发散,模型对应一个有限混合,其原型均值作为聚类中心。
  • 通过设定 c_k = g_theta(v_k) 将其扩展到零样本学习,其中 v_k 是类别元数据,g_theta 是一个学习的嵌入;在需要时固定原型范数。

实验结果

研究问题

  • RQ1简单的原型基嵌入,且每个类别固定数量的原型,能否在少样本设置中对未见类别泛化?
  • RQ2距离度量的选择如何影响少样本学习中的基于原型的分类性能?
  • RQ3使用 episodic 方案和更高-way 的 episode 训练是否改善少样本任务的泛化?
  • RQ4原型框架是否能有效扩展到使用类别元数据的零样本学习?

主要发现

  • 在 Omniglot 上,使用欧氏距离的 ProtNets 在 1-shot 达到 98.8%,在 5-shot 达到 99.7%(5-way),在某些设定下为 96.0%/98.9%(20-way)。
  • 在 miniImageNet 上,ProtNets 的 1-shot: 49.42%,5-shot: 68.20%(5-way 设置),超过基线包括 Matching Networks 和 Meta-Learner LSTM。
  • 在 CUB 零样本上,使用 GoogLeNet 特征和 312 维属性的 ProtNets 达到 54.6% 的 50 类准确率,超过多种属性基和嵌入方法。
  • 欧氏距离在此框架中一向优于余弦距离,且更高-way 的训练 episode 可以提升泛化。
  • 该方法比许多元学习方法更简单高效,同时在多个基准上达到最先进的结果。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。