[论文解读] WizardMath: Empowering Mathematical Reasoning for Large Language Models via Reinforced Evol-Instruct
WizardMath 是一个基于开源 Llama-2 的模型,使用 Evol-Instruct Feedback 的强化学习(RLEIF)进行训练,在 GSM8k 和 MATH 上实现最先进的数学推理,超过许多开源模型以及在这些基准测试中的一些闭源模型。
Large language models (LLMs), such as GPT-4, have shown remarkable performance in natural language processing (NLP) tasks, including challenging mathematical reasoning. However, most existing open-source models are only pre-trained on large-scale internet data and without math-related optimization. In this paper, we present WizardMath, which enhances the mathematical CoT reasoning abilities of LLMs without using external python tools, by applying our proposed Reinforcement Learning from Evol-Instruct Feedback (RLEIF) method to the domain of math. Through extensive experiments on two mathematical reasoning benchmarks, namely GSM8k and MATH, we reveal the extraordinary capabilities of our model. Remarkably, WizardMath-Mistral 7B surpasses top-tier open-source LLMs by a substantial margin with higher data efficiency. Furthermore, WizardMath 70B even outperforms GPT-3.5-Turbo, Claude 2, Gemini Pro and GPT-4-early-version. Additionally, our preliminary exploration highlights the pivotal role of instruction evolution and process supervision in achieving exceptional math performance. For more details refer to https://github.com/nlpxucan/WizardLM
研究动机与目标
- 提高对开源 LLM 中更好数学推理的需求的认识。
- 提出一个新训练框架(RLEIF),将 Evol-Instruct、指令奖励建模和过程监督奖励结合起来。
- 在 GSM8k 和 MATH 上与开源模型及部分闭源模型相比,展示最先进性能。
提出的方法
- 用包含逐步数学解答的监督指令跟随数据对 Llama-2 进行微调。
- 开发 Evol-Instruct,生成多样且逐步更难/更易的数学指令(向下与向上演化)。
- 训练两种奖励模型:指令奖励模型(IRM)用于指令质量,过程监督奖励模型(PRM)用于逐步解答反馈。
- 使用近端策略优化(PPO)在演化数据上以最终奖励 r = rI × rA 进行强化学习。
实验结果
研究问题
- RQ1RLEIF 与基于 Evol-Instruct 的数据增强是否可以提升开源 LLM 的数学推理,使其超越基线开源模型?
- RQ2在 GSM8k 和 MATH 上,WizardMath 与闭源及其他开源模型相比的表现如何?
- RQ3在 RLEIF 框架下,扩大模型规模(7B、13B、70B)对 GSM8k 和 MATH 性能有何影响?
主要发现
- WizardMath 70B 在 GSM8k 上达到 81.6 的 pass@1,相对于基线 56.8,提升了 +24.8。
- WizardMath 70B 在 MATH 上达到 22.7 的 pass@1,相对于基线 13.5,提升了 +9.2。
- WizardMath 13B 在 GSM8k 上达到 63.9 的 pass@1,相对于基线 28.7,提升了 +35.2。
- WizardMath 13B 在 MATH 上达到 14.0 的 pass@1,相对于基线 3.9,提升了 +10.1。
- WizardMath 7B 在 GSM8k 上达到 54.9 的 pass@1,相对于基线 51.6,提升了 +3.3,在 MATH 上达到 10.7,相对于基线 2.9,提升了 +7.7。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。