[论文解读] FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
FlashAttention-2 通过重新设计并行性和工作划分来加速 Transformer 的注意力计算,在 FlashAttention 上实现大约两倍加速,在前向中达到理论最大 FLOPs 的约 73%,在后向中 63%,端到端的 GPT 风格训练每张 A100 最高可达 225 TFLOPs/s。
Scaling Transformers to longer sequence lengths has been a major problem in the last several years, promising to improve performance in language modeling and high-resolution image understanding, as well as to unlock new applications in code, audio, and video generation. The attention layer is the main bottleneck in scaling to longer sequences, as its runtime and memory increase quadratically in the sequence length. FlashAttention exploits the asymmetric GPU memory hierarchy to bring significant memory saving (linear instead of quadratic) and runtime speedup (2-4$ imes$ compared to optimized baselines), with no approximation. However, FlashAttention is still not nearly as fast as optimized matrix-multiply (GEMM) operations, reaching only 25-40\% of the theoretical maximum FLOPs/s. We observe that the inefficiency is due to suboptimal work partitioning between different thread blocks and warps on the GPU, causing either low-occupancy or unnecessary shared memory reads/writes. We propose FlashAttention-2, with better work partitioning to address these issues. In particular, we (1) tweak the algorithm to reduce the number of non-matmul FLOPs (2) parallelize the attention computation, even for a single head, across different thread blocks to increase occupancy, and (3) within each thread block, distribute the work between warps to reduce communication through shared memory. These yield around 2$ imes$ speedup compared to FlashAttention, reaching 50-73\% of the theoretical maximum FLOPs/s on A100 and getting close to the efficiency of GEMM operations. We empirically validate that when used end-to-end to train GPT-style models, FlashAttention-2 reaches training speed of up to 225 TFLOPs/s per A100 GPU (72\% model FLOPs utilization).
研究动机与目标
- 通过降低注意力瓶颈来推动对更长上下文长度的 Transformers 的扩展。
- 通过重新思考跨线程块和 warps 的工作划分来改善 GPU 利用率。
- 减少非矩阵乘法的 FLOPs,以使大部分时间用于快速的矩阵乘法操作。
- 增加序列长度、批量和注意头之间的并行性以提升利用率。
- 在 GPT 风格模型上验证端到端的训练加速。
提出的方法
- 微调 FlashAttention 的前向/后向算法,在不改变输出的情况下减少非矩阵乘法 FLOPs。
- 除了批量和头之外,还在序列长度上并行化注意力以提高占用率。
- 在一个线程块内在 warp 之间分配工作以最小化共享内存访问。
- 使用在线 softmax 瓷砖化以实现分块计算并获得正确的最终输出。
- 通过在合适时跳过区块并减少冗余的 masking 工作来高效应用因果掩蔽。
- 提供描述前向和后向传递的算法(算法 1 和算法 2),包含分块/warp 分区。
实验结果
研究问题
- RQ1FlashAttention-2 是否在前向和后向注意力传递上比 FlashAttention 实现更高的 GPU 吞吐量?
- RQ2工作划分和非矩阵乘法 FLOPs 的变化在多大程度上降低运行时和内存带宽?
- RQ3在不同上下文长度的 GPT 风格模型上使用 FlashAttention-2 可以实现怎样的端到端训练加速?
- RQ4在现代 GPU(如 A100、H100)上,FlashAttention-2 可以在理论最大 FLOPs/s 的多大程度上接近?
主要发现
| Model | Without FlashAttention (TFLOPs/s) | FlashAttention (TFLOPs/s) | FlashAttention-2 (TFLOPs/s) |
|---|---|---|---|
| GPT3-1.3B 2k context | 142 | 189 | 196 |
| GPT3-1.3B 8k context | 72 | 170 | 220 |
| GPT3-2.7B 2k context | 149 | 189 | 205 |
| GPT3-2.7B 8k context | 80 | 175 | 225 |
- FlashAttention-2 在基准测试中实现了大约对 FlashAttention 的两倍加速。
- 前向吞吐量在 A100 上最高达到理论最大 FLOPs/s 的 73%;后向最高达到 63%。
- 使用 GPT 风格模型的端到端训练在每张 A100 GPU 上达到高达 225 TFLOPs/s(模型 FLOPs 利用率 72%)。
- 在 GPT-3 1.3B/2.7B 规模下,FlashAttention-2 在 2k 和 8k 上下文中相对于基线和 FlashAttention 均有显著改进。
- 在 H100 GPU 上,预计前向+后向加速和原始吞吐量会随着新硬件特性的提升而进一步改进。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。