[论文解读] Contrastive Learning of Structured World Models
C-SWMs 通过对比学习学习基于对象的潜在表征和一个以动作为条件的、图结构的转移模型,从而在结构化环境中实现无监督对象发现与多对象动力学预测的高准确性。
A structured understanding of our world in terms of objects, relations, and hierarchies is an important component of human cognition. Learning such a structured world model from raw sensory data remains a challenge. As a step towards this goal, we introduce Contrastively-trained Structured World Models (C-SWMs). C-SWMs utilize a contrastive approach for representation learning in environments with compositional structure. We structure each state embedding as a set of object representations and their relations, modeled by a graph neural network. This allows objects to be discovered from raw pixel observations without direct supervision as part of the learning process. We evaluate C-SWMs on compositional environments involving multiple interacting objects that can be manipulated independently by an agent, simple Atari games, and a multi-object physics simulation. Our experiments demonstrate that C-SWMs can overcome limitations of models based on pixel reconstruction and outperform typical representatives of this model class in highly structured environments, while learning interpretable object-based representations.
研究动机与目标
- 激励学习一个结构化、以对象为中心的世界模型,以提升泛化能力和反事实推理。
- 开发一种从像素中无监督发现对象的方法。
- 提出一种对象级对比损失来训练对象表征和转移。
- 利用图神经网络建模对象之间的关系与交互。
- 证明结构化表征能改善长时序状态预测和泛化。
提出的方法
- 使用两部分编码器将观测编码为一组以对象为中心的潜在表征:一个基于 CNN 的对象提取器和一个基于 MLP 的对象编码器。
- 用图神经网络建模对象之间的相互作用,通过一个平移型转移预测潜在状态更新(z_t + T(z_t, a_t) ≈ z_{t+1})。
- 使用对象级对比趋于损失进行训练,区分真实的状态-动作-状态三元组与基于 TransE 风格能量的伪负样本。
- 采用对象因式分解的潜在空间 Z = Z_1 × … × Z_K 及相应的动作 A = A_1 × … × A_K,以捕捉组成结构并实现参数共享。
- 在潜在空间中使用排序指标(Hits@1、MRR)评估跨多样化环境的多步前向预测。
实验结果
研究问题
- RQ1C-SWMs 是否能够在没有监督的情况下从原始像素观测中发现对象?
- RQ2对象中心的潜在表征与基于 GNN 的转移是否能够实现准确的多步状态预测和组合泛化?
- RQ3相较于重建基线,对比学习是否提升潜在表征与预测精度?
- RQ4对象因式分解如何影响对未知环境配置的泛化?
主要发现
| 模型 | 1 步 H@1 | 1 步 MRR | 5 步 H@1 | 5 步 MRR | 10 步 H@1 | 10 步 MRR |
|---|---|---|---|---|---|---|
| 2D SHAPES - C-SWM | 100 ± 0.0 | 100 ± 0.0 | 100 ± 0.0 | 100 ± 0.0 | 99.9 ± 0.0 | 100 ± 0.0 |
| 2D SHAPES - 潜在 GNN | 99.9 ± 0.0 | 100 ± 0.0 | 97.4 ± 0.1 | 98.4 ± 0.0 | 89.7 ± 0.3 | 93.1 ± 0.2 |
| 2D SHAPES - 因式化状态 | 54.5 ± 18.1 | 65.0 ± 15.9 | 34.4 ± 16.0 | 47.4 ± 16.0 | 24.1 ± 11.2 | 37.0 ± 12.1 |
| 2D SHAPES - 对比损失 | 49.9 ± 0.9 | 55.2 ± 0.9 | 6.5 ± 0.5 | 9.3 ± 0.7 | 1.4 ± 0.1 | 2.6 ± 0.2 |
| 2D SHAPES - World Model (AE) | 98.7 ± 0.5 | 99.2 ± 0.3 | 36.1 ± 8.1 | 44.1 ± 8.1 | 6.5 ± 2.6 | 10.5 ± 3.6 |
| 2D SHAPES - World Model (VAE) | 94.2 ± 1.0 | 96.4 ± 0.6 | 14.1 ± 1.1 | 21.4 ± 1.4 | 1.4 ± 0.2 | 3.5 ± 0.4 |
| 3D BLOCKS - C-SWM | 99.9 ± 0.0 | 100 ± 0.0 | 99.9 ± 0.0 | 100 ± 0.0 | 99.9 ± 0.0 | 100 ± 0.0 |
| 3D BLOCKS - 潜在 GNN | 99.9 ± 0.0 | 99.9 ± 0.0 | 96.3 ± 0.4 | 97.7 ± 0.3 | 86.0 ± 1.8 | 90.2 ± 1.5 |
| 3D BLOCKS - 因式化状态 | 74.2 ± 9.3 | 82.5 ± 8.3 | 48.7 ± 12.9 | 62.6 ± 13.0 | 65.8 ± 14.0 | 49.6 ± 11.0 |
| 3D BLOCKS - 对比损失 | 48.9 ± 16.8 | 52.5 ± 17.8 | 12.2 ± 5.8 | 16.3 ± 7.1 | 3.1 ± 1.9 | 5.3 ± 2.8 |
| 3D BLOCKS - World Model (AE) | 93.5 ± 0.8 | 95.6 ± 0.6 | 26.7 ± 0.7 | 35.6 ± 0.8 | 4.0 ± 0.2 | 7.6 ± 0.3 |
| 3D BLOCKS - World Model (VAE) | 90.9 ± 0.7 | 94.2 ± 0.6 | 31.3 ± 2.3 | 41.8 ± 2.3 | 7.2 ± 0.9 | 12.9 ± 1.3 |
| ATARI PONG - C-SWM (K=5) | 20.5 ± 3.5 | 41.8 ± 2.9 | 9.5 ± 2.2 | 22.2 ± 3.3 | 5.3 ± 1.6 | 15.8 ± 2.8 |
| ATARI PONG - C-SWM (K=3) | 34.8 ± 5.3 | 54.3 ± 5.2 | 12.8 ± 3.4 | 28.1 ± 4.2 | 9.5 ± 1.7 | 21.1 ± 2.8 |
| ATARI PONG - C-SWM (K=1) | 36.5 ± 5.6 | 56.2 ± 6.2 | 18.3 ± 1.9 | 35.7 ± 2.3 | 11.5 ± 1.0 | 26.0 ± 1.2 |
| ATARI PONG - World Model (AE) | 23.8 ± 3.3 | 44.7 ± 2.4 | 1.7 ± 0.5 | 8.0 ± 0.5 | 1.2 ± 0.8 | 5.3 ± 0.8 |
| ATARI PONG - World Model (VAE) | 1.0 ± 0.0 | 5.1 ± 0.1 | 1.0 ± 0.0 | 5.2 ± 0.0 | 1.0 ± 0.0 | 5.2 ± 0.0 |
| SPACE INVADERS - C-SWM (K=5) | 48.5 ± 7.0 | 66.1 ± 6.6 | 16.8 ± 2.7 | 35.7 ± 3.7 | 11.8 ± 3.0 | 26.0 ± 4.1 |
| SPACE INVADERS - C-SWM (K=3) | 46.2 ± 13.0 | 62.3 ± 11.5 | 10.8 ± 3.7 | 28.5 ± 5.8 | 6.0 ± 0.4 | 20.9 ± 0.9 |
| SPACE INVADERS - C-SWM (K=1) | 31.5 ± 13.1 | 48.6 ± 11.8 | 10.0 ± 2.3 | 23.9 ± 3.6 | 6.0 ± 1.7 | 19.8 ± 3.3 |
| SPACE INVADERS - World Model (AE) | 40.2 ± 3.6 | 59.6 ± 3.5 | 5.2 ± 1.1 | 14.1 ± 2.0 | 3.8 ± 0.8 | 10.4 ± 1.3 |
| SPACE INVADERS - World Model (VAE) | 1.0 ± 0.0 | 5.3 ± 0.1 | 0.8 ± 0.2 | 5.2 ± 0.0 | 1.0 ± 0.0 | 5.2 ± 0.0 |
| 3-BODY PHYSICS - C-SWM | 100 ± 0.0 | 100 ± 0.0 | 97.2 ± 0.9 | 98.5 ± 0.5 | 75.5 ± 4.7 | 85.2 ± 3.1 |
| 3-BODY PHYSICS - World Model (AE) | 100 ± 0.0 | 100 ± 0.0 | 97.7 ± 0.3 | 98.8 ± 0.2 | 67.9 ± 2.4 | 78.4 ± 1.8 |
| 3-BODY PHYSICS - World Model (VAE) | 100 ± 0.0 | 100 ± 0.0 | 83.1 ± 2.5 | 90.3 ± 1.6 | 23.6 ± 4.2 | 37.5 ± 4.8 |
| 3-BODY PHYSICS - PAIG | 89.2 ± 3.5 | 90.7 ± 3.4 | 57.7 ± 12.0 | 63.1 ± 11.1 | 25.1 ± 13.0 | 33.1 ± 13.4 |
- C-SWMs 学习可解释的对象级表征和准确的转移预测,在高度结构化的环境中优于基于重建的基线。
- 在网格世界和物理任务中,C-SWMs 在短期和中期视野(1、5、10 步)的潜在空间预测几乎完美,且具有高的 H@1 和 MRR,尤其在使用对象因式分解表示与 GNN 转移时。
- 在无监督下出现对象发现,每个对象的潜在坐标与真实对象位置紧密对齐(可通过随机线性变换)。
- 对比损失显著提高对未见配置的泛化,相对于像素重建损失,尤其在多对象情境及使用基于 VAE 的解码器时。
- 增加对象槽位数量(K)需要在 Atari 任务上进行基于验证的调参;迭代/对象中心的编码可能进一步提升鲁棒性。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。