Skip to main content
QUICK REVIEW

[논문 리뷰] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Tri Dao, Daniel Y. Fu|arXiv (Cornell University)|2022. 05. 27.
Advanced Neural Network Applications인용 수 457
한 줄 요약

FlashAttention은 타일링과 재계산을 통해 메모리 IO를 실질적으로 감소시키며 정확한 어텐션을 계산하고, 더 빠른 학습과 더 긴 컨텍스트를 가능하게 한다. 또한 추가 속도 향상을 위한 블록-희소 변형이 있다.

ABSTRACT

Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length. Approximate attention methods have attempted to address this problem by trading off model quality to reduce the compute complexity, but often do not achieve wall-clock speedup. We argue that a missing principle is making attention algorithms IO-aware -- accounting for reads and writes between levels of GPU memory. We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM. We analyze the IO complexity of FlashAttention, showing that it requires fewer HBM accesses than standard attention, and is optimal for a range of SRAM sizes. We also extend FlashAttention to block-sparse attention, yielding an approximate attention algorithm that is faster than any existing approximate attention method. FlashAttention trains Transformers faster than existing baselines: 15% end-to-end wall-clock speedup on BERT-large (seq. length 512) compared to the MLPerf 1.1 training speed record, 3$ imes$ speedup on GPT-2 (seq. length 1K), and 2.4$ imes$ speedup on long-range arena (seq. length 1K-4K). FlashAttention and block-sparse FlashAttention enable longer context in Transformers, yielding higher quality models (0.7 better perplexity on GPT-2 and 6.4 points of lift on long-document classification) and entirely new capabilities: the first Transformers to achieve better-than-chance performance on the Path-X challenge (seq. length 16K, 61.4% accuracy) and Path-256 (seq. length 64K, 63.1% accuracy).

연구 동기 및 목표

  • GPU에서 자기-주목(self-attention)의 메모리 IO를 병목의 주된 원인으로 제시하고 IO-의식적인 정확한 어텐션 방법을 제안한다.
  • 입력을 타일링하고 소프트맥스를 점진적으로 수행하여 큰 N×N 어텐션 매트릭스의 읽기/쓰기 수를 줄인다.
  • 온칩 통계 및 출력으로부터 재계산하여 역전파를 위한 전체 어텐션 매트릭스 저장을 피한다.
  • 긴 시퀀스 길이에서 더 빠른 근사 어텐션으로의 확장을 위한 블록-희소 어텐션.
  • 오픈 소스 구현과 베이스라인 및 장문 맥락 작업에 대한 실증적 검증을 제공한다.

제안 방법

  • K와 V를 블록 단위로 SRAM에 로드하고 Q 블록에 걸쳐 O를 누적하기 위해 어텐션을 타일링 방식으로 재정의한다.
  • 수치적 안정성을 위해 m과 ell라는 유지 통계를 갖춘 블록 단위의 대수적 집계를 사용하여 소프트맥스를 계산한다.
  • 역전파 시 재계산을 적용하여 O와 소프트맥스 통계만 저장하고 필요 시 S와 P를 재구성한다.
  • 메모리 트래픽을 최소화하고 전체 N×N 매트릭스를 물리화하는 것을 피하기 위해 모든 단계를 하나의 CUDA 커널로 합친다.
  • FlashAttention에 대한 IO-복잡도 분석을 제공하여 HBM 접근이 O(N²d²/M)인 반면 표준 어텐션은 Ω(Nd+N²)임을 보여준다.
  • 정해진 희소성 마스크를 가진 블록-희소 FlashAttention으로 확장하여 희소도에 비례해 IO를 감소시킨다.

실험 결과

연구 질문

  • RQ1얼마나 정확도를 해치지 않으면서 GPU HBM 접근을 최소화하며 어텐션을 정확하게 계산할 수 있는가?
  • RQ2타일링과 재계산이 정확성을 해치지 않으면서 표준 어텐션보다 월-시간상의 속도향상을 낼 수 있는가?
  • RQ3블록-희소 FlashAttention이 IO 효율성과 속도를 위해 정확도와 어떤 트레이드-오프를 보이는가?
  • RQ4다양한 SRAM 크기에서 정확한 어텐션의 IO 하한은 무엇이며, 실용적인 알고리즘으로 그것에 접근할 수 있는가?
  • RQ5IO-의식적인 구현이 실제로 더 긴 컨텍스트와 더 높은 품질의 Transformer 모델을 가능하게 하는가?

주요 결과

  • FlashAttention은 주의 계산에서 GPT-2 베이스라인 대비 최대 7.6×의 속도 향상을 달성하고 HBM 읽기/쓰기를 크게 줄인다.
  • 일반적인 헤드 차원 및 SRAM 크기에 대해 FlashAttention은 표준 어텐션보다 훨씬 적은 HBM 접근이 필요하며 메모리 발자국 측면에서도 효율적이다 (입력/출력 외에는 O(N)).
  • 학습 속도가 향상된다: BERT-large가 MLPerf 1.1 레코드보다 15% 빠르고; GPT-2는 HuggingFace 베이스라인보다 최대 3배 빠르다; LRA는 2.4배 빠르다.
  • 장문 컨텍스트의 이점으로 GPT-2에서 0.7 perplexity 개선 및 장문 문서 분류에서 6.4포인트 증가; Path-X와 Path-256은 긴 시퀀스에서 난수 이상 성능을 달성한다.
  • 블록-희소 FlashAttention은 FlashAttention보다 2–4× 더 빠르며 64K 시퀀스까지 확장하면서 품질은 비슷하게 유지된다.

더 나은 연구,지금 바로 시작하세요

연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.

카드 등록 없음 · 무료 플랜 제공

이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.