[논문 리뷰] Contrastive Learning of Structured World Models
C-SWMs는 객체 기반 잠재 표현과 행동 조건의 그래프 구조 전이 모델을 대조 학습을 통해 학습하여, 구조화된 환경에서 비지도 객체 발견 및 다객체 동역학 예측의 정확성을 가능하게 한다.
A structured understanding of our world in terms of objects, relations, and hierarchies is an important component of human cognition. Learning such a structured world model from raw sensory data remains a challenge. As a step towards this goal, we introduce Contrastively-trained Structured World Models (C-SWMs). C-SWMs utilize a contrastive approach for representation learning in environments with compositional structure. We structure each state embedding as a set of object representations and their relations, modeled by a graph neural network. This allows objects to be discovered from raw pixel observations without direct supervision as part of the learning process. We evaluate C-SWMs on compositional environments involving multiple interacting objects that can be manipulated independently by an agent, simple Atari games, and a multi-object physics simulation. Our experiments demonstrate that C-SWMs can overcome limitations of models based on pixel reconstruction and outperform typical representatives of this model class in highly structured environments, while learning interpretable object-based representations.
연구 동기 및 목표
- 일반화 및 반사실적 추론을 개선하기 위한 구조화된 객체 중심의 월드 모델 학습 동기 부여.
- 픽셀로부터 직접 감독 없이 객체를 발견하는 비지도 방법을 개발한다.
- 객체 표현과 전이를 학습시키기 위한 대조적(contrastive) 객체 수준 손실을 제안한다.
- 객체 간의 관계와 상호 작용을 모델링하기 위해 그래프 신경망을 활용한다.
- 구조화된 표현이 장기 상태 예측 및 일반화에 미치는 이점을 입증한다.
제안 방법
- 관찰을 CNN 기반 객체 추출기와 MLP 기반 객체 인코더의 두 부분으로 구성된 객체 중심 잠재 표현 집합으로 인코딩한다.
- 전이(z_t + T(z_t, a_t) ≈ z_{t+1}) 를 예측하는 그래프 신경망으로 객체 간 상호 작용을 모델링한다.
- 실제 상태-행동-상태 트리플을 손상된 음수와 구분하는 객체 수준 대조 힌지 손실로 학습한다(TransE 스타일 에너지를 기반으로).
- 구성적 구조를 포착하고 파라미터 공유를 가능하게 하기 위해 Z = Z_1 × … × Z_K 및 A = A_1 × … × A_K 의 객체 요소화 잠재 공간을 채택한다.
- 다양한 환경에서 다단계 예측을 위한 잠재 공간에서의 순위 측정 지표(Hits@1, MRR)로 평가한다.
실험 결과
연구 질문
- RQ1C-SWMs가 감독 없이 원시 픽셀 관찰에서 객체를 발견할 수 있는가?
- RQ2객체 중심 잠재 표현과 GNN 기반 전이가 다단계 상태 예측과 조합 일반화를 정확하게 가능하게 하는가?
- RQ3대조 학습이 재구성 기반 기준선과 비교하여 잠재 표현 및 예측 정확도를 향상시키는가?
- RQ4객체 요소화가 보지 못한 환경 구성에 대한 일반화에 어떤 영향을 미치는가?
주요 결과
| Model | 1 Step H@1 | 1 Step MRR | 5 Steps H@1 | 5 Steps MRR | 10 Steps H@1 | 10 Steps MRR |
|---|---|---|---|---|---|---|
| 2D SHAPES - C-SWM | 100 ± 0.0 | 100 ± 0.0 | 100 ± 0.0 | 100 ± 0.0 | 99.9 ± 0.0 | 100 ± 0.0 |
| 2D SHAPES - latent GNN | 99.9 ± 0.0 | 100 ± 0.0 | 97.4 ± 0.1 | 98.4 ± 0.0 | 89.7 ± 0.3 | 93.1 ± 0.2 |
| 2D SHAPES - factored states | 54.5 ± 18.1 | 65.0 ± 15.9 | 34.4 ± 16.0 | 47.4 ± 16.0 | 24.1 ± 11.2 | 37.0 ± 12.1 |
| 2D SHAPES - contrastive loss | 49.9 ± 0.9 | 55.2 ± 0.9 | 6.5 ± 0.5 | 9.3 ± 0.7 | 1.4 ± 0.1 | 2.6 ± 0.2 |
| 2D SHAPES - World Model (AE) | 98.7 ± 0.5 | 99.2 ± 0.3 | 36.1 ± 8.1 | 44.1 ± 8.1 | 6.5 ± 2.6 | 10.5 ± 3.6 |
| 2D SHAPES - World Model (VAE) | 94.2 ± 1.0 | 96.4 ± 0.6 | 14.1 ± 1.1 | 21.4 ± 1.4 | 1.4 ± 0.2 | 3.5 ± 0.4 |
| 3D BLOCKS - C-SWM | 99.9 ± 0.0 | 100 ± 0.0 | 99.9 ± 0.0 | 100 ± 0.0 | 99.9 ± 0.0 | 100 ± 0.0 |
| 3D BLOCKS - latent GNN | 99.9 ± 0.0 | 99.9 ± 0.0 | 96.3 ± 0.4 | 97.7 ± 0.3 | 86.0 ± 1.8 | 90.2 ± 1.5 |
| 3D BLOCKS - factored states | 74.2 ± 9.3 | 82.5 ± 8.3 | 48.7 ± 12.9 | 62.6 ± 13.0 | 65.8 ± 14.0 | 49.6 ± 11.0 |
| 3D BLOCKS - contrastive loss | 48.9 ± 16.8 | 52.5 ± 17.8 | 12.2 ± 5.8 | 16.3 ± 7.1 | 3.1 ± 1.9 | 5.3 ± 2.8 |
| 3D BLOCKS - World Model (AE) | 93.5 ± 0.8 | 95.6 ± 0.6 | 26.7 ± 0.7 | 35.6 ± 0.8 | 4.0 ± 0.2 | 7.6 ± 0.3 |
| 3D BLOCKS - World Model (VAE) | 90.9 ± 0.7 | 94.2 ± 0.6 | 31.3 ± 2.3 | 41.8 ± 2.3 | 7.2 ± 0.9 | 12.9 ± 1.3 |
| ATARI PONG - C-SWM (K=5) | 20.5 ± 3.5 | 41.8 ± 2.9 | 9.5 ± 2.2 | 22.2 ± 3.3 | 5.3 ± 1.6 | 15.8 ± 2.8 |
| ATARI PONG - C-SWM (K=3) | 34.8 ± 5.3 | 54.3 ± 5.2 | 12.8 ± 3.4 | 28.1 ± 4.2 | 9.5 ± 1.7 | 21.1 ± 2.8 |
| ATARI PONG - C-SWM (K=1) | 36.5 ± 5.6 | 56.2 ± 6.2 | 18.3 ± 1.9 | 35.7 ± 2.3 | 11.5 ± 1.0 | 26.0 ± 1.2 |
| ATARI PONG - World Model (AE) | 23.8 ± 3.3 | 44.7 ± 2.4 | 1.7 ± 0.5 | 8.0 ± 0.5 | 1.2 ± 0.8 | 5.3 ± 0.8 |
| ATARI PONG - World Model (VAE) | 1.0 ± 0.0 | 5.1 ± 0.1 | 1.0 ± 0.0 | 5.2 ± 0.0 | 1.0 ± 0.0 | 5.2 ± 0.0 |
| SPACE INVADERS - C-SWM (K=5) | 48.5 ± 7.0 | 66.1 ± 6.6 | 16.8 ± 2.7 | 35.7 ± 3.7 | 11.8 ± 3.0 | 26.0 ± 4.1 |
| SPACE INVADERS - C-SWM (K=3) | 46.2 ± 13.0 | 62.3 ± 11.5 | 10.8 ± 3.7 | 28.5 ± 5.8 | 6.0 ± 0.4 | 20.9 ± 0.9 |
| SPACE INVADERS - C-SWM (K=1) | 31.5 ± 13.1 | 48.6 ± 11.8 | 10.0 ± 2.3 | 23.9 ± 3.6 | 6.0 ± 1.7 | 19.8 ± 3.3 |
| SPACE INVADERS - World Model (AE) | 40.2 ± 3.6 | 59.6 ± 3.5 | 5.2 ± 1.1 | 14.1 ± 2.0 | 3.8 ± 0.8 | 10.4 ± 1.3 |
| SPACE INVADERS - World Model (VAE) | 1.0 ± 0.0 | 5.3 ± 0.1 | 0.8 ± 0.2 | 5.2 ± 0.0 | 1.0 ± 0.0 | 5.2 ± 0.0 |
| 3-BODY PHYSICS - C-SWM | 100 ± 0.0 | 100 ± 0.0 | 97.2 ± 0.9 | 98.5 ± 0.5 | 75.5 ± 4.7 | 85.2 ± 3.1 |
| 3-BODY PHYSICS - World Model (AE) | 100 ± 0.0 | 100 ± 0.0 | 97.7 ± 0.3 | 98.8 ± 0.2 | 67.9 ± 2.4 | 78.4 ± 1.8 |
| 3-BODY PHYSICS - World Model (VAE) | 100 ± 0.0 | 100 ± 0.0 | 83.1 ± 2.5 | 90.3 ± 1.6 | 23.6 ± 4.2 | 37.5 ± 4.8 |
| 3-BODY PHYSICS - PAIG | 89.2 ± 3.5 | 90.7 ± 3.4 | 57.7 ± 12.0 | 63.1 ± 11.1 | 25.1 ± 13.0 | 33.1 ± 13.4 |
- C-SWMs는 해석 가능한 객체 수준 표현과 정확한 전이 예측을 학습하며, 재구성 기반 기준선에 비해 구조적으로 매우 구조화된 환경에서 우수하다.
- 격자 세계 및 물리 엔진 과제에서 C-SWMs는 짧은 및 중간 길이의 예측에 대해 거의 완벽한 잠재 공간 예측을 달성하며, 객체 요인화 표현과 GNN 전이 사용 시 특히 높은 H@1 및 MRR를 보인다.
- 감독 없이 객체 발견이 나타나며, 각 객체의 잠재 좌표가 실제 객체 위치와 근접하게 정렬된다(랜덤 선형 변환까지).
- 대조적 손실은 특히 다객체 설정 및 VAE 기반 디코더를 사용할 때 재구성 기반 손실 대비 보지 못한 구성에 대한 일반화를 크게 향상시킨다.
- 객체 슬롯(K) 수를 늘리는 경우 Atari 과제에서 검증 기반 튜닝이 필요하며, 반복적/객체 중심 인코딩이 로버스트성을 더 높일 수 있다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.