[論文レビュー] Training Deep Neural Networks with 8-bit Floating Point Numbers
本論文は、チャンクベースの蓄積と浮動小数点確率的丸めによって可能となる、FP16の蓄積とFP16の重み更新を伴う8ビット浮動小数点数 (FP8) を用いたさまざまなDNNの訓練を成功裏に示し、FP32のベースラインと同等の精度を達成しつつ、メモリと計算要件を低減することを示している。
The state-of-the-art hardware platforms for training Deep Neural Networks (DNNs) are moving from traditional single precision (32-bit) computations towards 16 bits of precision -- in large part due to the high energy efficiency and smaller bit storage associated with using reduced-precision representations. However, unlike inference, training with numbers represented with less than 16 bits has been challenging due to the need to maintain fidelity of the gradient computations during back-propagation. Here we demonstrate, for the first time, the successful training of DNNs using 8-bit floating point numbers while fully maintaining the accuracy on a spectrum of Deep Learning models and datasets. In addition to reducing the data and computation precision to 8 bits, we also successfully reduce the arithmetic precision for additions (used in partial product accumulation and weight updates) from 32 bits to 16 bits through the introduction of a number of key ideas including chunk-based accumulation and floating point stochastic rounding. The use of these novel techniques lays the foundation for a new generation of hardware training platforms with the potential for 2-4x improved throughput over today's systems.
研究の動機と目的
- モデルの精度を損なうことなく、訓練精度を8ビットへと削減する動機付け。
- 蓄積と更新の課題を克服するために、FP8/FP16形式と手法を導入する。
- 標準データセット上でCNNおよびDNN全般にわたる広範な実証検証を示す。
- スループットとエネルギー効率を2〜4倍改善するハードウェア効率的アプローチを提案する。
提案手法
- データと蓄積のためにFP8 (1,5,2) および FP16 (1,6,9) 形式を定義する。
- 長いドット積を分割し、スワンピング誤差を低減するためにチャンクベースの蓄積を使用する。
- 丸め時の情報損失を保つために、重み更新に浮動小数点の確率的丸めを適用する。
- Softmax計算を安定化させるため、最終層GEMMにはFP16を維持する。
- backpropagation中の小さな勾配を保持するためにロススケーリングを採用する。
- 複数のネットワークとデータセットに渡るエミュレートされた低精度実験で検証する。
実験結果
リサーチクエスチョン
- RQ1多様なモデルとデータセットにわたって、8ビット浮動小数点表現を使用して精度の損失なしにDNNを訓練できるだろうか?
- RQ2訓練時に低精度フォーマットを使用する場合、swampingおよび蓄積誤差をどのように緩和できるか?
- RQ3メモリ、帯域幅、エネルギー効率の観点から、FP8訓練の実用的なハードウェア影響は何か?
- RQ4FP8訓練の成功において、先頭層と最終層の精度はどのような役割を果たすか?
- RQ5丸めモードはFP8訓練の精度にどのように影響するか?
主な発見
- FP8訓練は、FP16の蓄積とFP16のウェイト更新を伴い、CIFAR-10 CNN、CIFAR-10 ResNet、BN50-DNN、AlexNet、ResNet-18、ResNet-50に対して FP32ベースラインと同等の検証精度を達成する。
- FP8ウェイトとFP16マスターコピーにより、ウェイトとマスターコピーのメモリを約2×削減。
- チャンクベースの蓄積と浮動小数点確率的丸めは、swampingを有効に緩和し、堅牢な8ビット訓練を可能にする。
- ロススケーリングと最後の層のGEMMsにFP16を温存することは、ImageNetのような大規模データセットでの訓練を安定化させる。
- 最近傍丸めは精度を低下させる;確率的丸めはFP16の重み更新時にベースラインの性能を維持する。
- ハードウェアデモは、FP8エンジンがFP16と比較して2〜4×エネルギー効率を高められることを示唆している。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。