Skip to main content
QUICK REVIEW

[논문 리뷰] Orca-Math: Unlocking the potential of SLMs in Grade School Math

Arindam Mitra, Hamed Khanpour|arXiv (Cornell University)|2024. 02. 16.
Cognitive and developmental aspects of mathematical skills인용 수 7
한 줄 요약

Orca-Math는 200K 합성 수학 문제로 훈련된 7B 매개변수 SLM이 외부 도구나 앙상블 없이 반복적 선호 학습과 고품질 에이전트 생성 데이터 세트를 통해 GSM8K pass@1에서 86.81%에 도달할 수 있음을 보여준다.

ABSTRACT

Mathematical word problem-solving has long been recognized as a complex task for small language models (SLMs). A recent study hypothesized that the smallest model size, needed to achieve over 80% accuracy on the GSM8K benchmark, is 34 billion parameters. To reach this level of performance with smaller models, researcher often train SLMs to generate Python code or use tools to help avoid calculation errors. Additionally, they employ ensembling, where outputs of up to 100 model runs are combined to arrive at a more accurate result. Result selection is done using consensus, majority vote or a separate a verifier model used in conjunction with the SLM. Ensembling provides a substantial boost in accuracy but at a significant cost increase with multiple calls to the model (e.g., Phi-GSM uses top-48 to boost the performance from 68.2 to 81.5). In this work, we present Orca-Math, a 7-billion-parameter SLM based on the Mistral-7B, which achieves 86.81% on GSM8k without the need for multiple model calls or the use of verifiers, code execution or any other external tools. Our approach has the following key elements: (1) A high quality synthetic dataset of 200K math problems created using a multi-agent setup where agents collaborate to create the data, (2) An iterative learning techniques that enables the SLM to practice solving problems, receive feedback on its solutions and learn from preference pairs incorporating the SLM solutions and the feedback. When trained with Supervised Fine-Tuning alone, Orca-Math achieves 81.50% on GSM8k pass@1 metric. With iterative preference learning, Orca-Math achieves 86.81% pass@1. Orca-Math surpasses the performance of significantly larger models such as LLAMA-2-70B, WizardMath-70B, Gemini-Pro, ChatGPT-3.5. It also significantly outperforms other smaller models while using much smaller data (hundreds of thousands vs. millions of problems).

연구 동기 및 목표

  • 초등 수준의 단어 문제에 대한 소형 언어 모델(SLM)의 수학적 추론 개선의 필요성을 자극한다.
  • SLM 추론 능력을 향상시키기 위한 고품질 합성 데이터 파이프라인과 반복 학습을 제안한다.
  • 적은 데이터와 외부 도구 없이도 더 작은 모델이 GSM8K에서 더 큰 모델을 능가할 수 있음을 입증한다.

제안 방법

  • GPT-4-Turbo 솔루션을 통한 에이전트 기반 데이터 생성 파이프라인으로 Orca-Math-dataset 200K 수학 문제를 생성한다.
  • 지도학습 미세조정, 학생 풀이, 해결책에 대한 교사 피드백의 반복 학습 루프를 적용한다.
  • External verification 도구 없이 양/음 솔루션 신호와 선호 기반 미세조정(DPO 및 KTO)을 사용하여 모델을 정렬한다.
  • GSM8K 및 기타 벤치마크에서 GPT-4 기반 추출(GPT4-based-Exact-Match)을 이용한 정확도-유사 매 prompt로 평가한다.
  • 더 큰 모델(예: LLama-2-70B, WizardMath-70B, Gemini-Pro, GPT-3.5)과 비교하여 훨씬 적은 데이터로도 경쟁력을 보임을 보여준다.
  • 모델생성 양수 및 음수 샘플의 중요성을 보여주기 위한 제거 실험(ablations)을 보고한다.

실험 결과

연구 질문

  • RQ17B SLM이 앙상블이나 외부 도구 없이 GSM8K pass@1에서 80%를 초과할 수 있는가?
  • RQ2수정된 선호 학습(양/음 신호 포함)이 SLM의 수학 추론에 대한 표준 감독 학습과 비교해 어떤 차이를 보이는가?
  • RQ3에이전트 생성 고품질 합성 데이터가 소형 모델의 수학 추론 성능에 미치는 영향은 무엇인가?
  • RQ4모델 생성 양수 및 합성 음수 샘플이 학습 효율성에 유의미하게 기여하는가?

주요 결과

모델기본 모델모델 크기정답 형식평가 방식GSM8K (%)
Llama-27Bnlppass@1pass@114.6
Llama-213Bnlppass@1pass@128.7
Llama-234Bnlppass@1pass@142.2
Llama-270Bnlppass@1pass@156.8
MetaMathLlama-27Bnlppass@166.5
MetaMathLlama-213Bnlppass@172.3
MetaMathLlama-270Bnlppass@182.3
WizardMathLlama-27Bnlppass@154.9
WizardMathLlama-213Bnlppass@163.9
WizardMathLlama-270Bnlppass@181.6
MammoTHCode-Llama7Bcodepass@159.4
MammoTHCode-Llama12Bcodepass@164.7
MammoTHCode-Llama34Bcodepass@172.7
MammoTHLlama-270Bnlppass@176.9
Mistral7B7Bnlpmaj1@852.2
Mistral8×7B-nlpmaj1@858.4
OVMLlama-27B+7Bnlpverify100@173.7
Mistral7B+7B-nlpverify100@184.7
Llemma7B7Bnlppass@136.4
Llemma34B34Bnlppass@151.5
ToRA-Code7B7BcodeCOT@172.6
ToRA-Code13B13B-COT@175.8
ToRA-Code34B34B-COT@180.7
ToRA-Code70B70B-COT@184.3
Orca 2Llama-27Bnlppass@155.72
Orca 2Llama-213Bnlppass@165.73
Gemini Pro--nlpmaj1@3286.5
GPT-3.5-0613--codepass@177.4
GPT-4-0613----97.0
Phi-1.51.3Bcodepass@144.6
Phi-GSM1.5-tiny125Mcodepass@163.1
Phi-GSM1.5-small350Mcodepass@165.9
Phi-GSM1.51.3Bcodepass@168.2
Phi-GSM+V1.5-tiny+125M+125Mcodeverify48@168.9
Phi-GSM+V1.5-small+350M+350Mcodeverify48@171.3
Phi-GSM+V1.5+1.3B+1.3Bcodeverify48@181.5
Orca-MathMistral7Bnlppass@186.81
  • 지도 학습 미세조정만으로 Orca-Math는 GSM8K pass@1에서 81.50%를 달성한다.
  • 반복적 선호 학습은 검증자나 외부 도구 없이 GSM8K pass@1을 86.81%로 올린다.
  • Orca-Math(7B, Mistral)는 보고된 설정에서 LLama-2-70B, WizardMath-70B, Gemini-Pro, GPT-3.5 등 대형 모델보다 GSM8K에서 더 나은 성능을 보인다.
  • 200K 합성 데이터 세트가 많은 기준선 대비 현저히 적은 데이터로도 경쟁력 있는 성능을 달성한다.
  • 아블레이션 결과 교사 생성 양수만 사용할 때 성능이 저하되고, 모델 생성 양수 및 음성 신호가 학습에 도움이 됨을 보인다.
  • Orca-Math는 GPT-4 기반 정확도 매핑으로 GSM8K 이외의 수학 벤치마크(AddSub, ASDiv, MultiArith, SingleOp, SinglEq, Svamp 구조화)에서도 강력한 결과를 보인다.

더 나은 연구,지금 바로 시작하세요

연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.

카드 등록 없음 · 무료 플랜 제공

이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.