[논문 리뷰] JMLR: Joint Medical LLM and Retrieval Training for Enhancing Reasoning and Professional Question Answering Capability
이 논문은 Joint Medical LLM and Retrieval Training (JMLR)을 제안하는데, 동기화된 검색기-LLM 학습 프레임워크로 의료 QA 및 추론을 개선하고 학습 시간을 줄인다. 7B 및 13B Llama 기반 모델로 의료 벤치마크에서 최첨단 오픈 소스 결과를 보고한다.
Large Language Models (LLMs) have demonstrated a remarkable potential in medical knowledge acquisition and question-answering. However, LLMs can potentially hallucinate and yield factually incorrect outcomes, even with domain-specific pretraining. Previously, retrieval augmented generation (RAG) has limited success in addressing hallucinations. Unlike previous methods in RAG where the retrieval model was trained separately from the LLM, we introduce JMLR (for Jointly trains LLM and information Retrieval) during the fine-tuning phase. The synchronized training mechanism enhances JMLR's ability to retrieve clinical guidelines and leverage medical knowledge to reason and answer questions and reduces the demand for computational resources. We evaluated JMLR on the important medical question-answering application. Our experimental results demonstrate that JMLR-13B (70.5%) outperforms a previous state-of-the-art open-source model using conventional pre-training and fine-tuning Meditron-70B (68.9%) and Llama2-13B with RAG (67.7%) on a medical question-answering dataset. Comprehensive evaluations reveal JMLR-13B enhances reasoning quality and reduces hallucinations better than Claude3-Opus. Additionally, JMLR-13B (148 GPU hours) also trains much faster than Meditron-70B (42630 GPU hours). Through this work, we provide a new and efficient knowledge enhancement method for healthcare, demonstrating the potential of integrating retrieval and LLM training for medical question-answering systems.
연구 동기 및 목표
- 도메인 특화 지식을 활용해 의료 QA 및 추론을 개선하려는 동기.
- LLM의 헛소리를 줄이고 신뢰성을 높이기 위해 검색된 의료 가이드라인과 텍스트로 기반을 다진다.
- 더 나은 정렬을 위해 검색기와 LLM을 함께 업데이트하는 공동 학습 파라다임 개발.
- 전통적인 사전학습+미세조정 파이프라인과 비교한 효율성 향상 및 성능 평가.
제안 방법
- 길어진 입력 컨텍스트를 처리하기 위해 Shifted Sparse Attention (S2-Attn)을 사용한다.
- LLM-랭크 손실(LLM-Rank loss)이라는 공동 LLM-검색기 학습 목표를 갖는 ColBERT 기반 검색기를 활용한다.
- AM BOSS와 USMLE의 QA 쌍으로 학습하되 상위 검색 결과 문서를 LLM에 입력한다.
- LLM 구동 손실을 계산하고 LLM의 개선을 반영하는 순위 기반 신호를 통해 검색기 파라미터를 업데이트한다.
- 반복당 상위 30개 검색 문서를 샘플링하고 상위 7개를 LLM에 입력하여 답 변환과 추론을 수행한다.
- 7B 및 13B Llama 모델에서 통합 JMLR과 분리된 RAG 및 기준 사전학습/미세조정 파이프라인을 비교한다.
실험 결과
연구 질문
- RQ1검색기와 LLM 학습의 동기화가 기존의 사전학습+미세조정 및 RAG 기준선에 비해 의료 QA 정확도와 추론을 향상시키는가?
- RQ2JMLR이 학습 시간과 자원 사용을 줄이면서 의료 벤치마크에서 최첨단 성능을 유지하거나 능가할 수 있는가?
- RQ3의료 QA에서 헛소리 발생 가능성과 설명가능성에 JMLR이 어떤 영향을 미치는가?
- RQ4모델 규모(7B 대 13B)가 JMLR의 성능과 효율성에 미치는 영향은 무엇인가?
주요 결과
| 데이터셋 | PMC-Llama-7B | Llama-2-7B | Meditron-7B | JMLR-7B | MedMCQA | MedQA | AMBOSS |
|---|---|---|---|---|---|---|---|
| MMLU-Medical | 59.7 | 56.3 | 55.6 | 57.2 | 57.6 | 51.7 | 68.7 |
| MedMcQA | 57.6 | 54.4 | 59.2 | 61.3 | 57.6 | 61.3 | 68.7 |
| MedQA | 42.4 | 44.0 | 47.9 | 51.7 | 42.4 | 51.7 | 68.7 |
| AMBOSS | 43.7 | 46.5 | 50.1 | 68.7 | 46.5 | 68.7 | 81.2 |
- JMLR-13B는 AMBOSS에서 81.2%, MedQA에서 61.3%를 달성하여 이 데이터셋에서 Meditron-70B 및 ChatGPT를 능가한다.
- JMLR-7B는 AMBOSS에서 68.7%, MedQA에서 51.7%를 달성하여 여러 공개 기준선을 능가한다.
- JMLR의 학습 시간은 37시간으로, 기존의 사전학습(127h)과 미세조정(17h)보다 훨씬 짧다.
- JMLR-7B와 JMLR-13B는 MMLU-Medical, MedMCQA, MedQA, AMBOSS 벤치마크 전반에서 강한 결과를 보여주며 의료 추론 및 QA 능력 개선을 시사한다.
- GPT-4와 다수의 의사 3명이 대부분의 사례에서 JMLR-13B의 추론을 우수하다고 독립적으로 판단했다( GPT-4 승률 0.63; 전문가 0.60).
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.