[论文解读] Pseudo-Recursal: Solving the Catastrophic Forgetting Problem in Deep Neural Networks
本文提出伪递归训练(pseudo-recursal),将伪再训练与GAN生成的伪项结合,在深度网络上实现针对 CIFAR-10、SVHN 和 MNIST 的连续学习,而不增加任务特定的记忆容量,显著减少遗忘。
In general, neural networks are not currently capable of learning tasks in a sequential fashion. When a novel, unrelated task is learnt by a neural network, it substantially forgets how to solve previously learnt tasks. One of the original solutions to this problem is pseudo-rehearsal, which involves learning the new task while rehearsing generated items representative of the previous task/s. This is very effective for simple tasks. However, pseudo-rehearsal has not yet been successfully applied to very complex tasks because in these tasks it is difficult to generate representative items. We accomplish pseudo-rehearsal by using a Generative Adversarial Network to generate items so that our deep network can learn to sequentially classify the CIFAR-10, SVHN and MNIST datasets. After training on all tasks, our network loses only 1.67% absolute accuracy on CIFAR-10 and gains 0.24% absolute accuracy on SVHN. Our model's performance is a substantial improvement compared to the current state of the art solution.
研究动机与目标
- 解决 DNNs 序列任务学习中的灾难性遗忘。
- 提出一种与任务无关且节省内存的持续学习方法。
- 利用 Generative Adversarial Networks 来生成用于复现的代表性伪项。
- 展示将递归伪复现(pseudo-recursal)同时应用于分类器和生成器。
- 在图像数据集上与 Elastic Weight Consolidation 和标准伪复现进行比较。
提出的方法
- 在固定架构下形式化对序列任务的伪复现。
- 使用 GAN 生成表示过去任务的伪图像以进行复现。
- 对GAN递归应用伪复现,以覆盖多个任务而不需要每个任务额外的内存。
- 在当前任务上训练分类器和 GAN,同时使用表示先前任务的伪项。
- 在 CIFAR-10、SVHN 和 MNIST 上进行多种实验条件(std、reh、pseudo_rec、ewc、ewc_c10、rote_learn)的评估。
- 在学习新任务后衡量对先前任务的准确率保持情况;报告绝对准确率的变化。
实验结果
研究问题
- RQ1在不增加每个任务内存的前提下,使用带有 GAN 生成的伪图像的伪复现是否能在 CIFAR-10、SVHN 和 MNIST 上防止灾难性遗忘?
- RQ2对分类器和 GAN 同时递归应用伪复现是否比标准伪复现或 EWC 更能保持记忆?
- RQ3在序贯学习后相对于基线,伪递归在保持先前任务性能方面的表现如何?
- RQ4使用基于 GAN 的伪复现方法时,训练时间和内存的权衡有哪些?
主要发现
- 该方法在学习完所有任务后,CIFAR-10 的准确率仅下降 1.67%(绝对值)。
- 在后续任务训练后,SVHN 的准确率绝对提升 0.24%。
- 伪递归在保持先前任务性能方面超过 EWC,适用于 CIFAR-10 和 SVHN。
- 相比机械学习基线,伪递归在维持过去任务准确率方面有显著提升(例如在 CIFAR-10 和 SVHN 上分别提升 9.6% 和 13.11%)。
- 对过去任务递归训练 GAN 能在不增加每个任务内存的情况下生成代表性伪项。
- 该方法展示了有效的持续学习,并且对中间神经元没有硬性约束,也不存储过去任务数据。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。