[论文解读] ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training
ProphetNet 引入未来 n-gram 预测目标和 n-stream 自注意力用于 Seq2Seq 预训练,在摘要抽象与问题生成任务上达到最先进(SOTA)结果。它在每个步骤预测多个未来 token,并且在推理时可以转换为标准 Transformer 解码器。
This paper presents a new sequence-to-sequence pre-training model called ProphetNet, which introduces a novel self-supervised objective named future n-gram prediction and the proposed n-stream self-attention mechanism. Instead of optimizing one-step-ahead prediction in the traditional sequence-to-sequence model, the ProphetNet is optimized by n-step ahead prediction that predicts the next n tokens simultaneously based on previous context tokens at each time step. The future n-gram prediction explicitly encourages the model to plan for the future tokens and prevent overfitting on strong local correlations. We pre-train ProphetNet using a base scale dataset (16GB) and a large-scale dataset (160GB), respectively. Then we conduct experiments on CNN/DailyMail, Gigaword, and SQuAD 1.1 benchmarks for abstractive summarization and question generation tasks. Experimental results show that ProphetNet achieves new state-of-the-art results on all these datasets compared to the models using the same scale pre-training corpus.
研究动机与目标
- 通过计划未来 token 而不仅仅预测下一个 token,来提升 Seq2Seq 预训练的效果。
- 引入未来 n-gram 预测作为自监督目标,以降低对局部相关性的过拟合。
- 开发一个 n-stream 自注意力机制,以实现对多个未来 token 的同时预测。
- 在推理阶段通过禁用预测流,确保模型仍然兼容标准推断。
- 通过对基准 NLG 任务进行大量消融和比较,展示其有效性。
提出的方法
- 在 Transformer 编码器-解码器结构中增加额外的用于未来 token 预测的流(n-stream 自注意力)。
- 将未来 n-gram 损失定义为传统语言建模损失与预测下 n-1 个未来 token 的损失的组合(带衰减权重)。
- 使用去噪自编码器目标进行训练(基于掩码的范围掩码),并将其改造为在被掩盖的区间内预测 n-gram。
- 在 16GB(base)和 160GB(large)语料上进行预训练,使用与 MASS/BART/T5 相似的设置,输入长度 512 且含范围掩码。
- 推理阶段禁用预测流,使模型简化为标准的下一个 token 预测。
- 在 CNN/DailyMail、Gigaword 和 SQuAD 1.1 QG 任务上进行微调,以评估生成质量。
实验结果
研究问题
- RQ1相比一次性预测,预测未来 n-gram 是否改善了长程依赖捕捉与全局连贯性?
- RQ2n-stream 自注意力解码器在训练中是否能有效学会预测多个未来 token,同时保持与标准推断的兼容性?
- RQ3相较于其他预训练 Seq2Seq 模型,ProphetNet 在抽象摘要和问题生成基准上的表现如何?
主要发现
- ProphetNet 在 CNN/DailyMail 上达到最先进的 ROUGE 分数,R-1 为 43.68,R-2 为 20.64,R-L 为 40.72。
- 在 Gigaword 上,ProphetNet 在各指标上超越基线(表 2 中给出最高值)。
- 对于 SQuAD 1.1 问题生成,ProphetNet 相对于先前的方法在 BLEU/METEOR/ROUGE 上取得领先分数。
- 大规模预训练(160GB)带来进一步提升,在 CNN/DailyMail 和 Gigaword 上达到 SOTA,同时所需预训练数据显著少于某些基线。
- 在没有预训练的情况下,ProphetNet 仍相较于 Transformer 基线在 CNN/DailyMail 上有所提升。
- 在不同 n-gram 设置的对比中,2-gram 与 3-gram 变体优于 MASS 和 1-gram 基线,且 2-gram 在速度与精度之间提供有利的权衡。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。