Skip to main content
QUICK REVIEW

[論文レビュー] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Tri Dao, Daniel Y. Fu|arXiv (Cornell University)|May 27, 2022
Advanced Neural Network Applications被引用数 457
ひとこと要約

FlashAttention は、タイル分割と再計算によりメモリIOを大幅に削減して正確な注意計算を実現し、訓練を速くし、より長い文脈を可能にする。さらに速度向上のためのブロックスパース変種も提供。

ABSTRACT

Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length. Approximate attention methods have attempted to address this problem by trading off model quality to reduce the compute complexity, but often do not achieve wall-clock speedup. We argue that a missing principle is making attention algorithms IO-aware -- accounting for reads and writes between levels of GPU memory. We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM. We analyze the IO complexity of FlashAttention, showing that it requires fewer HBM accesses than standard attention, and is optimal for a range of SRAM sizes. We also extend FlashAttention to block-sparse attention, yielding an approximate attention algorithm that is faster than any existing approximate attention method. FlashAttention trains Transformers faster than existing baselines: 15% end-to-end wall-clock speedup on BERT-large (seq. length 512) compared to the MLPerf 1.1 training speed record, 3$ imes$ speedup on GPT-2 (seq. length 1K), and 2.4$ imes$ speedup on long-range arena (seq. length 1K-4K). FlashAttention and block-sparse FlashAttention enable longer context in Transformers, yielding higher quality models (0.7 better perplexity on GPT-2 and 6.4 points of lift on long-document classification) and entirely new capabilities: the first Transformers to achieve better-than-chance performance on the Path-X challenge (seq. length 16K, 61.4% accuracy) and Path-256 (seq. length 64K, 63.1% accuracy).

研究の動機と目的

  • GPU上の自己注意におけるメモリ IO をボトルネックとして動機づけ、IOを意識した正確な注意法を提案する。
  • 入力をタイル化してソフトマックスを逐次実行することで、大規模な N×N の注意行列の読み書きを削減する。
  • バックワードで全注意行列を保持せず、オンチップ統計と出力から再計算して再現する。
  • 長いシーケンス長での高速な近似注意を可能にするためブロックスパース注意へ拡張する。
  • オープンソース実装とベースラインと長文脈タスクでの実証的検証を提供する。

提案手法

  • 注意をタイル分割して K と V をブロック単位で SRAM にロードし、Q ブロックごとに O を蓄積する。
  • 代数的集約を用いてブロック内でソフトマックスを計算し、数値安定性のために m と ell を維持する。
  • バックワード時には O とソフトマックス統計のみを保存して再計算により S と P を必要に応じて再構築する。
  • 全てのステップを1つの CUDA カーネルに統合してメモリ転送を最小化し、全ての N×N 行列を実際に展開しない。
  • IO複雑性の分析を提供し、FlashAttention の場合は O(N²d²/M) の HDD アクセス、標準注意の場合は Ω(Nd+N²) を示す。
  • 固定のスパースマスクを用いたブロックスパース FlashAttention に拡張し、スパース性に比例して IO を削減。

実験結果

リサーチクエスチョン

  • RQ1GPU の HBM アクセスを最小化しつつ正確に注意を計算するにはどうすればよいか。
  • RQ2タイル分割と再計算は正確性を損なうことなく標準の注意に対して壁時速を上げられるか。
  • RQ3ブロックスパース FlashAttention は IO 効率と速度のために精度をどのように妥協するか。
  • RQ4SRAM サイズ全体での正確な注意の IO の下界は何か、実用的なアルゴリズムはそれに近づけるか。
  • RQ5IO意識の実装は長い文脈と高品質な Transformer モデルを実務的に可能にするか。)

主な発見

  • FlashAttention は attention 計算で GPT-2 ベースラインに対して最大 7.6× のスピードアップを達成し、HBM の読み書きを大幅に削減した。
  • 典型的なヘッド次元と SRAM サイズに対して、FlashAttention は標準の注意よりもはるかに少ない HBM アクセスを必要とし、メモリ・フットプリントにも効率的である(入力/出力を超える O(N))。
  • 訓練速度の向上: BERT-large は MLPerf 1.1 記録より 15% 速い; GPT-2 は HuggingFace ベースラインより最大 3× 速い; LRA は 2.4× 速い。
  • 長い文脈の恩恵として GPT-2 のパープレキシティが 0.7 改善、長文分類で 6.4 ポイントの上昇。Path-X と Path-256 は長いシーケンスで乱数より良い性能を達成。
  • ブロックスパース FlashAttention は FlashAttention より 2–4× の高速化を実現し、64K のシーケンスへ拡張可能で品質を維持。

より良い研究を、今すぐ始めましょう

論文設計から論文執筆まで、研究時間を劇的に削減しましょう。

クレジットカード登録不要

このレビューはAIが作成し、人間の編集者が確認しました。