[論文レビュー] Contrastive Learning of Structured World Models
C-SWMs は物体ベースの潜在表現と、グラフ構造の遷移モデルを対照学習を介して学習し、構造化された環境における教師なし物体発見と正確な多物体ダイナミクス予測を可能にする。
A structured understanding of our world in terms of objects, relations, and hierarchies is an important component of human cognition. Learning such a structured world model from raw sensory data remains a challenge. As a step towards this goal, we introduce Contrastively-trained Structured World Models (C-SWMs). C-SWMs utilize a contrastive approach for representation learning in environments with compositional structure. We structure each state embedding as a set of object representations and their relations, modeled by a graph neural network. This allows objects to be discovered from raw pixel observations without direct supervision as part of the learning process. We evaluate C-SWMs on compositional environments involving multiple interacting objects that can be manipulated independently by an agent, simple Atari games, and a multi-object physics simulation. Our experiments demonstrate that C-SWMs can overcome limitations of models based on pixel reconstruction and outperform typical representatives of this model class in highly structured environments, while learning interpretable object-based representations.
研究の動機と目的
- 構造化され、物体中心のワールドモデルを学習させて一般化と反事実推論を改善する動機づけ。
- 直接的な supervision なしにピクセルから物体を発見する教師なし法を開発する。
- 物体表現と遷移を訓練する対照的なオブジェクトレベルの損失を提案する。
- 物体間の関係と相互作用をモデル化するグラフニューラルネットワークを活用する。
- 構造化された表現が長期的な状態予測と一般化を改善することを示す。
提案手法
- 観測を物体中心の潜在表現の集合へ符号化する二部エンコーダを用いる:CNN ベースのオブジェクト抽出器とMLPベースのオブジェクトエンコーダ。
- グラフニューラルネットワークで物体間の相互作用をモデル化し、潜在状態更新を転写的遷移により予測する(z_t + T(z_t, a_t) ≈ z_{t+1})。
- 真の状態-作用-状態の三つ組を腐敗したネガティブと区別するオブジェクトレベルのヒンジ対照損失で訓練する(TransE風エネルギーに基づく)。
- 組成的構造を捉え、パラメータ共有を可能にするためにオブジェクト因子化潜在空間 Z = Z_1 × … × Z_K と対応するアクション A = A_1 × … × A_K を採用する。
- 多様な環境での長期前方予測を latent space で評価するリランキング指標(Hits@1, MRR)を用いる。
実験結果
リサーチクエスチョン
- RQ1C-SWMs は監督なしに raw ピクセル観測から物体を発見できるか?
- RQ2物体中心の潜在表現と GNN ベースの遷移は多段階の状態予測と組み合わせ一般化を正確に可能にするか?
- RQ3対照学習は再構成ベースのベースラインと比較して潜在表現と予測精度を改善するか?
- RQ4物体因子化は未知環境構成への一般化をどう影響するか?
主な発見
| モデル | 1 Step H@1 | 1 Step MRR | 5 Steps H@1 | 5 Steps MRR | 10 Steps H@1 | 10 Steps MRR |
|---|---|---|---|---|---|---|
| 2D SHAPES - C-SWM | 100 ± 0.0 | 100 ± 0.0 | 100 ± 0.0 | 100 ± 0.0 | 99.9 ± 0.0 | 100 ± 0.0 |
| 2D SHAPES - latent GNN | 99.9 ± 0.0 | 100 ± 0.0 | 97.4 ± 0.1 | 98.4 ± 0.0 | 89.7 ± 0.3 | 93.1 ± 0.2 |
| 2D SHAPES - factored states | 54.5 ± 18.1 | 65.0 ± 15.9 | 34.4 ± 16.0 | 47.4 ± 16.0 | 24.1 ± 11.2 | 37.0 ± 12.1 |
| 2D SHAPES - contrastive loss | 49.9 ± 0.9 | 55.2 ± 0.9 | 6.5 ± 0.5 | 9.3 ± 0.7 | 1.4 ± 0.1 | 2.6 ± 0.2 |
| 2D SHAPES - World Model (AE) | 98.7 ± 0.5 | 99.2 ± 0.3 | 36.1 ± 8.1 | 44.1 ± 8.1 | 6.5 ± 2.6 | 10.5 ± 3.6 |
| 2D SHAPES - World Model (VAE) | 94.2 ± 1.0 | 96.4 ± 0.6 | 14.1 ± 1.1 | 21.4 ± 1.4 | 1.4 ± 0.2 | 3.5 ± 0.4 |
| 3D BLOCKS - C-SWM | 99.9 ± 0.0 | 100 ± 0.0 | 99.9 ± 0.0 | 100 ± 0.0 | 99.9 ± 0.0 | 100 ± 0.0 |
| 3D BLOCKS - latent GNN | 99.9 ± 0.0 | 99.9 ± 0.0 | 96.3 ± 0.4 | 97.7 ± 0.3 | 86.0 ± 1.8 | 90.2 ± 1.5 |
| 3D BLOCKS - factored states | 74.2 ± 9.3 | 82.5 ± 8.3 | 48.7 ± 12.9 | 62.6 ± 13.0 | 65.8 ± 14.0 | 49.6 ± 11.0 |
| 3D BLOCKS - contrastive loss | 48.9 ± 16.8 | 52.5 ± 17.8 | 12.2 ± 5.8 | 16.3 ± 7.1 | 3.1 ± 1.9 | 5.3 ± 2.8 |
| 3D BLOCKS - World Model (AE) | 93.5 ± 0.8 | 95.6 ± 0.6 | 26.7 ± 0.7 | 35.6 ± 0.8 | 4.0 ± 0.2 | 7.6 ± 0.3 |
| 3D BLOCKS - World Model (VAE) | 90.9 ± 0.7 | 94.2 ± 0.6 | 31.3 ± 2.3 | 41.8 ± 2.3 | 7.2 ± 0.9 | 12.9 ± 1.3 |
| ATARI PONG - C-SWM (K=5) | 20.5 ± 3.5 | 41.8 ± 2.9 | 9.5 ± 2.2 | 22.2 ± 3.3 | 5.3 ± 1.6 | 15.8 ± 2.8 |
| ATARI PONG - C-SWM (K=3) | 34.8 ± 5.3 | 54.3 ± 5.2 | 12.8 ± 3.4 | 28.1 ± 4.2 | 9.5 ± 1.7 | 21.1 ± 2.8 |
| ATARI PONG - C-SWM (K=1) | 36.5 ± 5.6 | 56.2 ± 6.2 | 18.3 ± 1.9 | 35.7 ± 2.3 | 11.5 ± 1.0 | 26.0 ± 1.2 |
| ATARI PONG - World Model (AE) | 23.8 ± 3.3 | 44.7 ± 2.4 | 1.7 ± 0.5 | 8.0 ± 0.5 | 1.2 ± 0.8 | 5.3 ± 0.8 |
| ATARI PONG - World Model (VAE) | 1.0 ± 0.0 | 5.1 ± 0.1 | 1.0 ± 0.0 | 5.2 ± 0.0 | 1.0 ± 0.0 | 5.2 ± 0.0 |
| SPACE INVADERS - C-SWM (K=5) | 48.5 ± 7.0 | 66.1 ± 6.6 | 16.8 ± 2.7 | 35.7 ± 3.7 | 11.8 ± 3.0 | 26.0 ± 4.1 |
| SPACE INVADERS - C-SWM (K=3) | 46.2 ± 13.0 | 62.3 ± 11.5 | 10.8 ± 3.7 | 28.5 ± 5.8 | 6.0 ± 0.4 | 20.9 ± 0.9 |
| SPACE INVADERS - C-SWM (K=1) | 31.5 ± 13.1 | 48.6 ± 11.8 | 10.0 ± 2.3 | 23.9 ± 3.6 | 6.0 ± 1.7 | 19.8 ± 3.3 |
| SPACE INVADERS - World Model (AE) | 40.2 ± 3.6 | 59.6 ± 3.5 | 5.2 ± 1.1 | 14.1 ± 2.0 | 3.8 ± 0.8 | 10.4 ± 1.3 |
| SPACE INVADERS - World Model (VAE) | 1.0 ± 0.0 | 5.3 ± 0.1 | 0.8 ± 0.2 | 5.2 ± 0.0 | 1.0 ± 0.0 | 5.2 ± 0.0 |
| 3-BODY PHYSICS - C-SWM | 100 ± 0.0 | 100 ± 0.0 | 97.2 ± 0.9 | 98.5 ± 0.5 | 75.5 ± 4.7 | 85.2 ± 3.1 |
| 3-BODY PHYSICS - World Model (AE) | 100 ± 0.0 | 100 ± 0.0 | 97.7 ± 0.3 | 98.8 ± 0.2 | 67.9 ± 2.4 | 78.4 ± 1.8 |
| 3-BODY PHYSICS - World Model (VAE) | 100 ± 0.0 | 100 ± 0.0 | 83.1 ± 2.5 | 90.3 ± 1.6 | 23.6 ± 4.2 | 37.5 ± 4.8 |
| 3-BODY PHYSICS - PAIG | 89.2 ± 3.5 | 90.7 ± 3.4 | 57.7 ± 12.0 | 63.1 ± 11.1 | 25.1 ± 13.0 | 33.1 ± 13.4 |
- C-SWMs は解釈可能なオブジェクトレベルの表現と正確な遷移予測を学習し、構造化が高度に進んだ環境で再構成ベースのベースラインを上回る。
- グリッドワールドと物理シミュレーションタスクで、C-SWMs は短期・中期の潜在空間予測(1, 5, 10 ステップ)でほぼ完璧に近く、特にオブジェクト因子化表現と GNN 遷移を使用した場合に高い H@1 と MRR を達成する。
- 監督なしでの物体発見が現れ、各オブジェクトの潜在座標が真のオブジェクト位置と密接に一致する(ランダムな線形変換まで)。
- 対照的な損失はピクセル再構成損失よりも unseen 構成への一般化を大幅に改善し、特に複数オブジェクト設定と VAE ベースのデコーダ使用時に効果的。
- オブジェクトスロットの数を増やすと(K)、Atari タスクで検証ベースの調整が必要となる可能性がある。反復的・オブジェクト中心のエンコーディングは頑健性をさらに向上させる可能性。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。