Skip to main content
QUICK REVIEW

[論文レビュー] Trained Transformers Learn Linear Models In-Context

Ruiqi Zhang, Spencer Frei|arXiv (Cornell University)|Jun 16, 2023
Domain Adaptation and Few-Shot Learning被引用数 13
ひとこと要約

本論文は、線形回帰プロンプト上での勾配流によって訓練された単一層の線形自己注意トランスフォーマーが、文脈内で線形モデルを学習できることを示し、グローバルミニマムへ収束し、かつ特定の条件下で最良の線形予測子と競合する予測誤差を達成する。さらに、分布シフトおよび共変量シフトに対する頑健性を分析し、非線形トランスフォーマーは頑健性を高めることを示す。

ABSTRACT

Attention-based neural networks such as transformers have demonstrated a remarkable ability to exhibit in-context learning (ICL): Given a short prompt sequence of tokens from an unseen task, they can formulate relevant per-token and next-token predictions without any parameter updates. By embedding a sequence of labeled training data and unlabeled test data as a prompt, this allows for transformers to behave like supervised learning algorithms. Indeed, recent work has shown that when training transformer architectures over random instances of linear regression problems, these models' predictions mimic those of ordinary least squares. Towards understanding the mechanisms underlying this phenomenon, we investigate the dynamics of ICL in transformers with a single linear self-attention layer trained by gradient flow on linear regression tasks. We show that despite non-convexity, gradient flow with a suitable random initialization finds a global minimum of the objective function. At this global minimum, when given a test prompt of labeled examples from a new prediction task, the transformer achieves prediction error competitive with the best linear predictor over the test prompt distribution. We additionally characterize the robustness of the trained transformer to a variety of distribution shifts and show that although a number of shifts are tolerated, shifts in the covariate distribution of the prompts are not. Motivated by this, we consider a generalized ICL setting where the covariate distributions can vary across prompts. We show that although gradient flow succeeds at finding a global minimum in this setting, the trained transformer is still brittle under mild covariate shifts. We complement this finding with experiments on large, nonlinear transformer architectures which we show are more robust under covariate shifts.

研究の動機と目的

  • 関数クラスとしてのトランスフォーマーにおける文脈内学習(ICL)の理解を促進し、線形モデルに焦点を当てる。
  • 線形回帰プロンプト上で単一層の線形自己注意トランスフォーマーを勾配流で訓練するとグローバルミニマムに収束することを示す。
  • 新しいプロンプトおよび分布シフト下での予測誤差を特徴づける。
  • 共変量シフトに対する頑健性を調査し、共変量分布が異なるプロンプトへの一般化を検討する。
  • 共変量シフトの頑健性の観点から、線形自己注意とより大規模な非線形トランスフォーマーを対比する。

提案手法

  • 線形自己注意モジュール(LSA)を備えた1層トランスフォーマーと簡略化されたパラメータ化(WPVおよびWKQ)を検討する。
  • ガウス入力を持つランダムな線形回帰タスクから生成されたプロンプトに対して勾配流で訓練する。
  • 適切な初期化の下で母集団損失のグローバルミニマイザを導出する。
  • 極限予測子およびテストプロンプト予測の閉形式表現を提供する。
  • xとyの結合分布から抽出されたテストプロンプトに対する予測誤差の上界を導出する。
  • 等方および非等方の共分散の下での振る舞いを比較し、共変量シフトの頑健性を評価する。実証的には非線形トランスフォーマーへ拡張する。

実験結果

リサーチクエスチョン

  • RQ1文脈内プロンプトに対する勾配流訓練は、LSAをグローバルミニマムへ導き、文脈内で線形モデルを効果的に学習させることができるか。
  • RQ2収束時の予測子の構造と新しいプロンプトに対する予測誤差はどうなるか。
  • RQ3線形モデルからのプロンプトで訓練した場合、さまざまな分布シフト、特に共変量シフトに対するLSAの頑健性はどれくらいか。
  • RQ4タスク間で共変量分布が異なるプロンプトは、共変量シフト下での脆さを緩和するか。
  • RQ5非線形トランスフォーマーは共変量シフトへの頑健性でどう比較されるか。

主な発見

  • 適切な初期化の下で、母集団損失に対する勾配流はLSAにグローバルミニマムへ収束する。
  • 収束時、モデルはテストプロンプト上で線形予測子を文脈内学習できる学習規則を実装する。
  • 結合分布 (x,y) のプロンプトからの場合、クエリに対する予測yは最良の線形予測子誤差と、NおよびMのプロンプト長に応じて小さくなる有限サンプル誤差項の和である。
  • 訓練されたLSAsは、タスクシフトやクエリシフトなどのいくつかの分布シフトに対して頑健だが、共変量分布の共変量シフトには脆弱である。
  • プロンプト間で共変量シフトが存在する場合、LSAは依然としてグローバルミニマムへ収束するが新しいプロンプトでは性能が低下する。一方、より大きな非線形トランスフォーマーは実験的に頑健性の改善を示す。
  • 理論的結果は、非線形トランスフォーマーが共変量シフトに対してより頑健であることを示す実験によって補完される。

より良い研究を、今すぐ始めましょう

論文設計から論文執筆まで、研究時間を劇的に削減しましょう。

クレジットカード登録不要

このレビューはAIが作成し、人間の編集者が確認しました。