Skip to main content
QUICK REVIEW

[论文解读] Remember the Past: Distilling Datasets into Addressable Memories for Neural Networks

Zhiwei Deng, Olga Russakovsky|arXiv (Cornell University)|Jun 6, 2022
Domain Adaptation and Few-Shot Learning被引用 22
一句话总结

本论文提出一种基于记忆的数据集蒸馏方法,其中一组共享基底(记忆)通过可学习的寻址进行组合,以合成训练数据,从而实现与类别数无关的压缩,并实现强大的持续学习性能。

ABSTRACT

We propose an algorithm that compresses the critical information of a large dataset into compact addressable memories. These memories can then be recalled to quickly re-train a neural network and recover the performance (instead of storing and re-training on the full original dataset). Building upon the dataset distillation framework, we make a key observation that a shared common representation allows for more efficient and effective distillation. Concretely, we learn a set of bases (aka ``memories'') which are shared between classes and combined through learned flexible addressing functions to generate a diverse set of training examples. This leads to several benefits: 1) the size of compressed data does not necessarily grow linearly with the number of classes; 2) an overall higher compression rate with more effective distillation is achieved; and 3) more generalized queries are allowed beyond recalling the original classes. We demonstrate state-of-the-art results on the dataset distillation task across six benchmarks, including up to 16.5% and 9.7% in retained accuracy improvement when distilling CIFAR10 and CIFAR100 respectively. We then leverage our framework to perform continual learning, achieving state-of-the-art results on four benchmarks, with 23.2% accuracy improvement on MANY. The code is released on our project webpage https://github.com/princetonvisualai/RememberThePast-DatasetDistillation.

研究动机与目标

  • 推动将大数据集压缩成能够在重新训练时保留性能的紧凑记忆。
  • 提出一种记忆寻址形式,在跨类使用共享基底以改进压缩。
  • 证明带动量的时间反向传播和较长展开能带来最先进的蒸馏结果。
  • 表明该框架泛化到持续学习以及超出离散标签的灵活查询类型。

提出的方法

  • 用一组基底 M = {b1,...,bK} 表示数据集,存储在记忆中。
  • 使用一个可学习的寻址函数 A(y) 将基底线性组合成每个查询 y 的合成数据 x'。
  • 用方程 x'^{T} = y^{T} A_i [b1;...;bK]^T 定义 x' 以对每个 y 生成 r 个样本。
  • 通过带时间反向传播(BPTT)的双层优化来训练记忆和寻址。
  • 在内部循环中,进行带动量的 SGD 和较长的展开(例如 150–200 步)以产生有信息量的梯度。
  • 允许生成合成数据的通用化、非类别离散的查询。

实验结果

研究问题

  • RQ1跨类共享记忆表示是否可以提高数据集蒸馏的压缩率?
  • RQ2记忆寻址形式是否能够实现更灵活的(非独热编码)查询以用于蒸馏并提高重新训练性能?
  • RQ3相较于单步梯度方法,BPTT 中对动量、展开长度等优化选择如何影响蒸馏?
  • RQ4该方法在持续学习和基于记忆的回忆场景中的有效性如何?

主要发现

  • 在六个数据集蒸馏基准上达到最先进结果,例如 CIFAR10,在每类1张图片时恢复精度达到 66.4%。
  • 在多种预算下在 CIFAR10 取得 66.4%,在 TinyImageNet 取得 34.0%,表明强的压缩性能。
  • 用简单的“先压缩再回忆”方法展示了强烈的持续学习提升,例如在 MANY 上保留的准确性提升为 23.2%。
  • 通过共同记忆表示在跨类之间实现信息共享,显著优于基于类别的记忆的压缩。
  • 带动量和长展开的 BPTT 相对于先前的梯度匹配基线显著提升性能。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。