[論文レビュー] Orca-Math: Unlocking the potential of SLMs in Grade School Math
Orca-Math は 200K の合成数学問題で訓練された 7B パラメータの SLM が external tools や ensembling なしで 86.81% の GSM8K pass@1 を達成できることを示す。反復的な好み学習と高品質なエージェント生成データセットを介して。
Mathematical word problem-solving has long been recognized as a complex task for small language models (SLMs). A recent study hypothesized that the smallest model size, needed to achieve over 80% accuracy on the GSM8K benchmark, is 34 billion parameters. To reach this level of performance with smaller models, researcher often train SLMs to generate Python code or use tools to help avoid calculation errors. Additionally, they employ ensembling, where outputs of up to 100 model runs are combined to arrive at a more accurate result. Result selection is done using consensus, majority vote or a separate a verifier model used in conjunction with the SLM. Ensembling provides a substantial boost in accuracy but at a significant cost increase with multiple calls to the model (e.g., Phi-GSM uses top-48 to boost the performance from 68.2 to 81.5). In this work, we present Orca-Math, a 7-billion-parameter SLM based on the Mistral-7B, which achieves 86.81% on GSM8k without the need for multiple model calls or the use of verifiers, code execution or any other external tools. Our approach has the following key elements: (1) A high quality synthetic dataset of 200K math problems created using a multi-agent setup where agents collaborate to create the data, (2) An iterative learning techniques that enables the SLM to practice solving problems, receive feedback on its solutions and learn from preference pairs incorporating the SLM solutions and the feedback. When trained with Supervised Fine-Tuning alone, Orca-Math achieves 81.50% on GSM8k pass@1 metric. With iterative preference learning, Orca-Math achieves 86.81% pass@1. Orca-Math surpasses the performance of significantly larger models such as LLAMA-2-70B, WizardMath-70B, Gemini-Pro, ChatGPT-3.5. It also significantly outperforms other smaller models while using much smaller data (hundreds of thousands vs. millions of problems).
研究の動機と目的
- 小学生レベルの語学問題に対する数学的推論を向上させるため、SLM のモチベーションを高める。
- 高品質な合成データパイプラインと反復学習を提案して SLM の推論を強化。
- 少量データと外部ツールなしで小さなモデルが GSM8K で大規模モデルを上回ることを示す。
提案手法
- エージェントベースのデータ生成パイプラインを介して GPT-4-Turbo の解答を用いた 200K の数学問題の Orca-Math データセットを作成する。
- 監視付き微調整、学生の問題解答、解法に対する教員フィードバックを含む反復学習ループを採用する。
- 正/負の解答信号と嗜好ベースの微調整(DPO と KTO)を用いてモデルを整合させ、外部検証ツールなしで。
- GSM8K などのベンチマークで GPT-4 ベースの抽出を用いた厳密一致風のプロンプト(GPT4-based-Exact-Match)で評価する。
- より大規模なモデルと比較して、はるかに少ないデータで競争力を示す(例: LLama-2-70B, WizardMath-70B, Gemini-Pro, GPT-3.5)。
- モデル生成のポジティブ/ネガティブサンプルの重要性を示すアブレーションを報告する。
実験結果
リサーチクエスチョン
- RQ17B の SLM が ensembling や外部ツールなしで GSM8K pass@1 を 80% 以上達成できるか。
- RQ2正/負の信号を伴う反復的好み学習は、SLM の数学的推論における標準的な監視付き微調整とどう比較されるか。
- RQ3エージェント生成の高品質な合成データが小型モデルの数学的推論性能に与える影響は何か。
- RQ4モデル生成のポジティブと合成ネガティブが訓練効果に有意に寄与するか。
主な発見
| Model | Base model | Model size | Answer format | Eval method | GSM8K (%) |
|---|---|---|---|---|---|
| Llama-2 | 7B | nlp | pass@1 | pass@1 | 14.6 |
| Llama-2 | 13B | nlp | pass@1 | pass@1 | 28.7 |
| Llama-2 | 34B | nlp | pass@1 | pass@1 | 42.2 |
| Llama-2 | 70B | nlp | pass@1 | pass@1 | 56.8 |
| MetaMath | Llama-2 | 7B | nlp | pass@1 | 66.5 |
| MetaMath | Llama-2 | 13B | nlp | pass@1 | 72.3 |
| MetaMath | Llama-2 | 70B | nlp | pass@1 | 82.3 |
| WizardMath | Llama-2 | 7B | nlp | pass@1 | 54.9 |
| WizardMath | Llama-2 | 13B | nlp | pass@1 | 63.9 |
| WizardMath | Llama-2 | 70B | nlp | pass@1 | 81.6 |
| MammoTH | Code-Llama | 7B | code | pass@1 | 59.4 |
| MammoTH | Code-Llama | 12B | code | pass@1 | 64.7 |
| MammoTH | Code-Llama | 34B | code | pass@1 | 72.7 |
| MammoTH | Llama-2 | 70B | nlp | pass@1 | 76.9 |
| Mistral | 7B | 7B | nlp | maj1@8 | 52.2 |
| Mistral | 8×7B | - | nlp | maj1@8 | 58.4 |
| OVM | Llama-2 | 7B+7B | nlp | verify100@1 | 73.7 |
| Mistral | 7B+7B | - | nlp | verify100@1 | 84.7 |
| Llemma | 7B | 7B | nlp | pass@1 | 36.4 |
| Llemma | 34B | 34B | nlp | pass@1 | 51.5 |
| ToRA-Code | 7B | 7B | code | COT@1 | 72.6 |
| ToRA-Code | 13B | 13B | - | COT@1 | 75.8 |
| ToRA-Code | 34B | 34B | - | COT@1 | 80.7 |
| ToRA-Code | 70B | 70B | - | COT@1 | 84.3 |
| Orca 2 | Llama-2 | 7B | nlp | pass@1 | 55.72 |
| Orca 2 | Llama-2 | 13B | nlp | pass@1 | 65.73 |
| Gemini Pro | - | - | nlp | maj1@32 | 86.5 |
| GPT-3.5-0613 | - | - | code | pass@1 | 77.4 |
| GPT-4-0613 | - | - | - | - | 97.0 |
| Phi-1.5 | 1.3B | code | pass@1 | 44.6 | |
| Phi-GSM | 1.5-tiny | 125M | code | pass@1 | 63.1 |
| Phi-GSM | 1.5-small | 350M | code | pass@1 | 65.9 |
| Phi-GSM | 1.5 | 1.3B | code | pass@1 | 68.2 |
| Phi-GSM+V | 1.5-tiny+ | 125M+125M | code | verify48@1 | 68.9 |
| Phi-GSM+V | 1.5-small+ | 350M+350M | code | verify48@1 | 71.3 |
| Phi-GSM+V | 1.5+ | 1.3B+1.3B | code | verify48@1 | 81.5 |
| Orca-Math | Mistral | 7B | nlp | pass@1 | 86.81 |
- 監視付き微調整だけで Orca-Math は 81.50% の GSM8K pass@1 を達成。
- 反復的好み学習により検証者や外部ツールなしで GSM8K pass@1 が 86.81% に向上。
- Orca-Math(7B、Mistral)は報告された設定下で LLama-2-70B、WizardMath-70B、Gemini-Pro、GPT-3.5 などの大規模モデルを上回る。
- 200K の合成データセットは、多くのベースラインよりはるかに少ないデータ量で競争力のある性能を達成。
- アブレーションにより教師生成ポジティブのみの使用は性能を低下させることが示され、モデル生成ポジティブおよびネガティブ信号は有益。
- Orca-Math は GPT4 ベースの exact-match を用いた非 GSM8K の数学ベンチマーク(AddSub, ASDiv, MultiArith, SingleOp, SinglEq, Svamp 構造化)でも強い結果を示す。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。