[论文解读] NormFormer: Improved Transformer Pretraining with Extra Normalization
NormFormer 在每个 transformer 层中为 Pre-LN 模型添加三种轻量级基于归一化的操作,降低梯度不匹配并加速预训练,同时在因果和掩码语言模型上提升困惑度和下游任务性能。
During pretraining, the Pre-LayerNorm transformer suffers from a gradient magnitude mismatch: gradients at early layers are much larger than at later layers. These issues can be alleviated by our proposed NormFormer architecture, which adds three normalization operations to each layer: a Layer Norm after self attention, head-wise scaling of self-attention outputs, and a Layer Norm after the first fully connected layer. The extra operations incur negligible compute cost (+0.4% parameter increase), but improve pretraining perplexity and downstream task performance for both causal and masked language models ranging from 125 Million to 2.7 Billion parameters. For example, adding NormFormer on top of our strongest 1.3B parameter baseline can reach equal perplexity 24% faster, or converge 0.27 perplexity better in the same compute budget. This model reaches GPT3-Large (1.3B) zero shot performance 60% faster. For masked language modeling, NormFormer improves fine-tuned GLUE performance by 1.9% on average. Code to train NormFormer models is available in fairseq https://github.com/pytorch/fairseq/tree/main/examples/normformer .
研究动机与目标
- 在预训练期间识别 Pre-LN 转换器中的梯度幅值不匹配。
- 提出基于轻量级归一化的附加以稳定并加速训练。
- 在多种规模下对因果与掩码语言模型评估 NormFormer。
- 展示预训练困惑度和下游任务性能的提升。
- 提供消融研究与分析以理解每个附加项的作用。
提出的方法
- 引入每层三个附加项:对 MHA 输出进行头维度缩放(HeadScale)、在注意力模块后添加 LayerNorm,以及在第一 FFN 层后添加 LayerNorm。
- 在 MHA 路径内再应用一个 LayerNorm,并在 FFN 之后再应用一个 LN,具有每个头和每个残差路径的可学习参数 γ。
- 可选地在 FFN 路径包含残差缩放(ResScale),并分析其在不同规模下的影响。
- 在大小为 125M、355M、1.3B 与 2.7B 的因果与掩码语言模型上进行训练,在等量计算预算下将 NormFormer 与计算匹配的基线进行比较。
- 进行 GPT-3 风格任务和 GLUE 基准的零-shot 评估以评估泛化。
实验结果
研究问题
- RQ1添加 NormFormer 的额外归一化操作是否稳定 Pre-LN 转换器并在各层之间缩小梯度差距?
- RQ2NormFormer 的增益是否在从 125M 到 2.7B 参数的模型规模中持续?
- RQ3新增操作如何影响预训练困惑度和下游任务性能(GLUE)?相较于调优后的 Pre-LN 基线。
- RQ4NormFormer 中残差缩放在不同模型规模下的作用是什么?
- RQ5若去除任一新增组件,增益是否仍然健壮?
主要发现
| Model Size (|θ|, M) | λ_resid | PPL | CoLA | MNLI | MRPC | QNLI | QQP | RTE | SST-2 | Avg | |||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 125 | - | 125.42 | 3.42 | 74.3 | 85.9 | 84.6 | 91.6 | 90.7 | 66.4 | 92.9 | 83.77 | ||
| 125 | - | NormFormer | 125.50 | - | 3.31 | 82.6 | 86.3 | 86.0 | 91.9 | 91.3 | 67.9 | 93.8 | 85.69 |
| 125 | - | NormFormer | 125.51 | - | 3.29 | 80.9 | 86.2 | 85.3 | 91.5 | 91.2 | 62.8 | 94.2 | 84.59 |
| 355 | - | GPT3-355M (paper) | 355.0 | - | 3e-4 | - | - | - | - | - | - | - | - |
| 355 | - | GPT3-355M (replicated) | 355.0 | - | 15.41 | 46.1 | 70.8 | 54.6 | 71.1 | 41.2 | 56.8 | ||
| 355 | - | NormFormer-355M | 355.0 | - | 14.54 | 49.7 | 71.8 | 56.0 | 73.8 | 43.6 | 59.0 | ||
| 355 | - | NormFormer-355M | 355.0 | - | 14.52 | 49.7 | 72.0 | 56.7 | 73.2 | 43.8 | 59.1 | ||
| 1300 | - | GPT3-1.3B (paper) | 1313.5 | - | 2e-4 | - | - | - | - | - | - | - | - |
| 1300 | - | GPT3-1.3B (replicated) | 1313.5 | - | 12.56 | 58.5 | 74.6 | 58.1 | 76.8 | 49.4 | 63.5 | ||
| 1300 | - | GPT3-1.3B (High LR) | 1313.5 | - | 6e-4 | 57.5 | 74.3 | 59.3 | 76.3 | 50.8 | 63.6 | ||
| 1300 | - | NormFormer-1.3B | 1314.0 | - | 6e-4 | 60.5 | 74.5 | 60.1 | 77.5 | 50.8 | 64.7 | ||
| 2649 | - | GPT3-2.7B (paper) | 2648.7 | - | 1.6e-4 | - | - | - | - | - | - | - | - |
| 2649 | - | GPT3-2.7B (replicated) | 2648.7 | - | 10.92 | 65.9 | 76.6 | 61.4 | 78.2 | 49.6 | 66.3 | ||
| 2649 | 6e-4 | NormFormer-2.7B | 2649.5 | - | 6e-4 | 68.1 | 78.1 | 64.4 | 79.4 | 53.4 | 68.7 |
- NormFormer 在 125M–2.7B 范围内对因果和掩码语言模型的预训练困惑度与下游任务性能均有提升。
- 对于 1.3B 模型,NormFormer 能更快达到基线困惑度并在相同计算预算下比对齐达到相同困惑度的速度快 24%;在相同的计算预算下,其困惑度可比基线低 0.27。
- 零-shot 评估显示 NormFormer 在所有规模上都优于 GPT-3 在所测试任务上的表现。
- GLUE 微调结果显示 NormFormer MLMs 在各任务上优于 Pre-LN 基线,且有平均增益。
- 消融研究表明移除任何新增操作都会降低性能;HeadScale 和注意力后 LN 的影响尤为显著。
- 学习得到的缩放参数 (γ) 可以降低前层 FG 梯度并下调前期 FFN 输入,而 HeadScale 可以强调某些头,从而帮助稳定性和性能。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。