[论文解读] Retentive Network: A Successor to Transformer for Large Language Models
Retentive Network (RetNet) 引入了一种多尺度保留机制来替代注意力,实现并行训练、O(1) 推理,以及对长序列的线性记忆,与 Transformer 的性能相当。
In this work, we propose Retentive Network (RetNet) as a foundation architecture for large language models, simultaneously achieving training parallelism, low-cost inference, and good performance. We theoretically derive the connection between recurrence and attention. Then we propose the retention mechanism for sequence modeling, which supports three computation paradigms, i.e., parallel, recurrent, and chunkwise recurrent. Specifically, the parallel representation allows for training parallelism. The recurrent representation enables low-cost $O(1)$ inference, which improves decoding throughput, latency, and GPU memory without sacrificing performance. The chunkwise recurrent representation facilitates efficient long-sequence modeling with linear complexity, where each chunk is encoded parallelly while recurrently summarizing the chunks. Experimental results on language modeling show that RetNet achieves favorable scaling results, parallel training, low-cost deployment, and efficient inference. The intriguing properties make RetNet a strong successor to Transformer for large language models. Code will be available at https://aka.ms/retnet.
研究动机与目标
- 在不牺牲性能的前提下,推动在大语言模型部署中降低推理成本和内存使用。
- 开发保持训练并行性的 Transformer 替代架构。
- 引入支持并行、递归和分块递归表示的保留机制,以优化训练和推理。
提出的方法
- 提出一个多尺度保留(MSR)模块,用于替代多头注意力。
- 推导保留的双重表示:并行表示(便于训练)和递归表示(便于推理)。
- 实现三种计算范式:并行保留、递归保留和分块递归保留,适用于长序列。
- 结合门控(swish)和多头衰减(gamma)以提升表达能力和训练稳定性。
- 使用 GroupNorm 处理由于多尺度头带来的每头方差。
- 提供一个端到端的 RetNet 架构,包含 MSR + FFN 块及训练/推理策略。
- 在规模、训练成本和推理指标上,将 RetNet 与 Transformer 和高效 Transformer 变体进行比较。
实验结果
研究问题
- RQ1与 Transformer 相比,RetNet 是否能够在保持或提升推理效率的同时实现训练并行性?
- RQ2保留机制在降低训练和部署中的内存、延迟和计算成本的同时,是否能提供可比的语言模型性能?
- RQ3并行、递归和分块递归表示如何影响长序列建模和可扩展性?
- RQ4对于大模型和长上下文,RetNet 能实现哪些在内存、吞吐量和延迟方面的增益?
- RQ5相对于 Transformer,RetNet 在零-shot/少样本下游任务上的表现如何?
主要发现
- 相较于 Transformer,RetNet 实现了有利的扩展性、并行训练、低成本部署以及高效推理。
- 对于一个 7B 模型和 8k 上下文,RetNet 的解码速度比带键值缓存的 Transformer 快 8.4×,内存节省 70%。
- 在训练期间,RetNet 节省 25–50% 内存,较标准 Transformer 提速 7×,并且与 FlashAttention 相媲美。
- RetNet 的推理延迟与序列长度无关,对批量大小的敏感性较低,从而在解码阶段实现更高吞吐。
- RetNet 在语言建模困惑度方面与 Transformer 相当,并在若干任务上展现出有利的零-shot/少样本学习表现。
- 消融研究表明门控、GroupNorm 以及多尺度衰减对性能提升有贡献;更大的头维度会提升结果。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。