Skip to main content
QUICK REVIEW

[論文レビュー] FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning

Tri Dao|arXiv (Cornell University)|Jul 17, 2023
Advanced Neural Network Applications被引用数 140
ひとこと要約

FlashAttention-2 は、並列処理と作業分割の再設計によって Transformer の attention を高速化し、FlashAttention に対して約 2 倍のスピードアップを達成し、順方向で理論上の最大 FLOPs の最大 73%、逆方向で最大 63%、エンドツーエンドの GPT-スタイル学習を A100 あたり最大 225 TFLOPs/s に達成します。

ABSTRACT

Scaling Transformers to longer sequence lengths has been a major problem in the last several years, promising to improve performance in language modeling and high-resolution image understanding, as well as to unlock new applications in code, audio, and video generation. The attention layer is the main bottleneck in scaling to longer sequences, as its runtime and memory increase quadratically in the sequence length. FlashAttention exploits the asymmetric GPU memory hierarchy to bring significant memory saving (linear instead of quadratic) and runtime speedup (2-4$ imes$ compared to optimized baselines), with no approximation. However, FlashAttention is still not nearly as fast as optimized matrix-multiply (GEMM) operations, reaching only 25-40\% of the theoretical maximum FLOPs/s. We observe that the inefficiency is due to suboptimal work partitioning between different thread blocks and warps on the GPU, causing either low-occupancy or unnecessary shared memory reads/writes. We propose FlashAttention-2, with better work partitioning to address these issues. In particular, we (1) tweak the algorithm to reduce the number of non-matmul FLOPs (2) parallelize the attention computation, even for a single head, across different thread blocks to increase occupancy, and (3) within each thread block, distribute the work between warps to reduce communication through shared memory. These yield around 2$ imes$ speedup compared to FlashAttention, reaching 50-73\% of the theoretical maximum FLOPs/s on A100 and getting close to the efficiency of GEMM operations. We empirically validate that when used end-to-end to train GPT-style models, FlashAttention-2 reaches training speed of up to 225 TFLOPs/s per A100 GPU (72\% model FLOPs utilization).

研究の動機と目的

  • 注意機構のボトルネックを減らすことで、より長い文脈長に対する Transformer のスケーリングを動機づける。
  • スレッドブロックとワープ全体での作業分割を再考することによって、GPU の利用率を向上させる。
  • 非 matmul FLOPs を削減し、ほとんどの時間を高速な matmul 演算に保つ。
  • シーケンス長、バッチ、ヘッド間の並列性を高めて占有率を向上させる。
  • GPT-スタイルのモデルでエンドツーエンドのトレーニング高速化を検証する。

提案手法

  • 出力を変えずに非 matmul FLOPs を減らすよう FlashAttention の forward/backward アルゴリズムを微調整する。
  • 占有率を高めるため、バッチとヘッドに加えてシーケンス長にわたる attention を並列化する。
  • スレッドブロック内のワープ間で作業を分散し、共有メモリのトラフィックを最小化する。
  • 正しい最終出力を得られるよう、オンラインソフトマックスのタイル化を用いてブロック単位の計算を可能にする。
  • 適切な場合にはブロックをスキップして、冗長なマスキング作業を減らすことで因果マスキングを効率的に適用する。
  • ブロック/ワープ分割を用いた forward および backward パスを説明するアルゴリズム(Algorithm 1 および Algorithm 2)を提供する。

実験結果

リサーチクエスチョン

  • RQ1FlashAttention-2 は FlashAttention と比較して、前向きおよび後向きの attention パスの両方でより高い GPU 透過性を達成できるか?
  • RQ2作業分割と非 matmul FLOPs の変更は、実行時間とメモリトラフィックをどれだけ削減するか?
  • RQ3異なるコンテキスト長にわたって GPT-スタイルのモデルで FlashAttention-2 を使用した場合、エンドツーエンドのトレーニングのスピードアップはどの程度実現されるか?
  • RQ4近代的な GPU(例えば A100, H100)上で、FlashAttention-2 は理論上の最大 FLOPs/s にどれだけ近づけるだろうか?

主な発見

モデルFlashAttentionなし (TFLOPs/s)FlashAttention (TFLOPs/s)FlashAttention-2 (TFLOPs/s)
GPT3-1.3B 2k context142189196
GPT3-1.3B 8k context72170220
GPT3-2.7B 2k context149189205
GPT3-2.7B 8k context80175225
  • FlashAttention-2 はベンチマークで FlashAttention に対して約 2 倍の速度アップを実現する。
  • 前方パスのスループットは A100 で理論上の最大 FLOPs/s の最大 73%、後方パスは最大 63%。
  • エンドツーエンドの学習は GPT-スタイルのモデルで A100 GPU あたり最大 225 TFLOPs/s に達し(モデル FLOPs の 72% 利用)。
  • GPT-3 1.3B/2.7B スケールでは、2k および 8k の文脈長でベースラインおよび FlashAttention に対して FlashAttention-2 が大幅な改善を達成する。
  • H100 GPU では、前方+後方の speedup と生のスループットは新しいハードウェア機能によりさらに向上すると期待される。

より良い研究を、今すぐ始めましょう

論文設計から論文執筆まで、研究時間を劇的に削減しましょう。

クレジットカード登録不要

このレビューはAIが作成し、人間の編集者が確認しました。