[論文レビュー] Efficiently Scaling Transformer Inference
ペーパーは、大規模なTransformerモデルを複数のTPUチップに分割して推論遅延とFLOPS利用率を最適化するためのエンジニアリングフレームワークを提示し、500B超のパラメータモデルで新しいパレート前線を達成し、マルチクエリアテンションで長いコンテキストを可能にします。
We study the problem of efficient generative inference for Transformer models, in one of its most challenging settings: large deep models, with tight latency targets and long sequence lengths. Better understanding of the engineering tradeoffs for inference for large Transformer-based models is important as use cases of these models are growing rapidly throughout application areas. We develop a simple analytical model for inference efficiency to select the best multi-dimensional partitioning techniques optimized for TPU v4 slices based on the application requirements. We combine these with a suite of low-level optimizations to achieve a new Pareto frontier on the latency and model FLOPS utilization (MFU) tradeoffs on 500B+ parameter models that outperforms the FasterTransformer suite of benchmarks. We further show that with appropriate partitioning, the lower memory requirements of multiquery attention (i.e. multiple query heads share single key/value head) enables scaling up to 32x larger context lengths. Finally, we achieve a low-batch-size latency of 29ms per token during generation (using int8 weight quantization) and a 76% MFU during large-batch-size processing of input tokens, while supporting a long 2048-token context length on the PaLM 540B parameter model.
研究の動機と目的
- アプリケーションのニーズに基づいて、大規模Transformer推論のための多軸分割戦略を選択する簡易な解析モデルを開発する。
- メモリと低レベルスケジューリング最適化と分割を組み合わせて、500B超のパラメータモデルにおけるMFUとレイテンシを改善する。
- マルチクエリアテンションがKVキャッシュのメモリコストを削減し、長いコンテキスト長を可能にすることを示す。
- PaLM 540B(64 TPU v4 チップ)上で検証された実用的で構成可能な推論フレームワークを実証する。
- 異なるワークロード要件の下で、プリフィルと生成フェーズにおける分割レイアウトの選択ガイドラインを提供する。
提案手法
- 指標を定義する:レイテンシ、スループット、モデルFLOPS利用率(MFU)。
- フィードフォワード層のための1D/2Dウェイトステーショナリ―および重み集約方式を含む分割レイアウトを開発する。
- KVキャッシュをバッチ間でまとめてメモリ時間を削減することでマルチクエリアテンション分割を提案する。
- 並列のアテンション/フィードフォワード構成を活用して、演算の融合と通信の削減を図る。
- 低レベル最適化(Looped CollectiveEinsum、非同期集団演算)とint8ウェイト量子化を適用して性能を向上させる。
- 生成時、文脈長2048トークンまで、int8で1トークンあたり29 msのレイテンシを達成するPaLM 540B(64 TPU v4チップ)で検証する。
実験結果
リサーチクエスチョン
- RQ1大規模Transformer推論における分割戦略はレイテンシ、MFU、メモリトラフィックにどのように影響するか。
- RQ2さまざまなバッチサイズとコンテキスト長で、フィードフォワード層の1D/2D/重み集約レイアウトの最適な組み合わせは何か。
- RQ3マルチクエリアテンションはKVキャッシュのメモリ負荷を大幅に削減し、過度な通信を伴わずに長いコンテキストを可能にできるか。
- RQ4並列アテンション/フィードフォワードの定式化は、直列実装と比較してレイテンシとMFUにどのような影響を与えるか。
- RQ5PaLM規模のモデルで最高の実用的なパレート前線をもたらす量子化と低レベル最適化は何か。
主な発見
- 単純な解析的分割フレームワークは、所与のモデルサイズ、コンテキスト長、チップ数に対してほぼ最適な多軸分割を特定できる。
- 2D重み固定と重み集約のフィードフォワードレイアウトはバッチサイズの増大に伴って切り替わり、大きなバッチでは重み集約レイアウトが優れ、MFU最大で76%に達する。
- マルチクエリアテンションはKVキャッシュメモリを最大でn_chips倍削減し、報告された設定でマルチヘッド構成よりも長いコンテキスト(32–64倍長く)を可能にする。
- parallelアテンション/フィードフォワード定式はレイテンシと通信を削減し、直列版と比較してFLOPS利用率を高める。
- PaLM 540B(64 TPUs)では、プリフィルのレイテンシと生成スループットが、大規模バッチ下で1トークンあたり29 msの低バッチレイテンシ(int8)とMFU 76%を達成し、2048トークンのコンテキストを実現。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。