[论文解读] What Can Transformers Learn In-Context? A Case Study of Simple Function Classes
The paper trains Transformers from scratch to in-context learn simple function classes (linear, sparse linear, two-layer nets, decision trees) and matches or surpasses task-specific learning methods under various distributions.
In-context learning refers to the ability of a model to condition on a prompt sequence consisting of in-context examples (input-output pairs corresponding to some task) along with a new query input, and generate the corresponding output. Crucially, in-context learning happens only at inference time without any parameter updates to the model. While large language models such as GPT-3 exhibit some ability to perform in-context learning, it is unclear what the relationship is between tasks on which this succeeds and what is present in the training data. To make progress towards understanding in-context learning, we consider the well-defined problem of training a model to in-context learn a function class (e.g., linear functions): that is, given data derived from some functions in the class, can we train a model to in-context learn "most" functions from this class? We show empirically that standard Transformers can be trained from scratch to perform in-context learning of linear functions -- that is, the trained model is able to learn unseen linear functions from in-context examples with performance comparable to the optimal least squares estimator. In fact, in-context learning is possible even under two forms of distribution shift: (i) between the training data of the model and inference-time prompts, and (ii) between the in-context examples and the query input during inference. We also show that we can train Transformers to in-context learn more complex function classes -- namely sparse linear functions, two-layer neural networks, and decision trees -- with performance that matches or exceeds task-specific learning algorithms. Our code and models are available at https://github.com/dtsip/in-context-learning .
研究动机与目标
- Investigate whether a Transformer can be trained from scratch to in-context learn a defined function class from prompts containing input-output pairs.
- Assess how well such models approximate the function on unseen inputs compared to traditional learning methods.
- Analyze robustness to distribution shifts between training prompts and inference prompts.
- Examine how model capacity and problem dimensionality affect in-context learning performance.
提出的方法
- Train a decoder-only Transformer (12 layers, 8 heads, 256-dim embeddings) on prompts of (x_i, f(x_i)) to predict f(x_i) for subsequent queries.
- Sample random functions from a class D_F and inputs from D_X to form prompts with varying k in-context examples.
- Optimize an objective that minimizes average squared error across prompt prefixes as in equation (2).
- Use curriculum learning to gradually increase function class complexity and problem dimension during training.
- Compare Transformer performance to least-squares, k-NN, and inner-product baselines across function classes.
- Extend evaluation to out-of-distribution prompts to test generalization.
实验结果
研究问题
- RQ1Can a Transformer be trained from scratch to in-context learn a function class such as linear functions?
- RQ2How does in-context learning performance compare to optimal estimators (e.g., least squares) and simple baselines?
- RQ3Does the learned in-context learning extend to more complex function classes (sparse linear, neural nets, decision trees)?
- RQ4How do model capacity, problem dimension, and distribution shifts affect in-context learning performance?
- RQ5To what extent does curriculum learning facilitate training for higher-dimensional tasks?
主要发现
- Transformers trained from scratch can in-context learn linear functions with error comparable to least-squares across in-context example counts.
- Performance remains robust under distribution shifts between training and inference prompts and between in-context examples and query inputs.
- Models trained on sparse linear functions, two-layer ReLU nets, and decision trees achieve competitive or superior results to task-specific methods (e.g., Lasso, XGBoost, gradient-based nets).
- Increasing model capacity improves in-context learning, especially under out-of-distribution prompts, and enables higher-dimensional function learning.
- The approach demonstrates that Transformers can encode efficient, algorithm-like solutions in a single forward pass.
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。