[論文レビュー] What Can Transformers Learn In-Context? A Case Study of Simple Function Classes
本論文は、Transformerをゼロから訓練してインコンテキスト学習を行い、単純な関数クラス(線形、スパース線形、二層ネット、決定木)を学習させ、さまざまな分布下でタスク特有の学習法と同等かそれを上回る性能を達成する。
In-context learning refers to the ability of a model to condition on a prompt sequence consisting of in-context examples (input-output pairs corresponding to some task) along with a new query input, and generate the corresponding output. Crucially, in-context learning happens only at inference time without any parameter updates to the model. While large language models such as GPT-3 exhibit some ability to perform in-context learning, it is unclear what the relationship is between tasks on which this succeeds and what is present in the training data. To make progress towards understanding in-context learning, we consider the well-defined problem of training a model to in-context learn a function class (e.g., linear functions): that is, given data derived from some functions in the class, can we train a model to in-context learn "most" functions from this class? We show empirically that standard Transformers can be trained from scratch to perform in-context learning of linear functions -- that is, the trained model is able to learn unseen linear functions from in-context examples with performance comparable to the optimal least squares estimator. In fact, in-context learning is possible even under two forms of distribution shift: (i) between the training data of the model and inference-time prompts, and (ii) between the in-context examples and the query input during inference. We also show that we can train Transformers to in-context learn more complex function classes -- namely sparse linear functions, two-layer neural networks, and decision trees -- with performance that matches or exceeds task-specific learning algorithms. Our code and models are available at https://github.com/dtsip/in-context-learning .
研究の動機と目的
- 入力-出力ペアを含むプロンプトから、定義された関数クラスをインコンテキスト学習できるように、Transformerをゼロから訓練できるかを調査する。
- 従来の学習法と比較して、未知の入力に対してこのようなモデルが関数をどの程度近似できるかを評価する。
- 訓練プロンプトと推論プロンプトとの間の分布シフトに対するロバスト性を分析する。
- モデル容量と問題次元がインコンテキスト学習の性能にどう影響するかを検討する。
提案手法
- デコーダー専用のTransformer(12層、8ヘッド、256次元の埋め込み)を、(x_i, f(x_i))のプロンプト上で訓練し、後続のクエリに対してf(x_i)を予測する。
- クラスD_Fからランダムな関数を、D_Xから入力をサンプルして、さまざまなkのインコンテキスト例を含むプロンプトを作成する。
- 式(2)のように、プロンプト接頭辞全体にわたる平均二乗誤差を最小化する目的関数を最適化する。
- 訓練中にカリキュラム学習を用いて、関数クラスの複雑さと問題次元を徐々に増やす。
- 関数クラス全体で、最小二乗法、k-NN、内積ベースのベースラインとTransformerの性能を比較する。
- 一般化をテストするために、分布外のプロンプトに対する評価を拡張する。
実験結果
リサーチクエスチョン
- RQ1線形関数のような関数クラスを、ゼロから訓練してインコンテキスト学習できるTransformerを作れるか。
- RQ2インコンテキスト学習の性能は、最適推定量(例:最小二乗法)や単純なベースラインとどう比較されるか。
- RQ3学習済みのインコンテキスト学習は、より複雑な関数クラス(スパース線形、ニューラルネット、決定木)に拡張されるか。
- RQ4モデル容量、問題次元、および分布シフトがインコンテキスト学習の性能にどう影響するか。
- RQ5カリキュラム学習が高次元タスクの訓練をどの程度促進するか。
主な発見
- ゼロから訓練されたTransformerは、インコンテキスト例数にわたって最小二乗法と同等の誤差で線形関数をインコンテキスト学習できる。
- 訓練プロンプトと推論プロンプトの分布シフト、またインコンテキスト例とクエリ入力間の分布シフトにおいて、性能は頑健性を保つ。
- スパース線形関数、2層のReLUネット、決定木で訓練されたモデルは、タスク特化法(例:Lasso、XGBoost、勾配ベースのネット)と競合するか、または上回る結果を達成する。
- モデル容量を増やすとインコンテキスト学習が改善され、特に分布外プロンプト下で、より高次元の関数学習を可能にする。
- このアプローチは、Transformersが効率的な、アルゴリズムのような解を単一の前方伝播で符号化できることを示している。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。