Skip to main content
QUICK REVIEW

[论文解读] GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

Joshua Ainslie, James Lee-Thorp|arXiv (Cornell University)|May 22, 2023
Topic Modeling被引用 27
一句话总结

本文提出 uptraining,将 multi-head attention 转换为 multi-query 以及广义分组查询注意力(GQA),在使用原始预训练计算的一小部分的同时,实现具有竞争质量的快速推理。

ABSTRACT

Multi-query attention (MQA), which only uses a single key-value head, drastically speeds up decoder inference. However, MQA can lead to quality degradation, and moreover it may not be desirable to train a separate model just for faster inference. We (1) propose a recipe for uptraining existing multi-head language model checkpoints into models with MQA using 5% of original pre-training compute, and (2) introduce grouped-query attention (GQA), a generalization of multi-query attention which uses an intermediate (more than one, less than number of query heads) number of key-value heads. We show that uptrained GQA achieves quality close to multi-head attention with comparable speed to MQA.

研究动机与目标

  • 动机:在自回归 Transformer 中,由于内存带宽成为解码端瓶颈,寻求在不牺牲质量的情况下实现更快的推理。
  • 提出一种成本有效的 uptraining 方案,将现有的 multi-head 检查点转换为 multi-query 配置。
  • 引入 grouped-query attention(GQA),作为 multi-head 与 multi-query attention 之间的插值。
  • 证明经过 uptraining 的 GQA 在各类任务中达到接近 multi-head attention 的质量,且速度接近 multi-query attention。

提出的方法

  • 通过对 key 和 value 投影矩阵取均值池化,将 multi-head attention 检查点转换为单一头对应于 MQA,或用于每个 GQA 组。
  • 使用与原始数据与配方相同的、原始预训练步骤的一小部分(α)对转换后的模型进行 uptrain。
  • 将 grouped-query attention(GQA)定义为具有 G 个查询头组,每组共享一个 KV 头,与 MQA(G=1)和 MHA(G=H)之间插值。
  • 在 T5-Large 和 T5-XXL 上进行实验,将 MQA 与 GQA 应用于解码器自注意力和跨注意力(不应用于编码器自注意力)。
  • 在摘要、翻译和问答基准上进行评估,报告每个样本的推理时间和开发集性能。

实验结果

研究问题

  • RQ1在有限额外计算的情况下,是否可以将 multi-head 检查点有效地 uptrain 成快速的 multi-query 形式?
  • RQ2分组查询注意力是否提供可调的速度与质量权衡,超越 MQA 并接近 MHA?
  • RQ3uptraining 如何影响不同任务和模型规模下的稳定性与性能?

主要发现

  • 经过 uptrain 的 MQA 在推理速度更快、质量高于 MHA-Large,但在某些情况下仍落后于 XXL-MHA。
  • 经过 uptrain 的 GQA 在速度接近 MQA、质量接近 MHA-XXL 的情况下,提供了有利的权衡。
  • 对检查点进行均值池化转换比选择某个头或随机初始化更能保留信息。
  • 5% 的 uptraining 可以提升性能,对于 MQA 和 GQA,超过 5–10% 的收益递减。
  • GQA 的性能在较大模型上受益更明显due to bandwidth dynamics and cache considerations.
  • 将 GQA 组数从 1(MQA)增至更多组,带来适度的加速同时成本增加,8 组被认为是一个较有利的中间点。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。