[논문 리뷰] What Can Transformers Learn In-Context? A Case Study of Simple Function Classes
이 논문은 컨텍스트 내에서 간단한 함수 클래스(선형, 희소 선형, 2층 네트워크, 의사결정 트리)를 학습하기 위해 Transformer를 처음부터 학습시키고, 다양한 분포 조건에서 작업별 학습 방법과 이를 견주거나 능가하는 성능을 보인다.
In-context learning refers to the ability of a model to condition on a prompt sequence consisting of in-context examples (input-output pairs corresponding to some task) along with a new query input, and generate the corresponding output. Crucially, in-context learning happens only at inference time without any parameter updates to the model. While large language models such as GPT-3 exhibit some ability to perform in-context learning, it is unclear what the relationship is between tasks on which this succeeds and what is present in the training data. To make progress towards understanding in-context learning, we consider the well-defined problem of training a model to in-context learn a function class (e.g., linear functions): that is, given data derived from some functions in the class, can we train a model to in-context learn "most" functions from this class? We show empirically that standard Transformers can be trained from scratch to perform in-context learning of linear functions -- that is, the trained model is able to learn unseen linear functions from in-context examples with performance comparable to the optimal least squares estimator. In fact, in-context learning is possible even under two forms of distribution shift: (i) between the training data of the model and inference-time prompts, and (ii) between the in-context examples and the query input during inference. We also show that we can train Transformers to in-context learn more complex function classes -- namely sparse linear functions, two-layer neural networks, and decision trees -- with performance that matches or exceeds task-specific learning algorithms. Our code and models are available at https://github.com/dtsip/in-context-learning .
연구 동기 및 목표
- 프롬프트에 입력-출력 쌍이 포함된 정의된 함수 클래스에서 Transformer가 컨텍스트 내 학습을 할 수 있는지 조사한다.
- 그러한 모델이 보이지 않는 입력에 대해 함수 근사 성능이 전통적 학습 방법에 비해 얼마나 잘 수행되는지 평가한다.
- 학습 프롬프트와 추론 프롬프트 간의 분포 변화에 대한 로버스트성을 분석한다.
- 모델 용량과 문제 차원이 컨텍스트 내 학습 성능에 미치는 영향을 분석한다.
제안 방법
- 12층 디코더 전용 Transformer(헤드 8, 임베딩 차원 256)를 (x_i, f(x_i))의 프롬프트에서 학습시켜 이후 질의에 대해 f(x_i)를 예측한다.
- D_F 클래스에서 임의의 함수를 표본화하고 D_X에서 입력을 샘플링하여 k개의 컨텍스트 예제를 포함하는 프롬프트를 구성한다.
- 식(2)와 같이 프롬프트 접두사들에 걸쳐 평균 제곱 오차를 최소화하는 목표를 최적화한다.
- 학습 중에 함수 클래스의 복잡도와 문제 차원을 점진적으로 증가시키기 위한 커리큘럼 학습을 사용한다.
- 함수 클래스 전반에서 Transformer 성능을 최소자승법, k-최근접 이웃, 내적 베이스라인과 비교한다.
- 일반화 정확도를 테스트하기 위해 분포 외 프롬프트에 대한 평가를 확장한다.
실험 결과
연구 질문
- RQ1Transformer를 처음부터 학습시켜 선형 함수와 같은 함수 클래스를 컨텍스트 내 학습할 수 있는가?
- RQ2컨텍스트 내 학습 성능이 최적 추정치(예: 최소자승법) 및 간단한 베이스라인과 비교해 어떤가?
- RQ3학습된 컨텍스트 내 학습이 희소 선형, 뉴럴 네트워크, 의사결정 트리와 같은 더 복잡한 함수 클래스에 확장되는가?
- RQ4모델 용량, 문제 차원 및 분포 변화가 컨텍스트 내 학습 성능에 어떤 영향을 미치는가?
- RQ5커리큘럼 학습이 고차원 작업 학습에 어느 정도 도움이 되는가?
주요 결과
- 처음부터 학습된 Transformer는 컨텍스트 예제 개수에 따라 선형 함수에 대해 최소자승법과 비교할 만한 오차를 보이며 컨텍스트 내 학습을 수행할 수 있다.
- 학습 프롬프트와 추론 프롬프트 간의 분포 변화 및 컨텍스트 예제와 질의 입력 간의 분포 변화에서도 성능이 강건하게 유지된다.
- 희소 선형 함수, 두 층 ReLU 네트, 의사결정 트리에서 학습된 모델은 Lasso, XGBoost, 그래디언트 기반 네트워크와 같은 작업별 방법에 비해 경쟁력 있는 혹은 우수한 성과를 달성한다.
- 모델 용량을 늘리면 특히 분포 외 프롬프트에서 컨텍스트 내 학습이 향상되며 더 높은 차원의 함수 학습이 가능해진다.
- 이 접근법은 Transformer가 단일 순전파로도 효율적이고 알고리즘과 같은 해법을 인코딩할 수 있음을 시사한다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.