[论文解读] Improving Sequence-to-Sequence Learning via Optimal Transport
本文提出了一种新颖的序列到序列学习框架,通过基于最优传输(OT)的序列级监督来改进训练,替代或增强标准的最大似然估计(MLE)。通过最小化生成序列与参考序列之间的Wasserstein距离,该方法提升了语义对齐性并减少了暴露偏差,从而在机器翻译、抽象摘要和图像字幕生成任务中实现了持续的性能提升。
Sequence-to-sequence models are commonly trained via maximum likelihood estimation (MLE). However, standard MLE training considers a word-level objective, predicting the next word given the previous ground-truth partial sentence. This procedure focuses on modeling local syntactic patterns, and may fail to capture long-range semantic structure. We present a novel solution to alleviate these issues. Our approach imposes global sequence-level guidance via new supervision based on optimal transport, enabling the overall characterization and preservation of semantic features. We further show that this method can be understood as a Wasserstein gradient flow trying to match our model to the ground truth sequence distribution. Extensive experiments are conducted to validate the utility of the proposed approach, showing consistent improvements over a wide variety of NLP tasks, including machine translation, abstractive text summarization, and image captioning.
研究动机与目标
- 解决词级MLE训练与序列级评估指标(如BLEU和ROUGE)之间的不匹配问题。
- 通过引入全局序列级监督,克服自回归生成中的暴露偏差。
- 开发一种稳健且可微分的序列级损失函数,避免强化学习或对抗训练带来的不稳定性。
- 通过OT方法将生成序列与输入序列和参考序列对齐,提升生成序列的语义保持力和结构连贯性。
- 证明基于OT的正则化在多种序列到序列任务(包括翻译、摘要和图像字幕)中的泛化能力。
提出的方法
- 提出一种基于最优传输(OT)的序列级损失,计算生成序列与参考序列之间的Wasserstein距离,以促进语义相似性。
- 将训练目标表述为正则化的MLE损失,结合交叉熵与基于OT的正则化项,引导模型实现更好的语义对齐。
- 通过同时计算生成序列与输入序列之间的OT距离,扩展监督信号,确保模型在生成过程中充分利用源信息。
- 将训练过程解释为近似的Wasserstein梯度流,以最小化模型输出分布与真实数据分布之间的距离。
- 通过熵正则化实现OT距离的可微分近似,以支持神经网络中的端到端反向传播。
- 将该方法应用于多种架构(如基于GRU的Seq2Seq、类似Transformer的模型)和任务,无需架构重构,展现出广泛的适用性。
实验结果
研究问题
- RQ1最优传输能否为序列到序列模型提供比词级MLE更有效的序列级监督?
- RQ2基于OT的正则化是否能减少自回归生成中的暴露偏差并提升泛化能力?
- RQ3与强化学习和对抗训练相比,该方法在训练稳定性和性能方面表现如何?
- RQ4该OT损失是否能提升翻译、摘要和图像字幕等多样化NLP任务中的语义保真度和结构连贯性?
- RQ5该基于OT的方法对超参数选择是否具有鲁棒性,特别是组合损失中的权重系数γ?
主要发现
- 在所有评估任务中,OT增强模型均显著优于MLE基线,WMT'14英德翻译数据集上BLEU分数最高提升2.4分。
- 在Gigaword摘要数据集上,模型ROUGE-L得分为34.0,比基线Seq2Seq高出1.6分,且超越了使用更复杂架构的SOTA结果(36.92)。
- 在DUC-2004摘要数据集上,ROUGE-L从24.8提升至26.0,表明在更小、更具挑战性的数据集上也表现出色。
- 在COCO图像字幕任务中,BLEU-4从81.5提升至83.2,CIDEr从120.1提升至124.3,表明在多个指标上均取得一致提升,且未出现对单一指标的过拟合。
- 该方法对超参数γ具有鲁棒性,测试BLEU分数在γ ∈ (0,1]范围内始终高于基线,表明性能稳定。
- 定性分析显示,与原始MLE模型相比,该模型能更好地保留关键语义术语,减少误解,尤其在翻译和摘要任务中表现更优。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。