Skip to main content
QUICK REVIEW

[论文解读] What Can Transformers Learn In-Context? A Case Study of Simple Function Classes

Shivam Garg, Dimitris Tsipras|arXiv (Cornell University)|Aug 1, 2022
Topic Modeling被引用 59
一句话总结

The paper trains Transformers from scratch to in-context learn simple function classes (linear, sparse linear, two-layer nets, decision trees) and matches or surpasses task-specific learning methods under various distributions.

ABSTRACT

In-context learning refers to the ability of a model to condition on a prompt sequence consisting of in-context examples (input-output pairs corresponding to some task) along with a new query input, and generate the corresponding output. Crucially, in-context learning happens only at inference time without any parameter updates to the model. While large language models such as GPT-3 exhibit some ability to perform in-context learning, it is unclear what the relationship is between tasks on which this succeeds and what is present in the training data. To make progress towards understanding in-context learning, we consider the well-defined problem of training a model to in-context learn a function class (e.g., linear functions): that is, given data derived from some functions in the class, can we train a model to in-context learn "most" functions from this class? We show empirically that standard Transformers can be trained from scratch to perform in-context learning of linear functions -- that is, the trained model is able to learn unseen linear functions from in-context examples with performance comparable to the optimal least squares estimator. In fact, in-context learning is possible even under two forms of distribution shift: (i) between the training data of the model and inference-time prompts, and (ii) between the in-context examples and the query input during inference. We also show that we can train Transformers to in-context learn more complex function classes -- namely sparse linear functions, two-layer neural networks, and decision trees -- with performance that matches or exceeds task-specific learning algorithms. Our code and models are available at https://github.com/dtsip/in-context-learning .

研究动机与目标

  • Investigate whether a Transformer can be trained from scratch to in-context learn a defined function class from prompts containing input-output pairs.
  • Assess how well such models approximate the function on unseen inputs compared to traditional learning methods.
  • Analyze robustness to distribution shifts between training prompts and inference prompts.
  • Examine how model capacity and problem dimensionality affect in-context learning performance.

提出的方法

  • Train a decoder-only Transformer (12 layers, 8 heads, 256-dim embeddings) on prompts of (x_i, f(x_i)) to predict f(x_i) for subsequent queries.
  • Sample random functions from a class D_F and inputs from D_X to form prompts with varying k in-context examples.
  • Optimize an objective that minimizes average squared error across prompt prefixes as in equation (2).
  • Use curriculum learning to gradually increase function class complexity and problem dimension during training.
  • Compare Transformer performance to least-squares, k-NN, and inner-product baselines across function classes.
  • Extend evaluation to out-of-distribution prompts to test generalization.

实验结果

研究问题

  • RQ1Can a Transformer be trained from scratch to in-context learn a function class such as linear functions?
  • RQ2How does in-context learning performance compare to optimal estimators (e.g., least squares) and simple baselines?
  • RQ3Does the learned in-context learning extend to more complex function classes (sparse linear, neural nets, decision trees)?
  • RQ4How do model capacity, problem dimension, and distribution shifts affect in-context learning performance?
  • RQ5To what extent does curriculum learning facilitate training for higher-dimensional tasks?

主要发现

  • Transformers trained from scratch can in-context learn linear functions with error comparable to least-squares across in-context example counts.
  • Performance remains robust under distribution shifts between training and inference prompts and between in-context examples and query inputs.
  • Models trained on sparse linear functions, two-layer ReLU nets, and decision trees achieve competitive or superior results to task-specific methods (e.g., Lasso, XGBoost, gradient-based nets).
  • Increasing model capacity improves in-context learning, especially under out-of-distribution prompts, and enables higher-dimensional function learning.
  • The approach demonstrates that Transformers can encode efficient, algorithm-like solutions in a single forward pass.

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。