[논문 리뷰] ReMiT: RL-Guided Mid-Training for Iterative LLM Evolution
ReMiT은 mid-training 동안 RL로 조정된 참조를 사용하여 토큰의 가중치를 동적으로 재조정하고, 기본 모델을 개선하며 post-training에서 이익을 유지하고, pre-training과 post-training 사이에 자기강화 추진력을 가능하게 한다.
Standard training pipelines for large language models (LLMs) are typically unidirectional, progressing from pre-training to post-training. However, the potential for a bidirectional process--where insights from post-training retroactively improve the pre-trained foundation--remains unexplored. We aim to establish a self-reinforcing flywheel: a cycle in which reinforcement learning (RL)-tuned model strengthens the base model, which in turn enhances subsequent post-training performance, requiring no specially trained teacher or reference model. To realize this, we analyze training dynamics and identify the mid-training (annealing) phase as a critical turning point for model capabilities. This phase typically occurs at the end of pre-training, utilizing high-quality corpora under a rapidly decaying learning rate. Building upon this insight, we introduce ReMiT (Reinforcement Learning-Guided Mid-Training). Specifically, ReMiT leverages the reasoning priors of RL-tuned models to dynamically reweight tokens during the mid-training phase, prioritizing those pivotal for reasoning. Empirically, ReMiT achieves an average improvement of 3\% on 10 pre-training benchmarks, spanning math, code, and general reasoning, and sustains these gains by over 2\% throughout the post-training pipeline. These results validate an iterative feedback loop, enabling continuous and self-reinforcing evolution of LLMs.
연구 동기 및 목표
- LLM 능력의 중요한 전환점으로 중간 훈련을 식별한다.
- RL-참조 모델에 의해 안내되는 토큰 수준의 동적 재가중화 메커니즘을 제안한다.
- 외부 교사를 사용하지 않고 post-training과 pre-training 간의 양방향 영향을 가능하게 한다.
- 모델 계열 전반에서 post-training 동안 중간 훈련의 개선이 전달되고 증폭됨을 시연한다.
제안 방법
- 중간 훈련 중 RL-튜닝 모델을 참조로 사용하는 토큰 수준 재가중 스킴인 ReMiT를 도입한다.
- 기본 모델과 RL 참조 간의 토큰당 손실 차이를 계산하고, 시퀀스별 델타 손실을 중앙 정렬하며, 이를 잘린(클리핑된) 스케일된 시그모이드로 가중치에 매핑한다.
- 일반 다음 토큰 예측 손실의 소프트 재가중으로 중간 훈련 목적에 이 가중치를 통합한다.
- 외부 교사를 피하기 위해 파이프라인 내 RL-튜닝 모델을 참조로 사용한다.
- ReMiT를 암시적 목표 분포를 향한 KL-발산(KL-divergence) 및 KL-정규화된 RL과의 연결에 대한 이론적 정당화를 제공한다.
- 세 가지 오픈 소스 기본 모델 계열(OLMo-1B, SmolLM3-3B, Youtu-LLM-2B)을 대상으로 실험하고, 10개의 다운스트림 벤치마크에서 ReMiT을 기준선과 비교한다.
실험 결과
연구 질문
- RQ1RL 참조에 의해 안내되는 중간 훈련 재가중이 기본 모델의 능력을 향상시킬 수 있는가?
- RQ2중간 훈련의 이익이 SFT, DPO, RLVR 등 post-training 단계로 전이되고 지속되는가?
- RQ3ReMiT가 지식 증류와 토큰 수준 데이터 필터링 접근법에 비해 이점을 제공하는가?
주요 결과
| 모델 계열 | 사전 학습 | Vanilla NTP | MiniPLM | RHO-1 | ReMiT | Avg. |
|---|---|---|---|---|---|---|
| OLMo-1B | 3.03 | 48.14 | 48.45 | 50.42 | 61.64 | 27.56 |
| MATH | 2.94 | 10.26 | 9.60 | 10.32 | 14.50 | 9.91 |
| GPQA | 20.31 | 22.54 | 23.21 | 25.45 | 24.55 | 23.02 |
| BBH | 28.43 | 30.87 | 30.38 | 29.33 | 32.07 | 30.22 |
| IFE | 22.66 | 16.19 | 16.79 | 19.06 | 28.54 | 20.84 |
| HE | 6.71 | 8.54 | 7.32 | 6.71 | 12.80 | 8.22 |
| MBPP | 4.80 | 4.60 | 6.80 | 6.20 | 9.20 | 6.28 |
| TQA | 21.30 | 22.40 | 23.13 | 23.38 | 25.58 | 23.48 |
| ARC-C | 44.71 | 46.67 | 45.31 | 46.42 | 49.23 | 46.07 |
| MMLU-P | 9.54 | 13.31 | 13.15 | 13.68 | 17.44 | 13.62 |
| Avg. | 16.44 | 22.35 | 22.41 | 23.10 | 27.56 | 22.58 |
- ReMiT은 모델 계열 전반에 걸쳐 10개의 pre-training 벤치마크에서 평균 3%의 개선을 보인다.
- 중간 훈련의 이익이 post-training으로 전이되어 파이프라인 전반에서 2% 이상 개선을 유지한다.
- Downstream 작업에서 Vanilla NTP, MiniPLM, RHO-1과 같은 베이스라인보다 ReMiT이 우수하다.
- 이 방법은 외부 교사 없이 기본 모델과 RL 모델 간의 공동 향상 플라이휠을 가능하게 한다.
- 토큰 가중치를 잘라내기(클리핑)하면 훈련이 안정화되고 핵심 토큰에 중점을 두면서 데이터 일관성을 유지한다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.