[論文レビュー] Mixed Precision Training With 8-bit Floating Point
この論文は、重み、活性化、誤差、勾配を8ビットの FP8 で計算し、32ビットのアキュムレータを用いて深層ネットワークを訓練し、Imagenet-1K および WMT16 の複数モデル・タスクで最先端の精度を達成します。
Reduced precision computation for deep neural networks is one of the key areas addressing the widening compute gap driven by an exponential growth in model size. In recent years, deep learning training has largely migrated to 16-bit precision, with significant gains in performance and energy efficiency. However, attempts to train DNNs at 8-bit precision have met with significant challenges because of the higher precision and dynamic range requirements of back-propagation. In this paper, we propose a method to train deep neural networks using 8-bit floating point representation for weights, activations, errors, and gradients. In addition to reducing compute precision, we also reduced the precision requirements for the master copy of weights from 32-bit to 16-bit. We demonstrate state-of-the-art accuracy across multiple data sets (imagenet-1K, WMT16) and a broader set of workloads (Resnet-18/34/50, GNMT, Transformer) than previously reported. We propose an enhanced loss scaling method to augment the reduced subnormal range of 8-bit floating point for improved error propagation. We also examine the impact of quantization noise on generalization and propose a stochastic rounding technique to address gradient noise. As a result of applying all these techniques, we report slightly higher validation accuracy compared to full precision baseline.
研究の動機と目的
- 低精度トレーニングを動機づけ、深層学習における計算ギャップの拡大に対処する。
- 重要な計算パスで確率的丸めを行わず、重み・活性化・誤差・勾配の FP8 計算を提案する。
- FP8 を用いた大規模データセットとモデルで訓練を実証しつつ、マスターコピーの重み精度を低減する。
- ロススケールの課題と量子化ノイズに対処し、精度を維持または向上させる。
提案手法
- 重み、活性化、誤差、勾配に FP8(s=1,e=5,m=2)を使用し、32ビット FP アキュムレータを用いる。
- 前方伝搬、後方伝搬、重み更新パスに量子化操作を挿入し、32ビット出力をFP8へダウンコンバートする。
- 勾配の下振れを防ぎ、最適化の安定性を保つためにロススケーリングを適用する。
- FP16でマスタ重みを保存し、計算パス上はFP32を用いて更新を行い、FP16ストレージへ戻す。
- 丸めモードを検討し、勾配ノイズを軽減し一般化を改善するために確率的丸めを導入する。
実験結果
リサーチクエスチョン
- RQ1FP8 混合精度トレーニングは、畳み込みアーキテクチャ(ResNet 系列)および NLP/Seq2Seq モデルで、FP32 のベースラインと同等またはそれを上回る精度を達成できるのか。
- RQ2訓練時の FP8 の縮小サブノーマルレンジに最適に対処するロススケーリング戦略と丸め方法は何か。
- RQ3FP8 は Imagenet-1K および WMT16 のような大規模データセットに対して収束、一般化、メモリ効率にどのような影響を与えるのか。
主な発見
| モデル | データセット | バッチサイズ | エポック数 | FP32 (top-1 %) | FP8 (top-1 %) |
|---|---|---|---|---|---|
| Resnet-18 | imagenet-1K | 256 | 100 | 69.23 | 69.71 |
| Resnet-34 | imagenet-1K | 256 | 100 | 72.96 | 72.95 |
| Resnet-50 | imagenet-1K | 256 | 100 | 75.47 | 75.70 |
- FP8 訓練は、強化されたロススケーリングとともに、Imagenet-1K の ResNet-18/34/50 に対して FP32 ベースラインにほぼ近いまたはわずかに上回るトップ1精度を達成する(69.71 vs 69.23; 72.95 vs 72.96; 75.70 vs 75.47)。
- FP8 訓練は FP32 アキュムレータを用いて、ResNet 作業負荷および GNMT/Transformer 翻訳タスクの WMT16 で安定した収束と精度を維持する。
- WMT16 の BLEU スコアは FP8 が FP32 ベースラインと同等である(GNMT 24.6 vs 24.3; Transformer 23.0 vs 23.6)。
- 一部のモデル(例:GNMT)では、収束を阻害せず一般化を向上させるために FP8 の動的ロススケーリング戦略が必要。
- activations/gradients の確率的丸めは、決定論的丸めと比較して一般化を改善し、検証性能をわずかに向上させる可能性がある。
- FP16 マスターコピーと FP8 計算は、精度を劣化させることなくマスタ重みの保管量を50%削減可能。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。