[论文解读] Reinforcement Learning for Optimizing RAG for Domain Chatbots
论文开发了一种面向策略的强化学习方法,独立于RAG管线,用于决定是否获取FAQ上下文,在GPT-4评估中实现约31%令牌节省,并在领域FAQ聊天机器人上实现略微的准确度提升。它还显示自有嵌入模型在检索和OOD检测方面优于公开模型。
With the advent of Large Language Models (LLM), conversational assistants have become prevalent for domain use cases. LLMs acquire the ability to contextual question answering through training, and Retrieval Augmented Generation (RAG) further enables the bot to answer domain-specific questions. This paper describes a RAG-based approach for building a chatbot that answers user's queries using Frequently Asked Questions (FAQ) data. We train an in-house retrieval embedding model using infoNCE loss, and experimental results demonstrate that the in-house model works significantly better than the well-known general-purpose public embedding model, both in terms of retrieval accuracy and Out-of-Domain (OOD) query detection. As an LLM, we use an open API-based paid ChatGPT model. We noticed that a previously retrieved-context could be used to generate an answer for specific patterns/sequences of queries (e.g., follow-up queries). Hence, there is a scope to optimize the number of LLM tokens and cost. Assuming a fixed retrieval model and an LLM, we optimize the number of LLM tokens using Reinforcement Learning (RL). Specifically, we propose a policy-based model external to the RAG, which interacts with the RAG pipeline through policy actions and updates the policy to optimize the cost. The policy model can perform two actions: to fetch FAQ context or skip retrieval. We use the open API-based GPT-4 as the reward model. We then train a policy model using policy gradient on multiple training chat sessions. As a policy model, we experimented with a public gpt-2 model and an in-house BERT model. With the proposed RL-based optimization combined with similarity threshold, we are able to achieve significant cost savings while getting a slightly improved accuracy. Though we demonstrate results for the FAQ chatbot, the proposed RL approach is generic and can be experimented with any existing RAG pipeline.
研究动机与目标
- 通过在RAG设置中降低LLM令牌成本来推动高效的领域特定聊天机器人。
- 展示使用infoNCE的自有嵌入模型在领域FAQ检索和OOD检测上优于公开嵌入。
- 提出并评估一个策略梯度强化学习代理,决定何时获取FAQ上下文以最小化成本。
- 证明将基于RL的上下文选择与相似度阈值相结合,在不损失准确性的前提下实现显著的令牌节省。
提出的方法
- 训练一个自有嵌入模型,使用infoNCE损失进行领域FAQ检索。
- 与公开模型在英文和 Hinglish 查询的Top-1/Top-3检索准确率进行对比。
- 将GPT-4用作奖励评估器,将Good/Bad评分转化为用于策略梯度训练的数值奖励。
- 开发一个外部于RAG的策略网络,根据状态(先前查询、动作和当前查询)选择FETCH或NO_FETCH动作。
- 使用策略梯度和熵正则化,通过(state, action, reward)序列来训练策略。
- 将RL策略与相似度阈值(SimThr)结合,以进一步减少令牌使用。

实验结果
研究问题
- RQ1一个外部策略模型是否能够在不影响答案质量的前提下学习何时获取FAQ上下文以减少LLM令牌使用?
- RQ2内部领域微调的嵌入模型是否比公开嵌入提高检索准确性和OOD检测?
- RQ3基于RL的上下文选择如何与相似性阈值规则互动以优化RAG成本?
- RQ4在RAG设置中使用GPT-4作为自动评估器来训练策略的影响如何?
主要发现
| 模型 | 英文Top-1 | 英文Top-3 | Hinglish Top-1 | Hinglish Top-3 |
|---|---|---|---|---|
| e5-base-v2 | 0.82 | 0.91 | 0.71 | 0.87 |
| triplet-loss | 0.90 | 0.93 | 0.84 | 0.89 |
| infoNCE | 0.97 | 0.98 | 0.94 | 0.95 |
- 使用infoNCE训练的自有嵌入模型在英文和Hinglish查询上的Top-1/Top-3准确率超过公开 e5-base-v2。
- 内部模型在领域内和OOD判别方面表现更好,使相似性阈值能够选择性地跳过检索。
- 在与相似度阈值结合时,外部于RAG的RL策略在91个查询的测试会话中可将令牌使用减少约31%,并且准确率略有提升。
- GPT-4评分可以转化为奖励以驱动策略梯度更新,用于动作选择。
- 使用GPT-2作为策略模型也能实现令牌节省(约25%),表明该方法可以在不同策略架构中泛化。
- 不同的奖励塑形会影响令牌节省(例如,采用替代塑形时约30%)。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。