[论文解读] How Transformers Learn Causal Structure with Gradient Descent
论文证明对一个简化的两层 Transformer 的梯度下降能够通过在第一注意力层编码因果图来学习潜在的因果结构,并且在上下文中的马尔可夫链设置中出现归纳头;注意力中的梯度信息反映互信息。
The incredible success of transformers on sequence modeling tasks can be largely attributed to the self-attention mechanism, which allows information to be transferred between different parts of a sequence. Self-attention allows transformers to encode causal structure which makes them particularly suitable for sequence modeling. However, the process by which transformers learn such causal structure via gradient-based training algorithms remains poorly understood. To better understand this process, we introduce an in-context learning task that requires learning latent causal structure. We prove that gradient descent on a simplified two-layer transformer learns to solve this task by encoding the latent causal graph in the first attention layer. The key insight of our proof is that the gradient of the attention matrix encodes the mutual information between tokens. As a consequence of the data processing inequality, the largest entries of this gradient correspond to edges in the latent causal graph. As a special case, when the sequences are generated from in-context Markov chains, we prove that transformers learn an induction head (Olsson et al., 2022). We confirm our theoretical findings by showing that transformers trained on our in-context learning task are able to recover a wide variety of causal structures.
研究动机与目标
- 激发对基于梯度的训练如何在 Transformer 中产生因果结构的理解。
- 引入一个带有因果结构的随机序列任务,以固定潜在图。
- 分析在梯度下降下一个两层仅注意力的 Transformer 的训练动态。
- 表明注意力矩阵的梯度捕捉到互信息并揭示图的边。
- 演示该方法如何通过多头架构推广到非树图,并评估分布外表现。
提出的方法
- 定义一个简化的两层解耦 Transformer,以及一个聚焦于 A^(1) 和 A^(2) 的简化模型。
- 构造一个基于令牌位置的潜在有向无环图(DAG)定义的带因果结构的随机序列任务。
- 证明梯度下降通过将潜在图编码在第一注意力层 (A^(1)) 来恢复潜在图。
- 表明梯度对应卡方互信息量度,并通过数据处理不等式引导边的恢复。
- 特例分析:上下文中的马尔可夫链产生归纳头。
- 提供多头扩展,将非树图分散到各头并在经验上验证。
实验结果
研究问题
- RQ1变换器上的梯度下降是否能够从用固定因果图生成的数据中恢复潜在因果结构?
- RQ2在训练过程中,因果结构如何编码在变换器的注意力层中?
- RQ3在上下文学习场景如马尔可夫链中,会出现哪些基本元件(如归纳头)?
- RQ4当潜在图不是树时模型的表现如何,能否通过多头设计解决?
- RQ5训练好的模型是否对分布外的转移具有泛化能力?
主要发现
- 对一个两层解耦 Transformer 的梯度下降学会将潜在因果图编码在第一注意力层为邻接矩阵。
- 第一注意力层的梯度对应令牌之间的卡方互信息,数据处理不等式将学习集中在图边上。
- 在上下文马尔可夫链的特例中,模型发展出一个归纳头来执行对转移的上下文估计。
- 当因果图不是树时,多头 Transformer 可以将潜在图分布到各头以实现求解行为。
- 实证上,训练好的 Transformer 在所提出的任务上恢复出多样的因果结构,并且对转移表现出分布外泛化。
- 理论保证(定理1和定理2)在给定假设下建立了总体损失收敛和OOD泛化。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。