[论文解读] Unsupervised State Representation Learning in Atari
本文提出时空深度互信息最大化(Spatiotemporal DeepInfomax, ST-DIM),一种自监督表示学习方法,通过在Atari 2600观测数据的时空特征间最大化互信息,实现对解耦的高层状态因子的建模。在基于ALE源代码提取真实状态变量的新基准上,ST-DIM在捕捉小物体和低熵特征方面优于以往的对比学习与生成方法。
State representation learning, or the ability to capture latent generative factors of an environment, is crucial for building intelligent agents that can perform a wide variety of tasks. Learning such representations without supervision from rewards is a challenging open problem. We introduce a method that learns state representations by maximizing mutual information across spatially and temporally distinct features of a neural encoder of the observations. We also introduce a new benchmark based on Atari 2600 games where we evaluate representations based on how well they capture the ground truth state variables. We believe this new framework for evaluating representation learning models will be crucial for future representation learning research. Finally, we compare our technique with other state-of-the-art generative and contrastive representation learning methods. The code associated with this work is available at https://github.com/mila-iqia/atari-representation-learning
研究动机与目标
- 开发一种自监督方法,在强化学习环境中无需奖励监督即可学习有意义且解耦的状态表示。
- 解决现有方法仅关注像素级重建或难以捕捉小物体或低熵状态因子的局限性。
- 提出一种新基准,通过源代码分析从Atari 2600游戏中提取真实状态变量,用于评估状态表示学习。
- 评估不同表示学习技术在捕捉多种生成因子(如物体位置、得分和敌人位置)方面的表现。
- 证明在空间与时间维度上最大化互信息,可生成更鲁棒且语义上更清晰的表示。
提出的方法
- 该方法使用卷积神经网络编码器从Atari观测中提取特征,表示在多个空间与时间尺度上计算。
- 基于InfoNCE损失的对比目标,最大化当前帧全局表示与未来帧局部补丁表示之间的互信息。
- 该方法结合两种目标:空间补丁之间的局部-局部互信息,以及完整帧与补丁之间的全局-局部互信息。
- 通过对比学习目标端到端训练模型,鼓励来自同一观测的正样本对(时空特征)在表示空间中比负样本对更接近。
- 关键创新在于使用多个负样本,以提升对比目标中互信息估计的稳定性和准确性。
- 通过线性探测评估方法,训练线性分类器以从学习到的表示中预测真实状态变量。
实验结果
研究问题
- RQ1在空间与时间维度上同时最大化互信息,是否能产生优于现有对比或生成方法的更优解耦状态表示?
- RQ2不同表示学习方法在捕捉Atari游戏中小物体(如钥匙或敌人)等低对比度物体方面表现如何?
- RQ3当高熵、易预测的特征(如时钟)主导学习目标时,对比方法在多大程度上会失效?
- RQ4所提出的基于真实状态变量的新基准在多大程度上实现了更可靠、可解释的表示学习模型评估?
- RQ5对比方法(倾向于高熵特征)与生成模型(倾向于大而低熵物体)在表示质量上存在哪些定性差异?
主要发现
- 在Atari基准上,ST-DIM在所有状态变量上的平均F1分数最高,优于所有对比基线方法及VAE和像素预测等生成模型。
- ST-DIM在捕捉小物体(如钥匙和敌人)方面显著优于其他方法,在如《蒙特祖马的复仇》等游戏中,其F1分数比竞争性对比方法高出20%–30%。
- 在《拳击》游戏中,ST-DIM对时钟变量的F1得分为0.92,对玩家得分的F1得分为0.88,而CPC和Global-T-DIM尽管在时钟上表现优异,却未能有效捕捉玩家和敌人的位置。
- 消融实验表明,若移除空间对比组件(即Global-T-DIM),所有状态变量的性能均下降,证实了目标中空间归纳偏置的重要性。
- 对比方法如ST-DIM对易被利用的特征(如《拳击》中的时钟)更具鲁棒性,而CPC和Global-T-DIM则在这些特征上趋于饱和,导致在更复杂、低熵的状态因子上表现欠佳。
- 生成模型如PIXEL-PRED在高熵特征(如时钟和得分)上表现较差,但在大而低熵的特征(如玩家和敌人位置)上表现优异,凸显了其与对比方法的互补优势。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。