[论文解读] Transformers learn in-context by gradient descent
这篇论文表明,在Transformer中的上下文学习可以被机械地理解为梯度下降更新,证明自注意力层能够在上下文数据上实现GD步骤,并且增加MLP能够通过对深度表征的梯度学习实现非线性回归。
At present, the mechanisms of in-context learning in Transformers are not well understood and remain mostly an intuition. In this paper, we suggest that training Transformers on auto-regressive objectives is closely related to gradient-based meta-learning formulations. We start by providing a simple weight construction that shows the equivalence of data transformations induced by 1) a single linear self-attention layer and by 2) gradient-descent (GD) on a regression loss. Motivated by that construction, we show empirically that when training self-attention-only Transformers on simple regression tasks either the models learned by GD and Transformers show great similarity or, remarkably, the weights found by optimization match the construction. Thus we show how trained Transformers become mesa-optimizers i.e. learn models by gradient descent in their forward pass. This allows us, at least in the domain of regression problems, to mechanistically understand the inner workings of in-context learning in optimized Transformers. Building on this insight, we furthermore identify how Transformers surpass the performance of plain gradient descent by learning an iterative curvature correction and learn linear models on deep data representations to solve non-linear regression tasks. Finally, we discuss intriguing parallels to a mechanism identified to be crucial for in-context learning termed induction-head (Olsson et al., 2022) and show how it could be understood as a specific case of in-context learning by gradient descent learning within Transformers. Code to reproduce the experiments can be found at https://github.com/google-research/self-organising-systems/tree/master/transformers_learn_icl_by_gd .
研究动机与目标
- 激发对 Transformer 中上下文学习机制的理解。
- 展示线性自注意力更新与线性回归的一步梯度下降之间的等价性。
- 演示堆叠注意力层能够实现迭代的类似 GD 的更新和曲率矫正(GD++)。
- 解释MLP如何通过对深度表征的梯度下降实现非线性回归。
- 讨论与元学习、快速权重,以及 induction-head 机制的联系。
提出的方法
- 推导一种权重构造,使单步线性自注意力等价于线性回归损失上的梯度下降更新。
- 在线性回归任务上,经验性地将训练好的线性自注意力层与 GD 构造进行比较,以评估一致性。
- 扩展到多层自注意力,展示在迭代数据变换(GD++)和残差曲率校正下的类似 GD 行为。
- 证明在 Transformer 中加入 MLPs 可以通过对深度表征的梯度下降来解决非线性回归任务(核回归视角)。
- 研究标记构造和数据变换,以展示 Transformer 如何在前向传播中通过基于梯度的更新实现上下文学习。
实验结果
研究问题
- RQ1单个线性自注意力层是否可以在线性回归任务上实现梯度下降步骤?
- RQ2在带有自注意力层的训练过的 Transformer 上线性回归数据是否收敛到类似 GD 的解?
- RQ3多层注意力和 MLPs 如何影响 Transformer 执行基于梯度下降的更新(GD++,非线性任务)的能力?
- RQ4在 Transformer 中的上下文学习是否可以理解为在前向传播中学习一个算法(mesa-optimization)?
- RQ5标记构造和数据变换在使上下文学习无需在前向传播之外进行显式权重更新的作用是什么?
主要发现
- 一个单头线性自注意力层可以在用于线性回归的训练数据上执行梯度下降式更新。
- 训练好的线性自注意力层与构造的 GD 更新高度一致,包括相似的预测和灵敏度。
- 多层自注意力堆栈能够实现迭代曲率校正(GD++),在线性任务上表现优于普通的 GD。
- 将 MLPs 融入 Transformer 使其通过对深度表征执行梯度下降来解决非线性回归,从而有效地实现了核回归風格。
- Transformer 可以通过学习的数据变换和任务特异性表示实现上下文学习,符合 mesa-optimization 与 fast weights 的概念。
- 该架构能够在分布内和分布外任务中复制或逼近基于梯度的学习动态。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。