[论文解读] SARATHI: Efficient LLM Inference by Piggybacking Decodes with Chunked Prefills
Sarathi 引入分块预填和解码最大化批处理,使解码能够搭载在预填之上,从而实现更高的解码和端到端吞吐量,并减少跨模型和硬件的流水线气泡。
Large Language Model (LLM) inference consists of two distinct phases - prefill phase which processes the input prompt and decode phase which generates output tokens autoregressively. While the prefill phase effectively saturates GPU compute at small batch sizes, the decode phase results in low compute utilization as it generates one token at a time per request. The varying prefill and decode times also lead to imbalance across micro-batches when using pipeline parallelism, resulting in further inefficiency due to bubbles. We present SARATHI to address these challenges. SARATHI employs chunked-prefills, which splits a prefill request into equal sized chunks, and decode-maximal batching, which constructs a batch using a single prefill chunk and populates the remaining slots with decodes. During inference, the prefill chunk saturates GPU compute, while the decode requests 'piggyback' and cost up to an order of magnitude less compared to a decode-only batch. Chunked-prefills allows constructing multiple decode-maximal batches from a single prefill request, maximizing coverage of decodes that can piggyback. Furthermore, the uniform compute design of these batches ameliorates the imbalance between micro-batches, significantly reducing pipeline bubbles. Our techniques yield significant improvements in inference performance across models and hardware. For the LLaMA-13B model on A6000 GPU, SARATHI improves decode throughput by up to 10x, and accelerates end-to-end throughput by up to 1.33x. For LLaMa-33B on A100 GPU, we achieve 1.25x higher end-to-end-throughput and up to 4.25x higher decode throughput. When used with pipeline parallelism on GPT-3, SARATHI reduces bubbles by 6.29x, resulting in an end-to-end throughput improvement of 1.91x.
研究动机与目标
- 解决由于 prefill 与 decode 计算之间的不匹配而造成的 LLM 推理低效。
- 通过创建统一、计算饱和的批次,减少流水线并行部署中的流水线气泡。
- 引入能够让解码在预填上搭载的技术,以提高整体吞吐量。
- 在不同模型和硬件上评估性能,以展示可扩展性和实用性。
提出的方法
- 引入分块预填,将预填请求拆分为等大小的计算块。
- 提出 decode-maximal batching 以形成仅一个预填块、其余槽位上叠加解码的批次。
- 将解码的线性运算与预填融合,将解码时间从内存瓶颈转变为计算瓶颈。
- 使用 P:D 比和 tile-size 考虑分析预填块大小与叠加解码数量之间的权衡。
- 在 nanoGPT 上实现并与 Orca 类的迭代级调度进行比较。
实验结果
研究问题
- RQ1 Sarathi 对不同序列长度、批量大小和 P:D 比率下的解码吞吐量和端到端 LLM 吞吐量的影响是什么?
- RQ2与 Orca 等现有的迭代级调度机制相比,Sarathi 在吞吐量和效率方面表现如何?
- RQ3Sarathi 对 GPU 流水线气泡和流水线并行模型吞吐量的影响如何?
- RQ4分块预填和解码最大化批处理相关的开销是什么,且它们如何扩展?
主要发现
- 在 A6000 上的 LLaMA-13B 上,解码吞吐量可提升多达 10×,端到端吞吐量提升至 1.33×。
- 在 A100 上的 LLaMA-33B,解码吞吐量提升至多 4.25×,端到端吞吐量提升至 1.25×。
- 在 GPT-3 的流水线并行中使用时,流水线气泡减少 6.29×,端到端吞吐量提升为 1.91×。
- 在各评估场景中,Sarathi 在多种模型和硬件上展示了显著的解码吞吐量和端到端吞吐量提升。
- 解码最大化批处理通过复用模型权重和融合计算,将解码工作从内存瓶颈转变为计算瓶颈。
- 分块预填和解码最大化批处理使单个预填即可产生多组混合批次,从而提高叠加解码的覆盖率。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。