[论文解读] Policy Distillation
本文提出策略蒸馏(policy distillation)方法,可将深度Q网络(DQN)的策略无性能损失地迁移至更小、更高效的“学生”网络。该方法实现模型压缩、多任务策略整合与在线蒸馏,使在Atari环境下的性能优于单任务教师网络与联合训练的DQN智能体。
Policies for complex visual tasks have been successfully learned with deep reinforcement learning, using an approach called deep Q-networks (DQN), but relatively large (task-specific) networks and extensive training are needed to achieve good performance. In this work, we present a novel method called policy distillation that can be used to extract the policy of a reinforcement learning agent and train a new network that performs at the expert level while being dramatically smaller and more efficient. Furthermore, the same method can be used to consolidate multiple task-specific policies into a single policy. We demonstrate these claims using the Atari domain and show that the multi-task distilled agent outperforms the single-task teachers as well as a jointly-trained DQN agent.
研究动机与目标
- 解决通过DQN训练的深度强化学习智能体存在的高计算成本与大模型尺寸问题。
- 通过蒸馏技术,实现从大型、任务特定的DQN教师网络到更小、更高效的学生网络的知识迁移。
- 将多个单任务DQN策略蒸馏为一个统一的多任务策略,使其表现优于单个教师网络。
- 探索在线蒸馏方法,通过在训练过程中持续追踪表现最佳的策略,以稳定DQN训练。
提出的方法
- 训练学生网络以模仿预训练DQN教师网络的动作价值输出分布,使用软标签(soft labels)。
- 采用温度缩放的softmax函数,使动作价值分布更平滑,从而提升知识迁移效果。
- 应用知识蒸馏损失函数,根据动作差距对动作分类进行加权,类似CAPI框架的设计。
- 使用监督回归方法,在教师策略生成的轨迹上训练学生网络。
- 通过定期用当前表现最佳的DQN策略更新学生网络,实现在线蒸馏。
- 采用多控制器架构,共享卷积特征提取层,使用任务特定的输出头,以实现多款游戏的泛化能力。
实验结果
研究问题
- RQ1策略蒸馏能否在不造成性能下降的前提下,有效将DQN策略压缩为更小、更高效的模型?
- RQ2能否将多个单任务DQN策略蒸馏为一个统一的多任务策略,使其泛化能力优于单个教师网络?
- RQ3在线蒸馏是否能通过实时追踪最佳表现策略,稳定DQN训练过程?
- RQ4当教师策略在训练过程中发生显著演化时,蒸馏方法的性能如何?
- RQ5在强化学习中,哪种损失函数设计能取得最佳蒸馏性能——尤其在非概率性、实值动作价值设置下?
主要发现
- 策略蒸馏可将DQN模型大小压缩至原大小的1/15,且在单款Atari游戏任务上无性能损失。
- 蒸馏后的多任务智能体在10个单任务DQN教师网络的几何平均性能上达到89.3%,在Q*bert和Seaquest等多款游戏中表现优于单个教师网络。
- 在三款游戏的多任务设置中,蒸馏智能体(Multi-Dist-KL)性能达到单任务DQN教师网络的116.9%,显著优于联合训练的多任务DQN智能体(83.5%)。
- 在线蒸馏使学生智能体在训练过程中表现出与DQN教师相当或更优的性能,且训练方差显著降低。
- 采用基于动作差距加权的软max损失函数(如CAPI框架)效果最佳,表明损失函数设计在强化学习蒸馏中至关重要。
- 即使在无迭代交互或无法控制数据分布的条件下,强化学习中的蒸馏依然有效,证实其作为通用正则化技术的潜力。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。