[论文解读] Memory Matching Networks for One-Shot Image Recognition
本文提出了一种新型深度学习架构——记忆匹配网络(MM-Net),用于少样本图像识别。该模型通过引入记忆模块和上下文学习器,动态地实时预测网络参数,从而增强卷积神经网络(CNN)。通过使用每类一个或少数几个样本的支持集进行训练,并利用双向长短期记忆网络(bi-LSTM)生成自适应的卷积层权重,MM-Net在Omniglot数据集上实现了99.28%的准确率,在mini-ImageNet数据集上实现了53.37%的准确率,达到当前最优性能。
In this paper, we introduce the new ideas of augmenting Convolutional Neural Networks (CNNs) with Memory and learning to learn the network parameters for the unlabelled images on the fly in one-shot learning. Specifically, we present Memory Matching Networks (MM-Net) --- a novel deep architecture that explores the training procedure, following the philosophy that training and test conditions must match. Technically, MM-Net writes the features of a set of labelled images (support set) into memory and reads from memory when performing inference to holistically leverage the knowledge in the set. Meanwhile, a Contextual Learner employs the memory slots in a sequential manner to predict the parameters of CNNs for unlabelled images. The whole architecture is trained by once showing only a few examples per class and switching the learning from minibatch to minibatch, which is tailored for one-shot learning when presented with a few examples of new categories at test time. Unlike the conventional one-shot learning approaches, our MM-Net could output one unified model irrespective of the number of shots and categories. Extensive experiments are conducted on two public datasets, i.e., Omniglot and \emph{mini}ImageNet, and superior results are reported when compared to state-of-the-art approaches. More remarkably, our MM-Net improves one-shot accuracy on Omniglot from 98.95% to 99.28% and from 49.21% to 53.37% on \emph{mini}ImageNet.
研究动机与目标
- 解决少样本学习中训练与推理之间的差异问题,即标准的小批量训练无法匹配少样本测试设置。
- 克服在仅有一个或少数几个样本可用于新类别时,微调和迁移学习方法的局限性。
- 开发一种统一模型,无需重新训练即可在不同样本数和类别数下实现泛化。
- 利用基于循环记忆的上下文学习器,实现对CNN的动态、上下文感知的参数预测。
- 通过整体利用所有支持集类别的知识,提升特征表示与相似性匹配能力。
提出的方法
- 为CNN增加一个记忆模块,通过写入和读取控制器,从标记的支持集图像中存储并检索特征。
- 使用双向长短期记忆网络(bi-LSTM)作为上下文学习器,按顺序处理记忆槽,并为未标记图像预测卷积层参数。
- 通过最小化每个批次中基于支持集条件的未标记图像分类误差,端到端训练整个网络。
- 构建包含混合样本数与类别数设置(如2–5类、1–5样本)的训练批次,以提升在多样化测试场景下的泛化能力。
- 通过未标记图像嵌入与支持集嵌入之间的点积计算相似度得分,以分配预测标签。
- 通过上下文学习器实时计算网络参数,消除微调需求,实现在无需重新训练的情况下对新类别进行推理。
实验结果
研究问题
- RQ1能否训练一个统一的深度神经网络,在无需微调的情况下,实现对不同样本数和类别数的少样本学习任务的泛化?
- RQ2如何使训练过程与推理条件对齐,以提升少样本设置下的泛化能力?
- RQ3带有顺序上下文学习器的记忆增强架构是否能够提升特征表示与相似性匹配能力?
- RQ4混合训练策略(样本数与类别数变化)对模型泛化能力和性能有何影响?
- RQ5与先前方法相比,记忆模块与参数预测机制如何增强判别性特征学习?
主要发现
- 在Omniglot数据集上,MM-Net实现了99.28%的top-1准确率,显著优于之前最优结果98.95%。
- 在mini-ImageNet数据集上,MM-Net在5类1样本评估设置下,少样本准确率达到53.37%,超越了此前最优结果49.21%。
- 混合训练策略(Mixed C-way k-shot)优于所有统一训练策略(固定样本数或类别数),表明其在多样化测试场景下具有更优的泛化能力。
- 上下文学习器中bi-LSTM隐藏层大小对性能影响极小,128至1024个单元之间的差异小于0.013,表明对超参数选择具有鲁棒性。
- t-SNE可视化显示,与匹配网络(MN)相比,MM-Net学习到的图像表征更具语义分离性,类别聚类更清晰。
- 相似度矩阵可视化表明,与MN相比,MM-Net产生的类内相似度更高、类间相似度更低,表明其具有更强的判别性特征学习能力。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。