Skip to main content
QUICK REVIEW

[Paper Review] How Transformers Learn Causal Structure with Gradient Descent

Eshaan Nichani, Alex Damian|arXiv (Cornell University)|Feb 22, 2024
Bayesian Modeling and Causal Inference5 citations
TL;DR

The paper proves that gradient descent on a simplified two-layer transformer learns latent causal structure by encoding the causal graph in the first attention layer, and shows induction heads arise in in-context Markov chain settings; gradient information in the attention mirrors mutual information.

ABSTRACT

The incredible success of transformers on sequence modeling tasks can be largely attributed to the self-attention mechanism, which allows information to be transferred between different parts of a sequence. Self-attention allows transformers to encode causal structure which makes them particularly suitable for sequence modeling. However, the process by which transformers learn such causal structure via gradient-based training algorithms remains poorly understood. To better understand this process, we introduce an in-context learning task that requires learning latent causal structure. We prove that gradient descent on a simplified two-layer transformer learns to solve this task by encoding the latent causal graph in the first attention layer. The key insight of our proof is that the gradient of the attention matrix encodes the mutual information between tokens. As a consequence of the data processing inequality, the largest entries of this gradient correspond to edges in the latent causal graph. As a special case, when the sequences are generated from in-context Markov chains, we prove that transformers learn an induction head (Olsson et al., 2022). We confirm our theoretical findings by showing that transformers trained on our in-context learning task are able to recover a wide variety of causal structures.

Motivation & Objective

  • Motivate understanding of how gradient-based training yields causal structure in transformers.
  • Introduce a random sequences with causal structure task to fix a latent graph.
  • Analyze training dynamics of a two-layer attention-only transformer under gradient descent.
  • Show that the gradient of the attention matrix captures mutual information and reveals graph edges.
  • Demonstrate how the approach generalizes to non-tree graphs via multi-head architectures and assess out-of-distribution performance.

Proposed method

  • Define a simplified two-layer disentangled transformer and a reduced model focusing on A^(1) and A^(2).
  • Construct a random sequences with causal structure task defined by a latent DAG on token positions.
  • Prove that gradient descent recovers the latent graph by encoding it in the first attention layer (A^(1)).
  • Show that the gradient corresponds to a chi-squared mutual information measure, guiding edge recovery via data processing inequality.
  • Special-case analysis: in-context Markov chains yield induction heads.
  • Provide a multi-head extension to distribute non-tree graphs across heads and validate empirically.

Experimental results

Research questions

  • RQ1Can gradient descent on a transformer recover latent causal structure from data generated with a fixed causal graph?
  • RQ2How is causal structure encoded inside the transformer’s attention layers during training?
  • RQ3What primitives (e.g., induction heads) arise in in-context learning scenarios such as Markov chains?
  • RQ4How does the model perform when the latent graph is not a tree, and can a multi-head design solve it?
  • RQ5Do the trained models generalize to out-of-distribution transitions?

Key findings

  • Gradient descent on a two-layer disentangled transformer learns to encode the latent causal graph in the first attention layer as the adjacency matrix.
  • The gradient of the first attention layer corresponds to chi-squared mutual information between tokens, and the data processing inequality concentrates learning on graph edges.
  • In the special case of in-context Markov chains, the model develops an induction head to perform in-context estimation of transitions.
  • When the causal graph is not a tree, a multi-head transformer can distribute the latent graph across heads to achieve solving behavior.
  • Empirically, trained transformers recover a variety of causal structures on the proposed task and show out-of-distribution generalization for transitions.
  • Theoretical guarantees (Theorem 1 and Theorem 2) establish population-loss convergence and OOD generalization under specified assumptions.

Better researchstarts right now

From paper design to paper writing, dramatically reduce your research time.

No credit card · Free plan available

This review was created by AI and reviewed by human editors.