[論文レビュー] JMLR: Joint Medical LLM and Retrieval Training for Enhancing Reasoning and Professional Question Answering Capability
この論文は Joint Medical LLM and Retrieval Training (JMLR) を紹介します。同期的な retriever-LLM トレーニングフレームワークで、医療QAと推論を改善しつつトレーニング時間を短縮します。7Bおよび13BのLlamaベースモデルで医療ベンチマークの最先端のオープンソース結果を報告します。
Large Language Models (LLMs) have demonstrated a remarkable potential in medical knowledge acquisition and question-answering. However, LLMs can potentially hallucinate and yield factually incorrect outcomes, even with domain-specific pretraining. Previously, retrieval augmented generation (RAG) has limited success in addressing hallucinations. Unlike previous methods in RAG where the retrieval model was trained separately from the LLM, we introduce JMLR (for Jointly trains LLM and information Retrieval) during the fine-tuning phase. The synchronized training mechanism enhances JMLR's ability to retrieve clinical guidelines and leverage medical knowledge to reason and answer questions and reduces the demand for computational resources. We evaluated JMLR on the important medical question-answering application. Our experimental results demonstrate that JMLR-13B (70.5%) outperforms a previous state-of-the-art open-source model using conventional pre-training and fine-tuning Meditron-70B (68.9%) and Llama2-13B with RAG (67.7%) on a medical question-answering dataset. Comprehensive evaluations reveal JMLR-13B enhances reasoning quality and reduces hallucinations better than Claude3-Opus. Additionally, JMLR-13B (148 GPU hours) also trains much faster than Meditron-70B (42630 GPU hours). Through this work, we provide a new and efficient knowledge enhancement method for healthcare, demonstrating the potential of integrating retrieval and LLM training for medical question-answering systems.
研究の動機と目的
- ドメイン知識を活用して医療QAと推論を改善する動機づけ。
- retrieved medical guidelinesやtextsでLLMsを grounding して幻覚を抑制する。
- より良いアライメントのために retriever と LLM を同時に更新する joint training パラダイムを開発する。
- 従来の pretraining+finetuning パイプラインと比較して効率向上と性能を評価する。
提案手法
- 長い入力文脈を扱うための Shifted Sparse Attention (S2-Attn) を使用する。
- ColBERTベースの retriever と joint LLM-retriever training objective (LLM-Rank loss) を採用する。
- AMBOSS と USMLE からの QA ペアを学習し、上位 retrieved documents を LLM に入力する。
- LLM-driven loss を計算し、LLM の改善を反映するランクベースの信号で retriever パラメータを更新する。
- 1 回のイテレーションごとに top-30 retrieved docs をサンプリングし、トップ7 を LLM に入力して回答生成と推論を行う。
- 統合された JMLR を、7B および 13B の Llama モデルで、別個の RAG や従来の pretraining/fine-tuning と比較する。
実験結果
リサーチクエスチョン
- RQ1 retriever と LLM のトレーニングを同期させると、従来の pretraining-finetuning および RAG のベースラインと比較して医療QAの精度と推論能力が向上するか?
- RQ2 JMLR はトレーニング時間とリソース消費を削減しつつ、医療ベンチマークで最先端の性能を維持または超えるか?
- RQ3 JMLR は医療QAにおける幻覚発生傾向と説明可能性にどのような影響を与えるか?
- RQ4 モデル規模(7B vs 13B)が JMLR の性能と効率に与える影響は何か?
主な発見
| Dataset | PMC-Llama-7B | Llama-2-7B | Meditron-7B | JMLR-7B | MedMCQA | MedQA | AMBOSS |
|---|---|---|---|---|---|---|---|
| MMLU-Medical | 59.7 | 56.3 | 55.6 | 57.2 | 57.6 | 51.7 | 68.7 |
| MedMcQA | 57.6 | 54.4 | 59.2 | 61.3 | 57.6 | 61.3 | 68.7 |
| MedQA | 42.4 | 44.0 | 47.9 | 51.7 | 42.4 | 51.7 | 68.7 |
| AMBOSS | 43.7 | 46.5 | 50.1 | 68.7 | 46.5 | 68.7 | 81.2 |
- JMLR-13B は AMBOSS で 81.2%、MedQA で 61.3% を達成し、これらのデータセットで Meditron-70B や ChatGPT を上回る。
- JMLR-7B は AMBOSS で 68.7%、MedQA で 51.7% を達成し、いくつかの公開ベースラインを上回る。
- JMLR のトレーニング時間は 37 時間で、従来の pretraining(127h)プラス finetuning(17h)より大幅に短い。
- JMLR-7B および JMLR-13B は MMLU-Medical、MedMCQA、MedQA、AMBOSS のベンチマークで強力な結果を示し、医療推論と QA 能力の改善を示唆する。
- GPT-4 と3名の医師は、JMLR-13B の推論をほとんどのケースで優れていると独立に判断した(GPT-4 勝率 0.63; 専門家 0.60)。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。