Skip to main content
QUICK REVIEW

[论文解读] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Tri Dao, Daniel Y. Fu|arXiv (Cornell University)|May 27, 2022
Advanced Neural Network Applications被引用 457
一句话总结

FlashAttention 通过分块和重新计算在极大减少内存 IO 的情况下计算精确注意力,实现更快训练并支持更长上下文,以及一个用于进一步加速的块稀疏变体。

ABSTRACT

Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length. Approximate attention methods have attempted to address this problem by trading off model quality to reduce the compute complexity, but often do not achieve wall-clock speedup. We argue that a missing principle is making attention algorithms IO-aware -- accounting for reads and writes between levels of GPU memory. We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM. We analyze the IO complexity of FlashAttention, showing that it requires fewer HBM accesses than standard attention, and is optimal for a range of SRAM sizes. We also extend FlashAttention to block-sparse attention, yielding an approximate attention algorithm that is faster than any existing approximate attention method. FlashAttention trains Transformers faster than existing baselines: 15% end-to-end wall-clock speedup on BERT-large (seq. length 512) compared to the MLPerf 1.1 training speed record, 3$ imes$ speedup on GPT-2 (seq. length 1K), and 2.4$ imes$ speedup on long-range arena (seq. length 1K-4K). FlashAttention and block-sparse FlashAttention enable longer context in Transformers, yielding higher quality models (0.7 better perplexity on GPT-2 and 6.4 points of lift on long-document classification) and entirely new capabilities: the first Transformers to achieve better-than-chance performance on the Path-X challenge (seq. length 16K, 61.4% accuracy) and Path-256 (seq. length 64K, 63.1% accuracy).

研究动机与目标

  • 提出内存 IO 作为 GPU 自注意力瓶颈的动机,并提出一种 IO 感知的精确注意力方法。
  • 通过对输入分块并逐步执行 softmax 来减少大型 N×N 注意力矩阵的读写。
  • 通过从片上统计信息和输出重新计算来避免在反向传播中存储完整的注意力矩阵。
  • 扩展到块稀疏注意力,以在较长序列长度下实现更快的近似注意力。
  • 提供开源实现,并在基线和长上下文任务上进行实证验证。

提出的方法

  • 通过分块重塑注意力,将 K 和 V 以块的方式加载到 SRAM,并在 Q 块之间累积 O。
  • 使用代数聚合在块中计算 softmax,并维护 m 和 ell 作为数值稳定性统计量。
  • 在反向传播中通过仅存储 O 和 softmax 统计量来重新计算,以在需要时重构 S 和 P。
  • 将所有步骤融合到一个单独的 CUDA 内核中,以尽量减少内存传输并避免将完整的 N×N 矩阵物化。
  • 提供 IO 复杂度分析,显示 FlashAttention 的 HBM 访问量为 O(N²d²/M),而标准注意力为 Ω(Nd+N²)。
  • 扩展到带固定稀疏掩码的块稀疏 FlashAttention,以按稀疏度减少 IO。

实验结果

研究问题

  • RQ1如何在尽量减少 GPU HBM 访问的同时实现精确的注意力计算?
  • RQ2分块和重新计算是否能在不牺牲精确性的前提下带来相对于标准注意力的实用速度提升?
  • RQ3块稀疏 FlashAttention 如何在准确性、IO 效率和速度之间权衡?
  • RQ4跨 SRAM 大小的精确注意力的 IO 下界是多少,实际算法是否能接近?
  • RQ5IO 感知实现是否在实践中能实现更长的上下文和更高质量的 Transformer 模型?

主要发现

  • FlashAttention 在注意力计算上相对于 GPT-2 基线实现最高 7.6× 的加速,并显著减少 HBM 的读写。
  • 对于典型的头部维度和 SRAM 大小,FlashAttention 需要的 HBM 访问远少于标准注意力,且内存占用以输入/输出之外的 O(N)。
  • 训练速度提高:BERT-large 比 MLPerf 1.1 记录快 15%;GPT-2 最大比 HuggingFace 基线快 3×;LRA 提速 2.4×。
  • 长上下文收益包括在 GPT-2 上困惑度降低 0.7,长文档分类提升 6.4 点;Path-X 与 Path-256 在长序列上实现优于随机的性能。
  • 块稀疏 FlashAttention 比 FlashAttention 快 2–4×,可扩展到 64K 序列,同时保持可比的质量。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。