[论文解读] Prompt Cache: Modular Attention Reuse for Low-Latency Inference
Prompt Cache 重用可复用提示模块的预计算注意力状态,在GPU/CPU上实现8×–60×TTFT速度提升,且不损失准确性。
We present Prompt Cache, an approach for accelerating inference for large language models (LLM) by reusing attention states across different LLM prompts. Many input prompts have overlapping text segments, such as system messages, prompt templates, and documents provided for context. Our key insight is that by precomputing and storing the attention states of these frequently occurring text segments on the inference server, we can efficiently reuse them when these segments appear in user prompts. Prompt Cache employs a schema to explicitly define such reusable text segments, called prompt modules. The schema ensures positional accuracy during attention state reuse and provides users with an interface to access cached states in their prompt. Using a prototype implementation, we evaluate Prompt Cache across several LLMs. We show that Prompt Cache significantly reduce latency in time-to-first-token, especially for longer prompts such as document-based question answering and recommendations. The improvements range from 8x for GPU-based inference to 60x for CPU-based inference, all while maintaining output accuracy and without the need for model parameter modifications.
研究动机与目标
- 通过利用系统消息、模板或文档等重复的提示片段,推动大语言模型推理延迟的降低。
- 引入正式的提示模块架构(PML),以显式定义可重用的提示组件及其位置。
- 对提示模块的注意力状态进行预计算与缓存,并在不同提示之间重复使用,以减少注意力计算。
- 证明模块化注意力重用在实现显著延迟减少的同时,仍然维持输出质量,在多种模型/数据集上具有可比性。
提出的方法
- 使用提示标记语言(PML)及一个架构来定义提示模块,以实现模块化注意力重用。
- 在CPU或GPU内存中对提示模块的注意力状态进行编码和缓存。
- 将缓存的提示模块状态与新计算的片段拼接,形成推断时的完整提示注意力状态。
- 调整 Transformer 架构以支持模块化注意力重用所需的非连续位置ID。
- 基于 HuggingFace Transformers 构建的原型实现,并在 LongBench 数据集上的 Llama2、Falcon 和 MPT 进行评估。
实验结果
研究问题
- RQ1注意力状态是否可以在共享常见文本片段的不同提示之间重复使用?
- RQ2如何编码和管理可重复使用的提示模块,以在不牺牲正确性的前提下最大化注意力重用?
- RQ3Prompt Cache 对延迟(首次令牌时间)和端到端生成在 CPU 和 GPU 部署中的影响是什么?
- RQ4存储提示模块的预计算注意力状态的内存开销是多少?它如何随模型规模和提示长度扩大而扩展?
- RQ5Prompt Cache 是否在多样的 LongBench 任务(问答、摘要、代码等)中保持输出质量?
主要发现
- Prompt Cache 显著降低首次令牌延迟,在 GPU 上实现 8×–10× 的提升,在 CPU 上在不同内存布局下可达 20×–70×。
- 输出准确性在 LongBench 任务和多种模型架构(Llama2、MPT、Falcon)之间与基线 KV Cache 相当。
- 随着提示变长和缓存片段增大,延迟收益增加,这是由于自注意力成本的二次方级别相比缓存内存拷贝成本的线性关系。
- 在可能的情况下,将提示模块存储在 GPU 内存中比存储在 CPU 内存中获得更高的加速,但 CPU 内存为非常大的模块库提供更大容量。
- 在 PML 中对提示模块和联合进行参数化实现了灵活、可重复、可组合的提示设计,同时保持注意力状态的可复用性。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。