[論文レビュー] Fast Transformer Decoding: One Write-Head is All You Need
本稿では、すべてのアテンションヘッドでキーと値を共有する多クエリアテンションを提案する。これにより、インクリメンタルデコード時のメモリ帯域幅が著しく削減され、品質のわずかな低下を伴いながらも、最大12倍の高速化が達成される。これは、レイテンシに敏感な応用に最適である。
Multi-head attention layers, as used in the Transformer neural sequence model, are a powerful alternative to RNNs for moving information across and between sequences. While training these layers is generally fast and simple, due to parallelizability across the length of the sequence, incremental inference (where such paralleization is impossible) is often slow, due to the memory-bandwidth cost of repeatedly loading the large "keys" and "values" tensors. We propose a variant called multi-query attention, where the keys and values are shared across all of the different attention "heads", greatly reducing the size of these tensors and hence the memory bandwidth requirements of incremental decoding. We verify experimentally that the resulting models can indeed be much faster to decode, and incur only minor quality degradation from the baseline.
研究の動機と目的
- 現代のハードウェア上で推論速度を制限する、トランスフォーマーにおけるインクリメンタルデコードの高いメモリ帯域幅コストを軽減すること。
- モデル性能を著しく低下させることなく、マルチヘッドアテンションにおけるキーと値テンソルのサイズを小さくすること。
- 自己回帰的生成における繰り返しのメモリアクセスを最小限に抑えることで、レイテンシが重要な応用における高速推論を可能にすること。
- ヘッド間でキーと値を共有しても、標準的なマルチヘッドアテンションと比較して競争力のある品質を維持できるかどうかを評価すること。
提案手法
- すべてのアテンションヘッドが同じキーと値の投影を共有する多クエリアテンションを提案する。これにより、パラメータ数とメモリ帯域幅が削減される。
- 各ヘッドごとに別々のキーと値の投影を行う標準的なマルチヘッドアテンション機構を、すべてのヘッドに共通の1つの投影に置き換えることで変更する。
- 各ヘッドごとに同じクエリ投影を使用するが、キーと値の投影をすべてのヘッドで共有することで、KとVテンソルのサイズをO(hd)からO(d)に削減する。
- 共有されたKとV行列を用いた標準的な自己アテンション計算を採用し、同じアテンションメカニズムを維持しながら、メモリフットプリントを削減する。
- 標準的およびローカルアテンションの両方のバリエーションにこの手法を適用し、さまざまな設定での性能を評価する。
- TPUv2ハードウェア上でバッチ推論とインクリメンタルデコードを実行し、訓練および推論コストを測定する。
実験結果
リサーチクエスチョン
- RQ1すべてのアテンションヘッドでキーと値を共有することで、性能の著しい低下を伴わずにインクリメンタルデコード時のメモリ帯域幅を削減できるか?
- RQ2標準的なマルチヘッドアテンションや他の低減アテンション変種と比較して、提案された多クエリアテンションの品質はいかがであるか?
- RQ3多クエリアテンションは自己回帰的生成における推論速度をどの程度向上させるか?
- RQ4この手法は、機械翻訳および言語モデリングベンチマークでも競争力のある性能を維持するか?
- RQ5トレーニング時間やモデルサイズを増加させることなく、スルーレットの向上を達成できるか?
主な発見
- TPUv2上で、マルチクエリモデルはインクリメンタルデコードフェーズの1トークンあたりのデコーダー推論時間を47μsから3.8μsに短縮し、12.4倍の高速化を達成した。
- 1トークンあたりのアモアタイズド推論コストは、ベースラインの46μsからマルチクエリの3.8μsに低下し、エンコーダーのコストも1.7μsから1.5μsに減少した。
- WMT14 EN-DE翻訳タスクにおいて、マルチクエリモデルはビームサーチ(4)でBLEUスコア28.5を達成し、ベースライン(28.4)をわずかに上回った。
- 10億語の言語モデリングベンチマークでは、マルチクエリモデルのパープレキシティは30.2であり、ベースライン(29.9)とわずかに劣った。
- アーキテクチャの変更にもかかわらず、トレーニング時間は13.0μs/トークン(マルチクエリ)と13.2μs/トークン(ベースライン)とほぼ同等であり、トレーニングのオーバーヘッドがないことが示された。
- h、dk、またはdvを低減したすべての代替手法よりも、品質と速度の両面で優れており、メモリ帯域幅最適化の有効なソリューションであることが実証された。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。