Skip to main content
QUICK REVIEW

[論文レビュー] Automatic Cross-Replica Sharding of Weight Update in Data-Parallel Training

Yuanzhong Xu, HyoukJoong Lee|arXiv (Cornell University)|Apr 28, 2020
Advanced Neural Network Applications参考文献 19被引用数 19
ひとこと要約

この論文では、データ並列学習における重み更新計算の自動的かつクロスレプリカ分散処理を提案し、すべてのレプリカで繰り返し非分割更新が行われることによる性能ボトルネックを軽減する。XLAにおける静的解析とグラフ変換を活用し、最適化された通信を伴って重みと最適化器の補助変数をレプリカ間で効率的に分散処理することで、コード変更や追加ハードウェアなしで、Transformerのような大規模モデルで最大45%の高速化を達成した。

ABSTRACT

In data-parallel synchronous training of deep neural networks, different devices (replicas) run the same program with different partitions of the training batch, but weight update computation is repeated on all replicas, because the weights do not have a batch dimension to partition. This can be a bottleneck for performance and scalability in typical language models with large weights, and models with small per-replica batch size which is typical in large-scale training. This paper presents an approach to automatically shard the weight update computation across replicas with efficient communication primitives and data formatting, using static analysis and transformations on the training computation graph. We show this technique achieves substantial speedups on typical image and language models on Cloud TPUs, requiring no change to model code. This technique helps close the gap between traditionally expensive (ADAM) and cheap (SGD) optimizers, as they will only take a small part of training step time and have similar peak memory usage. It helped us to achieve state-of-the-art training performance in Google's MLPerf 0.6 submission.

研究の動機と目的

  • データ並列学習における繰り返し発生する非分割重み更新計算が引き起こすパフォーマンスボトルネックを解消すること。
  • ADAMのような高コストな最適化手法が、すべてのレプリカで全重み更新を実行するため、学習時間が圧倒的に長くなるのを軽減すること。
  • 追加のデバイスを追加せずに、既存のレプリカ上で重みと補助変数(例:モーメンタム、分散)を効率的に分散処理すること。
  • 知的な分散戦略と通信パターンの選択により、通信とデータフォーマットのオーバーヘッドを最小限に抑えること。
  • 既存のモデルコードとの互換性を維持しながら、大規模モデルにおける顕著な高速化とメモリ節約を達成すること。

提案手法

  • XLA計算グラフに対する静的解析を実施し、分散処理に適した繰り返し発生する演算(例:重み更新)を同定すること。
  • 制御フロー解析を用いて、最適な通信ポイントを特定し、分散処理候補演算のパフォーマンス向上を推定すること。
  • 戦略的な位置に分散演算と通信プリミティブ(例:all-gather, all-reduce)を挿入するため、計算グラフを変換すること。
  • アクセラレータメモリレイアウト(例:タイルメモリ)に整合するデータ分散フォーマットを設計し、通信とメモリアクセスコストを最小限に抑えること。
  • 小規模学習(スハードサイズを最小限に)と大規模学習(通信遅延を最小限に)のシナリオに応じて、異なる分散戦略を採用すること。
  • XLAの関数型IRを活用し、副作用を最小限に抑えることで、解析を簡素化し、重みテンソルのライブレンジを短縮する高度な最適化を可能にすること。

実験結果

リサーチクエスチョン

  • RQ1データ並列学習における重み更新計算を、モデルコードを変更せずにレプリカ間で自動的に分散処理できるか?
  • RQ2重みと補助変数をレプリカ間で分散処理する際のパフォーマンスと通信オーバーヘッドのトレードオフは何か?
  • RQ3自動的分散処理は、重みと補助テンソルのライブレンジの短縮により、特にピークメモリ使用量にどのように影響を与えるか?
  • RQ4大規模学習において、ADAMの実行時間はSGDに比べてどれほど短縮できるか?
  • RQ5重みサイズやバッチサイズが異なるモデルにおいて、分散戦略の有効性にどのような差が生じるか?

主な発見

  • 1024 TPUv3チップを用いたTransformerモデルでは、ステップ時間に45%の短縮が達成され、46.5msから25.6msに低下した。
  • ResNet-50のような小規模モデルでも、1024チップでのスケーリング時に6%の高速化が観測され、広範な適用可能性が示された。
  • 言語モデルのTransformerでは、16チップの小規模スケールで9%の改善が得られ、大容量重みを持つモデルに与える影響の大きさが明らかになった。
  • 最適化により、補助変数用のバッファの再利用が可能になり、NCFのような大規模補助変数オーバーヘッドを持つモデルではピークメモリ使用量が削減された。
  • ADAMとSGDの間のメモリギャップが効果的に埋められ、両者のピークメモリ使用量が類似した水準に収束した。
  • このアプローチはモデルコードの変更を一切不要とし、追加のハードウェアやインfraストラクチャを追加せずに、既存のレプリカのみを活用した。

より良い研究を、今すぐ始めましょう

論文設計から論文執筆まで、研究時間を劇的に削減しましょう。

クレジットカード登録不要

このレビューはAIが作成し、人間の編集者が確認しました。