Skip to main content
QUICK REVIEW

[논문 리뷰] What Can Transformers Learn In-Context? A Case Study of Simple Function Classes

Shivam Garg, Dimitris Tsipras|arXiv (Cornell University)|2022. 08. 01.
Topic Modeling인용 수 59
한 줄 요약

이 논문은 컨텍스트 내에서 간단한 함수 클래스(선형, 희소 선형, 2층 네트워크, 의사결정 트리)를 학습하기 위해 Transformer를 처음부터 학습시키고, 다양한 분포 조건에서 작업별 학습 방법과 이를 견주거나 능가하는 성능을 보인다.

ABSTRACT

In-context learning refers to the ability of a model to condition on a prompt sequence consisting of in-context examples (input-output pairs corresponding to some task) along with a new query input, and generate the corresponding output. Crucially, in-context learning happens only at inference time without any parameter updates to the model. While large language models such as GPT-3 exhibit some ability to perform in-context learning, it is unclear what the relationship is between tasks on which this succeeds and what is present in the training data. To make progress towards understanding in-context learning, we consider the well-defined problem of training a model to in-context learn a function class (e.g., linear functions): that is, given data derived from some functions in the class, can we train a model to in-context learn "most" functions from this class? We show empirically that standard Transformers can be trained from scratch to perform in-context learning of linear functions -- that is, the trained model is able to learn unseen linear functions from in-context examples with performance comparable to the optimal least squares estimator. In fact, in-context learning is possible even under two forms of distribution shift: (i) between the training data of the model and inference-time prompts, and (ii) between the in-context examples and the query input during inference. We also show that we can train Transformers to in-context learn more complex function classes -- namely sparse linear functions, two-layer neural networks, and decision trees -- with performance that matches or exceeds task-specific learning algorithms. Our code and models are available at https://github.com/dtsip/in-context-learning .

연구 동기 및 목표

  • 프롬프트에 입력-출력 쌍이 포함된 정의된 함수 클래스에서 Transformer가 컨텍스트 내 학습을 할 수 있는지 조사한다.
  • 그러한 모델이 보이지 않는 입력에 대해 함수 근사 성능이 전통적 학습 방법에 비해 얼마나 잘 수행되는지 평가한다.
  • 학습 프롬프트와 추론 프롬프트 간의 분포 변화에 대한 로버스트성을 분석한다.
  • 모델 용량과 문제 차원이 컨텍스트 내 학습 성능에 미치는 영향을 분석한다.

제안 방법

  • 12층 디코더 전용 Transformer(헤드 8, 임베딩 차원 256)를 (x_i, f(x_i))의 프롬프트에서 학습시켜 이후 질의에 대해 f(x_i)를 예측한다.
  • D_F 클래스에서 임의의 함수를 표본화하고 D_X에서 입력을 샘플링하여 k개의 컨텍스트 예제를 포함하는 프롬프트를 구성한다.
  • 식(2)와 같이 프롬프트 접두사들에 걸쳐 평균 제곱 오차를 최소화하는 목표를 최적화한다.
  • 학습 중에 함수 클래스의 복잡도와 문제 차원을 점진적으로 증가시키기 위한 커리큘럼 학습을 사용한다.
  • 함수 클래스 전반에서 Transformer 성능을 최소자승법, k-최근접 이웃, 내적 베이스라인과 비교한다.
  • 일반화 정확도를 테스트하기 위해 분포 외 프롬프트에 대한 평가를 확장한다.

실험 결과

연구 질문

  • RQ1Transformer를 처음부터 학습시켜 선형 함수와 같은 함수 클래스를 컨텍스트 내 학습할 수 있는가?
  • RQ2컨텍스트 내 학습 성능이 최적 추정치(예: 최소자승법) 및 간단한 베이스라인과 비교해 어떤가?
  • RQ3학습된 컨텍스트 내 학습이 희소 선형, 뉴럴 네트워크, 의사결정 트리와 같은 더 복잡한 함수 클래스에 확장되는가?
  • RQ4모델 용량, 문제 차원 및 분포 변화가 컨텍스트 내 학습 성능에 어떤 영향을 미치는가?
  • RQ5커리큘럼 학습이 고차원 작업 학습에 어느 정도 도움이 되는가?

주요 결과

  • 처음부터 학습된 Transformer는 컨텍스트 예제 개수에 따라 선형 함수에 대해 최소자승법과 비교할 만한 오차를 보이며 컨텍스트 내 학습을 수행할 수 있다.
  • 학습 프롬프트와 추론 프롬프트 간의 분포 변화 및 컨텍스트 예제와 질의 입력 간의 분포 변화에서도 성능이 강건하게 유지된다.
  • 희소 선형 함수, 두 층 ReLU 네트, 의사결정 트리에서 학습된 모델은 Lasso, XGBoost, 그래디언트 기반 네트워크와 같은 작업별 방법에 비해 경쟁력 있는 혹은 우수한 성과를 달성한다.
  • 모델 용량을 늘리면 특히 분포 외 프롬프트에서 컨텍스트 내 학습이 향상되며 더 높은 차원의 함수 학습이 가능해진다.
  • 이 접근법은 Transformer가 단일 순전파로도 효율적이고 알고리즘과 같은 해법을 인코딩할 수 있음을 시사한다.

더 나은 연구,지금 바로 시작하세요

연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.

카드 등록 없음 · 무료 플랜 제공

이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.