Skip to main content
QUICK REVIEW

[논문 리뷰] Trained Transformers Learn Linear Models In-Context

Ruiqi Zhang, Spencer Frei|arXiv (Cornell University)|2023. 06. 16.
Domain Adaptation and Few-Shot Learning인용 수 13
한 줄 요약

이 논문은 선형 회귀 프롬프트에서 그래디언트 흐름으로 학습된 단일 계층 선형 자기 주의 트랜스포머가 컨텍스트 내에서 선형 모델을 학습할 수 있으며, 특정 조건에서 글로벌 최소로 수렴하고 최상의 선형 예측자와 비교 가능한 예측 오차를 달성한다는 것을 보여준다. 또한 분포 변화와 공변량 변화에 대한 강건성을 분석하며, 비선형 트랜스포머가 더 높은 강건성을 제공한다.

ABSTRACT

Attention-based neural networks such as transformers have demonstrated a remarkable ability to exhibit in-context learning (ICL): Given a short prompt sequence of tokens from an unseen task, they can formulate relevant per-token and next-token predictions without any parameter updates. By embedding a sequence of labeled training data and unlabeled test data as a prompt, this allows for transformers to behave like supervised learning algorithms. Indeed, recent work has shown that when training transformer architectures over random instances of linear regression problems, these models' predictions mimic those of ordinary least squares. Towards understanding the mechanisms underlying this phenomenon, we investigate the dynamics of ICL in transformers with a single linear self-attention layer trained by gradient flow on linear regression tasks. We show that despite non-convexity, gradient flow with a suitable random initialization finds a global minimum of the objective function. At this global minimum, when given a test prompt of labeled examples from a new prediction task, the transformer achieves prediction error competitive with the best linear predictor over the test prompt distribution. We additionally characterize the robustness of the trained transformer to a variety of distribution shifts and show that although a number of shifts are tolerated, shifts in the covariate distribution of the prompts are not. Motivated by this, we consider a generalized ICL setting where the covariate distributions can vary across prompts. We show that although gradient flow succeeds at finding a global minimum in this setting, the trained transformer is still brittle under mild covariate shifts. We complement this finding with experiments on large, nonlinear transformer architectures which we show are more robust under covariate shifts.

연구 동기 및 목표

  • 트랜스포머에서 컨텍스트 내 학습(ICL)이 함수 클래스(특히 선형 모델)에 대해 어떻게 작동하는지 이해하는 데 흥미를 가질 수 있도록 동기를 부여한다.
  • 선형 회귀 프롬프트에서 단일 층 선형 자기 주의 트랜스포머의 그래디언트 흐름 학습이 모집단 손실에서 글로벌 최소로 수렴함을 보인다.
  • 새로운 프롬프트 및 분포 변화 하에서의 예측 오차를 특징짓는다.
  • 공변량 변화에 대한 강건성과 공변량 분포가 다른 프롬프트로 일반화되는지 조사한다.
  • 선형 자기 주의와 더 큰 비선형 트랜스포머 간의 공변량 변화에 대한 강건성 차이를 비교한다.

제안 방법

  • LSA(선형 자기 주의 모듈)를 갖춘 한 층 트랜스포머와 간단한 매개변수화(WPV 및 WKQ)를 연구한다.
  • 가우시안 입력을 갖는 무작위 선형 회귀 태스크로 생성된 프롬프트에 대해 그래디언트 흐름으로 학습한다.
  • 적절한 초기화 하에서 모집단 손실의 글로벌 최솟값을 도출한다.
  • 수렴한 예측기와 테스트 프롬프트 예측에 대한 닫힌 형식 표현을 제공한다.
  • (x,y)의 결합 분포에서 추출된 테스트 프롬프트에 대한 예측 오차에 대한 경 bounds를 도출한다.
  • 등방성 및 비등방성 공분산에서의 동작을 비교하고, 공변량 변화 강건성을 평가하며, 비선형 트랜스포머로의 경험적 확장을 제시한다.

실험 결과

연구 질문

  • RQ1그래디언트 흐름으로 컨텍스트 프롬프트에 대해 학습된 LSAs가 선형 모델을 효과적으로 컨텍스트 내에서 학습하는 글로벌 최솟값에 도달할 수 있는가?
  • RQ2수렴 시 예측기 구조와 새로운 프롬프트에 대한 예측 오차는 어떤가?
  • RQ3특히 공변량 변화에 대해, 선형 모델에서 프롬프트를 학습한 경우 LSAs는 다양한 분포 변화에 얼마나 강건한가?
  • RQ4프롬프트의 공변량 분포가 작업 간에 변화하는 경우 공변량 변화에 따른 취약성을 완화하는가?
  • RQ5비선형 트랜스포머는 공변량 변화에 대해 어떻게 더 강건한가?

주요 결과

  • 모집단 손실에 대한 그래디언트 흐름은 적절한 초기화 아래 LSAs의 글로벌 최소로 수렴한다.
  • 수렴 시 모델은 테스트 프롬프트에서 선형 예측기를 컨텍스트 내에서 학습할 수 있는 학습 규칙을 구현한다.
  • 결합 분포( x,y )에서의 프롬프트에 대해, 질의에서의 예측 y는 최적의 선형 예측기 오차에 샘플 수 N, M이 커질수록 감소하는 유한 샘플 오차 항과 함께 나타난다.
  • 훈련된 LSAs는 작업 전이, 질의 전이와 같은 여러 분포 변화에 대해 강건성을 보이나, 공변량 분포의 변화에는 취약하다.
  • 프롬프트 간 공변량 변화가 존재하는 경우 LSAs는 여전히 글로벌 최소로 수렴하지만 새로운 프롬프트에서 성능이 저조한 반면, 더 큰 비선형 트랜스포머는 실험적으로 강건성이 향상된다.
  • 이론적 결과는 비선형 트랜스포머가 공변량 변화에 더 강건하다는 실험적 증거를 보완한다.

더 나은 연구,지금 바로 시작하세요

연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.

카드 등록 없음 · 무료 플랜 제공

이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.