[論文レビュー] Transformers learn in-context by gradient descent
この論文は、Transformerの文脈学習が勾配降下更新として力学的に理解できることを示し、自己注意層が文脈データ上のGDステップを実装できること、そしてMLPを追加することで深い表現上の勾配ベース学習による非線形回帰が可能になることを証明している。
At present, the mechanisms of in-context learning in Transformers are not well understood and remain mostly an intuition. In this paper, we suggest that training Transformers on auto-regressive objectives is closely related to gradient-based meta-learning formulations. We start by providing a simple weight construction that shows the equivalence of data transformations induced by 1) a single linear self-attention layer and by 2) gradient-descent (GD) on a regression loss. Motivated by that construction, we show empirically that when training self-attention-only Transformers on simple regression tasks either the models learned by GD and Transformers show great similarity or, remarkably, the weights found by optimization match the construction. Thus we show how trained Transformers become mesa-optimizers i.e. learn models by gradient descent in their forward pass. This allows us, at least in the domain of regression problems, to mechanistically understand the inner workings of in-context learning in optimized Transformers. Building on this insight, we furthermore identify how Transformers surpass the performance of plain gradient descent by learning an iterative curvature correction and learn linear models on deep data representations to solve non-linear regression tasks. Finally, we discuss intriguing parallels to a mechanism identified to be crucial for in-context learning termed induction-head (Olsson et al., 2022) and show how it could be understood as a specific case of in-context learning by gradient descent learning within Transformers. Code to reproduce the experiments can be found at https://github.com/google-research/self-organising-systems/tree/master/transformers_learn_icl_by_gd .
研究の動機と目的
- Transformersにおける文脈内学習機構の理解を動機づける。
- 線形自己注意更新と線形回帰における1ステップの勾配降下との等価性を示す。
- 注意層を積み重ねることで反復的なGD様の更新と曲率修正(GD++)を可能にすることを示す。
- 深い表現上での勾配降下を介した非線形回帰をMLPが可能にする方法を説明する。
- メタ学習、ファストウェイト、誘導頭(induction-head)機構との関連を議論する。
提案手法
- 線形回帰損失に対する勾配降下更新と等価になる単一の線形自己注意ステップを作る重み構成を導出する。
- 線形回帰タスクで、訓練済みの線形自己注意層とGD構成を経験的に比較して整合性を評価する。
- 多層自己注意へ拡張し、反復的なデータ変換(GD++)と残差曲率補正によるGD様挙動を示す。
- TransformersへのMLPの組み込みが、深い表現上での勾配降下によって非線形回帰タスクを解くことを可能にする(カーネル回帰の視点)。
- トークン構成とデータ変換を調査し、Transformerがフォワードパス内で勾配ベースの更新を介して文脈内学習を実装できることを示す。
実験結果
リサーチクエスチョン
- RQ1単一の線形自己注意層は線形回帰タスクに対して勾配降下ステップを実装できるのか?
- RQ2自己注意層を持つ訓練済みのTransformerは線形回帰データに対してGD様の解に収束するのか?
- RQ3複数の注意層とMLPは、Transformerが勾配降下ベースの更新(GD++、非線形タスク)を実行する能力にどのように影響するのか?
- RQ4Transformerの文脈内学習は、フォワードパス内でアルゴリズム(メサ最適化)を学習することとして理解できるのか?
- RQ5フォワードパス外で明示的な重み更新を伴わずに文脈内学習を可能にするトークン構成とデータ変換の役割は何か?
主な発見
- 1ヘッドの線形自己注意層は線形回帰の訓練データに対して勾配降下風の更新を実行できる。
- 訓練済みの線形自己注意層は、構築されたGD更新と密接に一致し、予測や感度も類似している。
- 自己注意層のスタックは反復的な曲率補正(GD++)を実装でき、線形タスクで普通のGDより優れる。
- MLPを組み込むと、深層表現上で勾配降下を行うことで非線形回帰を解くことが可能になり、カーネル様回帰を実質的に可能にする。
- トランスフォーマーは学習されたデータ変換とタスク固有の表現を介して文脈内学習を生み出すことができ、メサ最適化とファストウェイトの概念と一致する。
- そのアーキテクチャは、分布内/分布外タスクを横断する勾配ベースの学習ダイナミクスを再現または近似できる。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。