Skip to main content
QUICK REVIEW

[論文レビュー] Scaling Relationship on Learning Mathematical Reasoning with Large Language Models

Zheng Yuan, Hongyi Yuan|arXiv (Cornell University)|Aug 3, 2023
Topic Modeling被引用数 8
ひとこと要約

tldr: The paper analyzes how pre-training loss, supervised data, and augmented data affect math reasoning in supervised LLMs and introduces rejection sampling fine-tuning (RFT) to augment data with diverse reasoning paths, showing significant gains over standard SFT.

ABSTRACT

Mathematical reasoning is a challenging task for large language models (LLMs), while the scaling relationship of it with respect to LLM capacity is under-explored. In this paper, we investigate how the pre-training loss, supervised data amount, and augmented data amount influence the reasoning performances of a supervised LLM. We find that pre-training loss is a better indicator of the model's performance than the model's parameter count. We apply supervised fine-tuning (SFT) with different amounts of supervised data and empirically find a log-linear relation between data amount and model performance, and we find better models improve less with enlarged supervised datasets. To augment more data samples for improving model performances without any human effort, we propose to apply Rejection sampling Fine-Tuning (RFT). RFT uses supervised models to generate and collect correct reasoning paths as augmented fine-tuning datasets. We find with augmented samples containing more distinct reasoning paths, RFT improves mathematical reasoning performance more for LLMs. We also find RFT brings more improvement for less performant LLMs. Furthermore, we combine rejection samples from multiple models which push LLaMA-7B to an accuracy of 49.3\% on GSM8K which outperforms the supervised fine-tuning (SFT) accuracy of 35.9\% significantly.

研究の動機と目的

  • Understand how pre-training loss correlates with math reasoning performance under supervised fine-tuning (SFT) and in-context learning (ICL).
  • Characterize how increasing supervised data quantitatively affects reasoning accuracy across model sizes.
  • Investigate data augmentation through rejection sampling to create diverse reasoning paths and its impact on performance.
  • Demonstrate the benefits of aggregating rejection-sampled data from multiple models and compare with baselines on GSM8K.

提案手法

  • Evaluate SFT and ICL performance across multiple LLMs (LLaMA/LLaMA2 variants) on GSM8K as a math reasoning benchmark.
  • Compare performance as a function of pre-training loss rather than model size or token counts.
  • Analyze performance against varying amounts of supervised data to identify log-linear scaling with data.
  • Apply rejection sampling to generate multiple reasoning paths, filter for correct answers, and fine-tune models (RFT).
  • Deduplicate and aggregate rejection-sampled data from multiple base models to study diversity effects on performance.
  • Provide comparisons with existing baselines (ICL, SFT, RFT from single/multiple models) on GSM8K.

実験結果

リサーチクエスチョン

  • RQ1How does pre-training loss correlate with SFT and ICL performance for math reasoning in LLMs?
  • RQ2What is the relationship between the amount of supervised data and model performance on math reasoning tasks?
  • RQ3Does rejection sampling fine-tuning (RFT) improve math reasoning, and how does performance scale with the number of distinct reasoning paths?
  • RQ4Does aggregating rejection-sampled data from multiple models yield further gains over single-model RFT?

主な発見

  • Pre-training loss is a better performance indicator for math reasoning than parameter count, with accuracy roughly negatively linearly related to pre-training loss in the studied interval.
  • SFT performance scales log-linearly with the amount of supervised data, with diminishing returns as models become better pre-trained.
  • RFT improves math reasoning when augmented data contains many distinct reasoning paths, with larger gains for weaker models.
  • Aggregating rejection-sampled data from multiple models yields higher accuracy than single-model RFT across several LLaMA/LLaMA2 variants (e.g., 49.3 in LLaMA-7B, 55.4 in LLaMA2-13B).
  • RFT is substantially cheaper than pre-training, and improving pre-training loss remains the fundamental solution for scaling math reasoning abilities.

より良い研究を、今すぐ始めましょう

論文設計から論文執筆まで、研究時間を劇的に削減しましょう。

クレジットカード登録不要

このレビューはAIが作成し、人間の編集者が確認しました。