[论文解读] Are Sixteen Heads Really Better than One?
论文表明,许多 transformer 的注意力头在测试时可以剪枝,几乎不损失性能,有时将层降到单一头并显著提高推理速度。
Attention is a powerful and ubiquitous mechanism for allowing neural models to focus on particular salient pieces of information by taking their weighted average when making predictions. In particular, multi-headed attention is a driving force behind many recent state-of-the-art NLP models such as Transformer-based MT models and BERT. These models apply multiple attention mechanisms in parallel, with each attention "head" potentially focusing on different parts of the input, which makes it possible to express sophisticated functions beyond the simple weighted average. In this paper we make the surprising observation that even if models have been trained using multiple heads, in practice, a large percentage of attention heads can be removed at test time without significantly impacting performance. In fact, some layers can even be reduced to a single head. We further examine greedy algorithms for pruning down models, and the potential speed, memory efficiency, and accuracy improvements obtainable therefrom. Finally, we analyze the results with respect to which parts of the model are more reliant on having multiple heads, and provide precursory evidence that training dynamics play a role in the gains provided by multi-head attention.
研究动机与目标
- 推动对经过训练的 Transformer 模型中多头注意力必要性的实证研究。
- 量化在机器翻译和自然语言推断中哪些注意力头对性能是必需的。
- 开发一种剪枝策略,在不重新训练的情况下识别并移除不那么重要的头。
- 分析头剪枝对不同注意力组件(Enc-Enc、Enc-Dec、Dec-Dec)以及训练动态中的影响。
提出的方法
- 定义一种掩蔽机制,在多头注意力中禁用单个注意力头。
- 在移除单个头后以及将整个层减少到一个头后评估性能。
- 提出一个基于损失对头掩蔽的预期敏感度的重要性分数 I_h,通过前向/反向传播估计。
- 通过将头按 I_h(或代理)排序并分步剪枝来进行迭代剪枝,以研究累积效应。
- 比较两种成熟模型的剪枝效果:WMT English→French Transformer 和 BERT base on MNLI。
- 在 GPU 上测量性能(BLEU、准确度)和推理速度提升。
实验结果
研究问题
- RQ1在训练好的基于 Transformer 的模型中,单个注意力头对于 MT 和 NLI 任务有多重要?
- RQ2将层减少到一个头不会损害性能吗?哪些层对这种减少具有抵抗力?
- RQ3在 MT 模型中,头剪枝对编码器-解码器注意力与自注意力组件有何影响?
- RQ4头的重要性在训练过程中如何演变,何时被判定为重要或冗余?
- RQ5从剪枝头中产生的实际效率提升(速度/内存),在什么条件下这些提升最明显?
主要发现
- 大多数注意力头在测试时可以移除而不会显著影响性能。
- 一些层可以减少到单个头而几乎不产生影响,尽管编码器-解码器注意力通常需要更多头。
- 使用重要性代理进行迭代剪枝可以在 WMT 中剪掉最多约 20% 的头,在 BERT 中剪掉约 40% 而不明显下降;进一步剪枝会导致显著下降。
- 在 MT 中,编码器-解码器注意力对剪枝比自注意力更敏感,表明对多头的依赖程度不同。
- 训练动态显示头在训练早期变得更明显重要,后期出现更具容忍性的剪枝机制。
- 剪枝带来显著的效率提升,在较大批量下移除 50% 的头时,BERT 的推理速度提高可达 17.5%;分配给 MHA 的参数占总参数的相当大比例(大致三分之一)。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。