[论文解读] Trained Transformers Learn Linear Models In-Context
本论文表明,通过对线性回归提示进行梯度流训练的单层线性自注意力变换器可以在上下文中学习线性模型,收敛到全局最小值,并在某些条件下实现的预测误差与最佳线性预测器相当;它还分析对分布偏移和协变量偏移的鲁棒性,非线性变换器提供了更高的鲁棒性。
Attention-based neural networks such as transformers have demonstrated a remarkable ability to exhibit in-context learning (ICL): Given a short prompt sequence of tokens from an unseen task, they can formulate relevant per-token and next-token predictions without any parameter updates. By embedding a sequence of labeled training data and unlabeled test data as a prompt, this allows for transformers to behave like supervised learning algorithms. Indeed, recent work has shown that when training transformer architectures over random instances of linear regression problems, these models' predictions mimic those of ordinary least squares. Towards understanding the mechanisms underlying this phenomenon, we investigate the dynamics of ICL in transformers with a single linear self-attention layer trained by gradient flow on linear regression tasks. We show that despite non-convexity, gradient flow with a suitable random initialization finds a global minimum of the objective function. At this global minimum, when given a test prompt of labeled examples from a new prediction task, the transformer achieves prediction error competitive with the best linear predictor over the test prompt distribution. We additionally characterize the robustness of the trained transformer to a variety of distribution shifts and show that although a number of shifts are tolerated, shifts in the covariate distribution of the prompts are not. Motivated by this, we consider a generalized ICL setting where the covariate distributions can vary across prompts. We show that although gradient flow succeeds at finding a global minimum in this setting, the trained transformer is still brittle under mild covariate shifts. We complement this finding with experiments on large, nonlinear transformer architectures which we show are more robust under covariate shifts.
研究动机与目标
- 理解变换器中在上下文学习(ICL)对于函数类的理解,聚焦线性模型。
- 证明单层线性自注意力变换器在线性回归提示上通过梯度流训练收敛到全局最小值。
- 描述对新提示的预测误差以及在分布偏移下的情况。
- 研究对协变量偏移的鲁棒性并推广到具有不同协变量分布的提示。
- 对比线性自注意力与更大非线性变换器在协变量偏移鲁棒性方面。
提出的方法
- 研究一个具备线性自注意力模块(LSA)的单层变换器及简化参数化(WPV 和 WKQ)。
- 在来自带高斯输入的随机线性回归任务的提示上通过梯度流进行训练。
- 在合适初始化下推导总体损失的全局极小值。
- 为极限预测器和测试提示预测提供闭式表达式。
- 推导从( x, y) 联合分布抽取的测试提示的预测误差界限。
- 比较各向同性与各向异性协方差下的行为并评估对协变量偏移的鲁棒性,将其扩展到非线性变换器的经验分析。
实验结果
研究问题
- RQ1梯度流训练的就上下文提示是否能使 LSAs 达到全局极小值,从而在上下文中有效学习线性模型?
- RQ2收敛时的预测器结构及其对新提示的预测误差是什么?
- RQ3当在来自线性模型的提示上训练时,LSAs 对各种分布偏移,尤其是协变量偏移,有多鲁棒?
- RQ4不同任务之间具有不同协变量分布的提示是否能减轻协变量偏移下的脆弱性?
- RQ5非线性变换器在对协变量偏移的鲁棒性方面有何差异?
主要发现
- 在适当初始化下,对总体损失的梯度流收敛到 LSAs 的全局极小值。
- 收敛时,模型实现了一个学习规则,可以在测试提示上就地学习线性预测器。
- 对来自联合分布(x, y) 的提示,查询上的预测 y 等于最佳线性预测误差再加上随提示长度 N 和 M 减少的有限样本误差项。
- 训练的 LSAs 对若干分布偏移(任务偏移、查询偏移)表现出鲁棒性,但在协变量分布的协变量偏移下较脆弱。
- 当提示间存在协变量偏移时,LSAs 仍然收敛到全局极小值但在新提示上表现较差,而更大的非线性变换器在经验上显示出更强的鲁棒性。
- 理论结果辅以实验,表明非线性变换器对协变量偏移更鲁棒。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。