[论文解读] Low-Memory Neural Network Training: A Technical Report
本文量化了训练的内存需求,并评估四种技术——稀疏、低精度、微批处理和梯度检查点——以在对 WideResNet 和 DynamicConv Transformer 的可控精度损失下降低训练内存。
Memory is increasingly often the bottleneck when training neural network models. Despite this, techniques to lower the overall memory requirements of training have been less widely studied compared to the extensive literature on reducing the memory requirements of inference. In this paper we study a fundamental question: How much memory is actually needed to train a neural network? To answer this question, we profile the overall memory usage of training on two representative deep learning benchmarks -- the WideResNet model for image classification and the DynamicConv Transformer model for machine translation -- and comprehensively evaluate four standard techniques for reducing the training memory requirements: (1) imposing sparsity on the model, (2) using low precision, (3) microbatching, and (4) gradient checkpointing. We explore how each of these techniques in isolation affects both the peak memory usage of training and the quality of the end model, and explore the memory, accuracy, and computation tradeoffs incurred when combining these techniques. Using appropriate combinations of these techniques, we show that it is possible to the reduce the memory required to train a WideResNet-28-2 on CIFAR-10 by up to 60.7x with a 0.4% loss in accuracy, and reduce the memory required to train a DynamicConv model on IWSLT'14 German to English translation by up to 8.7x with a BLEU score drop of 0.15.
研究动机与目标
- 量化神经网络训练中涉及的内存组件(模型、优化器、激活)并识别它们的相对贡献。
- 单独评估四种内存降低技术对训练内存和模型精度的影响。
- 探究将这些技术结合起来对内存、精度和计算的影响。
- 在具有代表性的基准测试上展示潜在的内存节省(CIFAR-10 上的 WideResNet 和 IWSLT'14 De→En 的 DynamicConv)。
提出的方法
- 在训练期间将内存使用量划分为模型、优化器和激活组件进行分析。
- 通过动态稀疏重新参数化评估稀疏性,并在不同非零百分比下测量精度。
- 评估使用 FP16 的低精度训练及动态损失缩放对精度和 FLOPs 的影响。
- 通过模拟更小的微批量来测试微批处理,并在考虑批量归一化的情况下分析精度。
- 分析梯度检查点策略以降低激活记忆并在内存减少时量化 FLOPs 的权衡。
实验结果
研究问题
- RQ1在具有代表性的模型中,训练过程的主导内存组件是什么?
- RQ2稀疏、低精度、微批处理和检查点各自如何影响内存与精度?
- RQ3应用这些技术时在计算方面存在哪些权衡?
- RQ4如何将这些技术组合起来,在给定精度约束下最大化内存减少?
- RQ5结合后的技术是否能够在 WideResNet 和 DynamicConv Transformer 上实现显著降低的训练内存?
主要发现
- 在所研究的设置中,激活记忆占总训练记忆的主导地位。
- WideResNet 的稀疏度高达 70%,DC-Transformer 高达 60% 时,仍能维持较小的精度/BLEU 损失(分别为 0.3% 和 0.8 BLEU)。
- FP16 训练对 WideResNet 的精度影响不显著,对 DC-Transformer 仅造成 0.15 BLEU 的下降。
- 微批大小最小可达到 10 时仍能保持 WideResNet 的精度;当极小尺寸时,精度会下降。
- 梯度检查点可以将激活内存最多降低约 5.8 倍,同时增加 FLOPs,并且不损失精度。
- 将多种技术结合可将 WideResNet 训练内存最多降低 60.7 倍,精度损失为 0.4%;对 DC-Transformer 内存最多降低 8.7 倍,损失为 0.15 BLEU。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。