[论文解读] Learning to Propagate for Graph Meta-Learning
本文提出 Gated Propagation Network (GPN),一种在类图上传播类别原型以改善小样本分类的元学习者,具备基于记忆的终身更新和多头注意门控。它在两个来自 ImageNet 的图数据集上相对于基线取得一致的提升。
Meta-learning extracts common knowledge from learning different tasks and uses it for unseen tasks. It can significantly improve tasks that suffer from insufficient training data, e.g., few shot learning. In most meta-learning methods, tasks are implicitly related by sharing parameters or optimizer. In this paper, we show that a meta-learner that explicitly relates tasks on a graph describing the relations of their output dimensions (e.g., classes) can significantly improve few shot learning. The graph's structure is usually free or cheap to obtain but has rarely been explored in previous works. We develop a novel meta-learner of this type for prototype-based classification, in which a prototype is generated for each class, such that the nearest neighbor search among the prototypes produces an accurate classification. The meta-learner, called "Gated Propagation Network (GPN)", learns to propagate messages between prototypes of different classes on the graph, so that learning the prototype of each class benefits from the data of other related classes. In GPN, an attention mechanism aggregates messages from neighboring classes of each class, with a gate choosing between the aggregated message and the message from the class itself. We train GPN on a sequence of tasks from many-shot to few shot generated by subgraph sampling. During training, it is able to reuse and update previously achieved prototypes from the memory in a life-long learning cycle. In experiments, under different training-test discrepancy and test task generation settings, GPN outperforms recent meta-learning methods on two benchmark datasets. The code of GPN and dataset generation is available at https://github.com/liulu112601/Gated-Propagation-Net.
研究动机与目标
- 利用类别之间的图结构关系,通过在相关类别之间传播原型来改进小样本学习。
- 开发一个能够通过在图上整合邻近类别信息来更新类别原型的元学习者。
- 通过维护原型记忆来支撑在新任务上的学习,实现终身学习。
提出的方法
- 每个类别的原型通过对 K-shot 样本求平均进行初始化,并通过基于图的传播进行细化。
- 在聚合消息时使用多头注意力来对邻居原型进行加权。
- 应用门控机制以在邻居传播的消息和该类别自身消息之间进行选择。
- 在图上进行多步传播以进行超越直接邻居的关系推理。
- 为了在当前任务集合中某类别无邻居时仍能进行传播,维护原型记忆。
- 通过包括辅助监督任务和退火的课程策略进行训练,并构建最大生成树以约束传播路径。
实验结果
研究问题
- RQ1明确在类别图上定义的任务关系是否能提升元学习的少样本分类效果?
- RQ2在图边上传播类别原型是否比传统原型网络得到更好的决策边界?
- RQ3记忆、基于 MST 的传播路径和多头门控如何贡献于性能与效率?
- RQ4训练类和测试类在图上的任务距离(图跳数)对 GPN 性能的影响?
主要发现
- GPN 在 tiered ImageNet-Close 和 tiered ImageNet-Far 的多种设置下优于若干近年的少样本基线。
- 在 tiered ImageNet-Close 的随机采样下,5 类 1-shot 的准确率:GPN 48.37% vs Prototypical Net 42.87% 与 GPN+ 50.54%。
- 在同一设置下的 5-way 5-shot,分别为 64.14%(GPN) vs 62.68%(Prototypical Net)与 65.74%(GPN+)。
- 采用滚雪球采样(较近的任务关系)时,5-way 1-shot 的准确率:GPN 39.56% vs Prototypical Net 35.27% 与 GPN+ 40.78%。
- 在 tiered ImageNet-Far 的随机采样下,5-way 1-shot 的准确率:GPN 47.54% vs Prototypical Net 44.30% 与 GPN+ 47.49%。
- 在所有报告的设置中,GPN 相较基线表现出显著提升,当训练/测试类别在图上更接近时提升更大(GPN+ 常有帮助)。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。