Skip to main content
QUICK REVIEW

[论文解读] Meta Continual Learning

Risto Vuorio, Dong-Yeon Cho|arXiv (Cornell University)|Jun 11, 2018
Domain Adaptation and Few-Shot Learning参考文献 28被引用 27
一句话总结

本文提出元持续学习(Meta Continual Learning),一种元学习方法,通过训练神经网络以预测最优参数更新步长,从而在持续学习中最小化灾难性遗忘。该方法通过学习基于过往任务中参数重要性的更新调整,实现了在顺序MNIST任务上的优异性能,优于标准SGD,并在准确率上匹配或超越多个基线方法。

ABSTRACT

Using neural networks in practical settings would benefit from the ability of the networks to learn new tasks throughout their lifetimes without forgetting the previous tasks. This ability is limited in the current deep neural networks by a problem called catastrophic forgetting, where training on new tasks tends to severely degrade performance on previous tasks. One way to lessen the impact of the forgetting problem is to constrain parameters that are important to previous tasks to stay close to the optimal parameters. Recently, multiple competitive approaches for computing the importance of the parameters with respect to the previous tasks have been presented. In this paper, we propose a learning to optimize algorithm for mitigating catastrophic forgetting. Instead of trying to formulate a new constraint function ourselves, we propose to train another neural network to predict parameter update steps that respect the importance of parameters to the previous tasks. In the proposed meta-training scheme, the update predictor is trained to minimize loss on a combination of current and past tasks. We show experimentally that the proposed approach works in the continual learning setting.

研究动机与目标

  • 为解决持续学习中灾难性遗忘问题,即神经网络在学习新任务时性能在旧任务上下降的问题。
  • 开发一种通用且自动化的持续学习方法,避免依赖手工设计的正则化假设。
  • 探索在持续学习背景下学习优化方法,利用元学习训练更新预测器。
  • 证明一种与任务无关的更新规则可行,且能尊重先前任务的性能表现。

提出的方法

  • 训练一个元网络(更新预测器)以预测每个参数的更新步长,从而最小化对先前任务的遗忘。
  • 采用元训练方案,其中更新预测器在当前任务与过去任务的组合上进行优化。
  • 更新预测器基于对先前任务中参数重要性的估计,输出每个参数梯度更新的缩放因子。
  • 使用结合当前任务和先前任务性能的损失函数,端到端训练模型。
  • 该方法动态调整更新幅度:对重要参数采用较小更新,对灵活参数采用较大更新。
  • 该方法无需显式记忆过去数据或任务特定正则化,而是依赖于学习到的更新指导。

实验结果

研究问题

  • RQ1是否可以通过学习到的更新预测器,在无需显式记忆或手工设计正则化的情况下,有效缓解持续学习中的灾难性遗忘?
  • RQ2元学习得到的优化规则在具有共享数据分布的顺序任务上,泛化能力如何?
  • RQ3更新预测器是否学会识别对先前任务至关重要的参数,并在学习新任务时加以保护?
  • RQ4与现有持续学习基线相比,元学习得到的更新规则性能如何?

主要发现

  • 所提方法在不相交MNIST上达到82.3% ± 0.92的测试准确率,显著优于SGD(47.72%),并匹配或超越多个基线方法。
  • 在打乱MNIST上,该方法达到95.5% ± 0.58的准确率,接近最先进方法(如IMM的98.3% ± 0.08和EWC的98.2%)。
  • 更新预测器学会将接近零的更新分配给对过去任务至关重要的参数,这由元训练过程中三峰分布的演化得到证实。
  • 模型表现出有效的知识保留,随着元训练的推进,非关键参数的输出值逐渐增加。
  • 该方法在原则上可推广至更长的任务序列,尽管当前实验仅限于三个任务。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。