[论文解读] Remember the Past: Distilling Datasets into Addressable Memories for Neural Networks
本论文提出一种基于记忆的数据集蒸馏方法,其中一组共享基底(记忆)通过可学习的寻址进行组合,以合成训练数据,从而实现与类别数无关的压缩,并实现强大的持续学习性能。
We propose an algorithm that compresses the critical information of a large dataset into compact addressable memories. These memories can then be recalled to quickly re-train a neural network and recover the performance (instead of storing and re-training on the full original dataset). Building upon the dataset distillation framework, we make a key observation that a shared common representation allows for more efficient and effective distillation. Concretely, we learn a set of bases (aka ``memories'') which are shared between classes and combined through learned flexible addressing functions to generate a diverse set of training examples. This leads to several benefits: 1) the size of compressed data does not necessarily grow linearly with the number of classes; 2) an overall higher compression rate with more effective distillation is achieved; and 3) more generalized queries are allowed beyond recalling the original classes. We demonstrate state-of-the-art results on the dataset distillation task across six benchmarks, including up to 16.5% and 9.7% in retained accuracy improvement when distilling CIFAR10 and CIFAR100 respectively. We then leverage our framework to perform continual learning, achieving state-of-the-art results on four benchmarks, with 23.2% accuracy improvement on MANY. The code is released on our project webpage https://github.com/princetonvisualai/RememberThePast-DatasetDistillation.
研究动机与目标
- 推动将大数据集压缩成能够在重新训练时保留性能的紧凑记忆。
- 提出一种记忆寻址形式,在跨类使用共享基底以改进压缩。
- 证明带动量的时间反向传播和较长展开能带来最先进的蒸馏结果。
- 表明该框架泛化到持续学习以及超出离散标签的灵活查询类型。
提出的方法
- 用一组基底 M = {b1,...,bK} 表示数据集,存储在记忆中。
- 使用一个可学习的寻址函数 A(y) 将基底线性组合成每个查询 y 的合成数据 x'。
- 用方程 x'^{T} = y^{T} A_i [b1;...;bK]^T 定义 x' 以对每个 y 生成 r 个样本。
- 通过带时间反向传播(BPTT)的双层优化来训练记忆和寻址。
- 在内部循环中,进行带动量的 SGD 和较长的展开(例如 150–200 步)以产生有信息量的梯度。
- 允许生成合成数据的通用化、非类别离散的查询。
实验结果
研究问题
- RQ1跨类共享记忆表示是否可以提高数据集蒸馏的压缩率?
- RQ2记忆寻址形式是否能够实现更灵活的(非独热编码)查询以用于蒸馏并提高重新训练性能?
- RQ3相较于单步梯度方法,BPTT 中对动量、展开长度等优化选择如何影响蒸馏?
- RQ4该方法在持续学习和基于记忆的回忆场景中的有效性如何?
主要发现
- 在六个数据集蒸馏基准上达到最先进结果,例如 CIFAR10,在每类1张图片时恢复精度达到 66.4%。
- 在多种预算下在 CIFAR10 取得 66.4%,在 TinyImageNet 取得 34.0%,表明强的压缩性能。
- 用简单的“先压缩再回忆”方法展示了强烈的持续学习提升,例如在 MANY 上保留的准确性提升为 23.2%。
- 通过共同记忆表示在跨类之间实现信息共享,显著优于基于类别的记忆的压缩。
- 带动量和长展开的 BPTT 相对于先前的梯度匹配基线显著提升性能。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。