[論文レビュー] Rethinking Batch Normalization in Transformers
この論文では、自然言語処理(NLP)におけるTransformerに向けた新しい正規化手法であるパワー正規化(PN)を提案する。PNは、NLPデータにおけるバッチ単位の統計的変動が原因で生じるバッチ正規化(BN)の不安定性を解消することを目的としている。ゼロ平均制約の緩和、走査平均の二次平均の使用、近似逆誤差伝搬の採用により、BNおよび層正規化(LN)よりも優れた訓練安定性と性能を達成しており、WMT14ではLNより0.6 BLEU、WikiText-103では5.6 PPLの向上を示した。
The standard normalization method for neural network (NN) models used in Natural Language Processing (NLP) is layer normalization (LN). This is different than batch normalization (BN), which is widely-adopted in Computer Vision. The preferred use of LN in NLP is principally due to the empirical observation that a (naive/vanilla) use of BN leads to significant performance degradation for NLP tasks; however, a thorough understanding of the underlying reasons for this is not always evident. In this paper, we perform a systematic study of NLP transformer models to understand why BN has a poor performance, as compared to LN. We find that the statistics of NLP data across the batch dimension exhibit large fluctuations throughout training. This results in instability, if BN is naively implemented. To address this, we propose Power Normalization (PN), a novel normalization scheme that resolves this issue by (i) relaxing zero-mean normalization in BN, (ii) incorporating a running quadratic mean instead of per batch statistics to stabilize fluctuations, and (iii) using an approximate backpropagation for incorporating the running statistics in the forward pass. We show theoretically, under mild assumptions, that PN leads to a smaller Lipschitz constant for the loss, compared with BN. Furthermore, we prove that the approximate backpropagation scheme leads to bounded gradients. We extensively test PN for transformers on a range of NLP tasks, and we show that it significantly outperforms both LN and BN. In particular, PN outperforms LN by 0.4/0.6 BLEU on IWSLT14/WMT14 and 5.6/3.0 PPL on PTB/WikiText-103. We make our code publicly available at \url{this https URL}.
研究の動機と目的
- バッチ正規化(BN)が自然言語処理(NLP)におけるTransformerで層正規化(LN)に比べて性能を発揮できない理由を調査すること。
- NLPにおいてナイーブなバッチ正規化を使用した場合の不安定性を引き起こす要因、特に大きなバッチ単位の統計的変動の性質を特定すること。
- これらの変動に対処しつつ、訓練効率を維持したままNLPにおける訓練を安定化させる新しい正規化スキームを設計すること。
- 弱い仮定の下で損失関数のリプシッツ定数が小さくなることを理論的に正当化すること。
提案手法
- バッチ正規化におけるゼロ平均制約を緩和し、バッチ統計に依存する感度を低減すること。
- 各バッチの統計を走査平均の二次平均に置き換えることで、訓練ステップ全体にわたる統計の安定化を図ること。
- 走査統計を順伝播に組み込む近似逆誤差伝搬スキームを導入し、より良い勾配伝搬を実現すること。
- 理論的分析により、弱い仮定の下でPNが損失関数のリプシッツ定数を小さくすることを示した。
- 近似逆誤差伝搬スキームが有界な勾配をもたらすことが証明され、訓練安定性の向上に寄与した。
- 提案手法をTransformerアーキテクチャに統合し、複数のNLPベンチマークで評価した。
実験結果
リサーチクエスチョン
- RQ1なぜ標準的なバッチ正規化は、層正規化(LN)に比べてNLPのTransformerで性能が劣化するのか?
- RQ2NLPデータに特有のどのような性質が、ナイーブなバッチ正規化を用いた場合の不安定性を引き起こすのか?
- RQ3走査統計を用いる修正された正規化スキームが、NLPにおける訓練の安定性と性能を向上させられるか?
- RQ4提案された正規化手法が、既存の手法よりも一般化性能が向上し、収束が速くなるか?
- RQ5提案手法に対して、有界な勾配やリプシッツ定数の低減といった理論的保証を確立できるか?
主な発見
- パワー正規化(PN)は、IWSLT14で層正規化(LN)を大きく上回り、0.4 BLEUの向上を達成した。
- WMT14では、LNより0.6 BLEUの向上を示し、翻訳タスク全体で一貫した改善を確認した。
- PTB言語モデリングベンチマークでは、LNに比べて周辺度(PPL)が5.6ポイント低下した。
- WikiText-103では、LNに比べて周辺度が3.0ポイント改善され、より強力な言語モデリング性能を示した。
- 理論的分析により、弱い仮定の下でPNはBNに比べて損失関数のリプシッツ定数が小さいことが確認された。
- PNにおける近似逆誤差伝搬スキームにより、勾配が有界であることが保証され、訓練安定性に寄与した。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。