Skip to main content
QUICK REVIEW

[论文解读] Learned Token Pruning for Transformers

Sehoon Kim, Sheng Shen|arXiv (Cornell University)|Jul 2, 2021
Advanced Neural Network Applications被引用 24
一句话总结

论文提出 Learned Token Pruning (LTP),一种基于阈值的变换器剪枝方法,能够针对每一层自适应地裁剪 token,使用可学习的每层阈值,在显著减少 FLOPs 的同时几乎不损失准确性,并且对输入长度变化具有更强的鲁棒性。

ABSTRACT

Deploying transformer models in practice is challenging due to their inference cost, which scales quadratically with input sequence length. To address this, we present a novel Learned Token Pruning (LTP) method which adaptively removes unimportant tokens as an input sequence passes through transformer layers. In particular, LTP prunes tokens with an attention score below a threshold value which is learned for each layer during training. Our threshold-based method allows the length of the pruned sequence to vary adaptively based on the input sequence, and avoids algorithmically expensive operations such as top-k token selection. We extensively test the performance of LTP on GLUE tasks and show that our method outperforms the prior state-of-the-art token pruning methods by up to ~2.5% higher accuracy with the same amount of FLOPs. In particular, LTP achieves up to 2.1x FLOPs reduction with less than 1% accuracy drop, which results in up to 1.9x and 2.0x throughput improvement on Intel Haswell CPUs and NVIDIA V100 GPUs, respectively. Furthermore, we demonstrate that LTP is more robust than prior methods to variations on input sentence lengths. Our code has been developed in PyTorch and has been open-sourced.

研究动机与目标

  • 通过在每层自适应裁剪 token 来减少 transformer 模型的推理成本。
  • 开发一个可微分的、基于阈值的剪枝机制,避免 top-k 选择。
  • 在训练期间学习每层的剪枝阈值,并在推理时应用硬剪枝。
  • 展示在 GLUE 与 SQuAD 数据集上的效率提升与鲁棒性。

提出的方法

  • 通过跨头和 token 的平均注意力概率来定义 token 重要性。
  • 用基于 sigmoid 的阈值替换 top-k 剪枝,采用可学习的、可微分的软掩码,使梯度流向阈值。
  • 在软剪枝阶段将阈值与模型参数联合训练,然后二值化并微调(3 步过程)。
  • 引入一个 L1 正则化项以促进剪枝并稳定阈值学习。
  • 展示与量化和知识蒸馏的兼容性,以进一步压缩模型。

实验结果

研究问题

  • RQ1基于阈值的 token 剪枝是否能够在跨 NLP 任务中实现显著的 FLOPs 降低同时保持可比较的准确性?
  • RQ2相比于固定剪枝配置,可学习的每层阈值是否对不同输入序列长度具有更强的鲁棒性?
  • RQ3就准确性与 FLOPs 之比而言,LTP 与先前的 token 剪枝方法(SpAtten、LAT)相比如何?
  • RQ4该方法是否与量化、蒸馏等其他压缩技术兼容?

主要发现

任务RoBERTa_base_准确度LTP_准确度RoBERTa_base_GFLOPsLTP_GFLOPs加速
MNLI-m87.5386.536.833.641.88×
MNLI-mm87.3686.377.153.631.97×
QQP90.3989.695.312.532.10×
QNLI92.8691.988.944.771.87×
SST-294.2793.464.452.132.09×
STS-B90.8990.035.532.841.95×
MRPC92.1491.599.334.442.10×
RTE77.9877.9811.386.301.81×
SQuAD 2.083.0482.2532.1216.991.89×
  • LTP 在 GLUE/SQuAD 上实现高达 2.10× 的 FLOPs 降低且准确率下降不足 1%;在 CPU/GPU 上的吞吐提升高达 1.93–1.97×。
  • 在相同 FLOPs 条件下,LTP 持续优于 SpAtten 和 LAT,跨任务的准确度最高可高出约 2.5%。
  • LTP 对输入长度的变化具有较强的鲁棒性,在句长不同的任务上显著优于 LAT。
  • 直接的硬件吞吐量演示显示出高达约 1.9×–2.0× 的增益,且随批次大小增大而增加。
  • 与 LTP 结合时,量化和知识蒸馏可进一步将 BOPs 降低多达 10×,且精度损失很小。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。