Skip to main content
QUICK REVIEW

[論文レビュー] Learned Token Pruning for Transformers

Sehoon Kim, Sheng Shen|arXiv (Cornell University)|Jul 2, 2021
Advanced Neural Network Applications被引用数 24
ひとこと要約

本論文は Learned Token Pruning (LTP) を導入します。これはトランスフォーマー向けの閾値ベースのトークン剪定法で、学習可能な各層の閾値を用いて層ごとに適応的にトークンを剪定します。FLOPsを大幅に削減しつつ精度の低下を最小限に抑え、入力長の変動へのロバスト性を改善します。

ABSTRACT

Deploying transformer models in practice is challenging due to their inference cost, which scales quadratically with input sequence length. To address this, we present a novel Learned Token Pruning (LTP) method which adaptively removes unimportant tokens as an input sequence passes through transformer layers. In particular, LTP prunes tokens with an attention score below a threshold value which is learned for each layer during training. Our threshold-based method allows the length of the pruned sequence to vary adaptively based on the input sequence, and avoids algorithmically expensive operations such as top-k token selection. We extensively test the performance of LTP on GLUE tasks and show that our method outperforms the prior state-of-the-art token pruning methods by up to ~2.5% higher accuracy with the same amount of FLOPs. In particular, LTP achieves up to 2.1x FLOPs reduction with less than 1% accuracy drop, which results in up to 1.9x and 2.0x throughput improvement on Intel Haswell CPUs and NVIDIA V100 GPUs, respectively. Furthermore, we demonstrate that LTP is more robust than prior methods to variations on input sentence lengths. Our code has been developed in PyTorch and has been open-sourced.

研究の動機と目的

  • 層ごとに適応的にトークンを剪定して、トランスフォーマーモデルの推論コストを削減する。
  • top-k トークン選択を回避する微分可能な閾値ベースの剪定機構を開発する。
  • 学習中に層ごとの剪定閾値を学習し、推論時にはハード剪定を適用する。
  • GLUEおよびSQuADデータセットで効率向上とロバスト性を示す。

提案手法

  • ヘッドとトークン全体での平均アテンション確率によってトークンの重要性を定義する。
  • top-k剪定を、シグモイドベースの閾値を用いた学習可能で微分可能なソフトマスクに置き換え、閾値への勾配伝播を可能にする。
  • ソフト剪定フェーズで閾値をモデルパラメータとともに訓練し、次に二値化してファインチューニングする(3段階プロセス)。
  • 剪定を促進し閾値学習を安定化させるL1正則化項を導入する。
  • 量子化および知識蒸留との適合性を示し、モデルをさらに圧縮する。

実験結果

リサーチクエスチョン

  • RQ1閾値ベースのトークン剪定は、NLPタスク全体でFLOPsを大幅に削減しつつ、精度を比較可能なレベルに保てるか?
  • RQ2学習可能な層ごとの閾値は、固定剪定設定と比較して、さまざまな入力シーケンス長へのロバスト性を提供するか?
  • RQ3精度とFLOPsの観点で、従来のトークン剪定法(SpAtten、LAT)と比較してLTPはどうか?
  • RQ4このアプローチは、量子化や蒸留など他の圧縮技術と互換性があるか?

主な発見

TaskRoBERTa_base_AccuracyLTP_AccuracyRoBERTa_base_GFLOPsLTP_GFLOPsSpeedup
MNLI-m87.5386.536.833.641.88×
MNLI-mm87.3686.377.153.631.97×
QQP90.3989.695.312.532.10×
QNLI92.8691.988.944.771.87×
SST-294.2793.464.452.132.09×
STS-B90.8990.035.532.841.95×
MRPC92.1491.599.334.442.10×
RTE77.9877.9811.386.301.81×
SQuAD 2.083.0482.2532.1216.991.89×
  • LTPは、GLUE/SQuADで最大で2.10×のFLOPs削減を達成し、精度低下は1%未満、CPU/GPUで最大1.93–1.97×のスループット向上を実現。
  • LTPは同じFLOPsでSpAttenおよびLATを一貫して上回り、タスク全体で最大約2.5%の精度向上。
  • LTPは入力長の変動に対して高いロバスト性を示し、LATを大幅に上回るタスクもある。
  • 直接的なハードウェアスループットのデモンストレーションは、最大で約1.9×–2.0×のゲインを示し、バッチサイズとともに増加。
  • 量子化と知識蒸留は、LTPと組み合わせた場合、最大で約10×のBOPs削減を、僅かな精度低下とともに実現可能。

より良い研究を、今すぐ始めましょう

論文設計から論文執筆まで、研究時間を劇的に削減しましょう。

クレジットカード登録不要

このレビューはAIが作成し、人間の編集者が確認しました。