Skip to main content
QUICK REVIEW

[논문 리뷰] Contrastive Learning of Structured World Models

Thomas Kipf, Elise van der Pol|arXiv (Cornell University)|2019. 11. 27.
Domain Adaptation and Few-Shot Learning참고 문헌 70인용 수 68
한 줄 요약

C-SWMs는 객체 기반 잠재 표현과 행동 조건의 그래프 구조 전이 모델을 대조 학습을 통해 학습하여, 구조화된 환경에서 비지도 객체 발견 및 다객체 동역학 예측의 정확성을 가능하게 한다.

ABSTRACT

A structured understanding of our world in terms of objects, relations, and hierarchies is an important component of human cognition. Learning such a structured world model from raw sensory data remains a challenge. As a step towards this goal, we introduce Contrastively-trained Structured World Models (C-SWMs). C-SWMs utilize a contrastive approach for representation learning in environments with compositional structure. We structure each state embedding as a set of object representations and their relations, modeled by a graph neural network. This allows objects to be discovered from raw pixel observations without direct supervision as part of the learning process. We evaluate C-SWMs on compositional environments involving multiple interacting objects that can be manipulated independently by an agent, simple Atari games, and a multi-object physics simulation. Our experiments demonstrate that C-SWMs can overcome limitations of models based on pixel reconstruction and outperform typical representatives of this model class in highly structured environments, while learning interpretable object-based representations.

연구 동기 및 목표

  • 일반화 및 반사실적 추론을 개선하기 위한 구조화된 객체 중심의 월드 모델 학습 동기 부여.
  • 픽셀로부터 직접 감독 없이 객체를 발견하는 비지도 방법을 개발한다.
  • 객체 표현과 전이를 학습시키기 위한 대조적(contrastive) 객체 수준 손실을 제안한다.
  • 객체 간의 관계와 상호 작용을 모델링하기 위해 그래프 신경망을 활용한다.
  • 구조화된 표현이 장기 상태 예측 및 일반화에 미치는 이점을 입증한다.

제안 방법

  • 관찰을 CNN 기반 객체 추출기와 MLP 기반 객체 인코더의 두 부분으로 구성된 객체 중심 잠재 표현 집합으로 인코딩한다.
  • 전이(z_t + T(z_t, a_t) ≈ z_{t+1}) 를 예측하는 그래프 신경망으로 객체 간 상호 작용을 모델링한다.
  • 실제 상태-행동-상태 트리플을 손상된 음수와 구분하는 객체 수준 대조 힌지 손실로 학습한다(TransE 스타일 에너지를 기반으로).
  • 구성적 구조를 포착하고 파라미터 공유를 가능하게 하기 위해 Z = Z_1 × … × Z_K 및 A = A_1 × … × A_K 의 객체 요소화 잠재 공간을 채택한다.
  • 다양한 환경에서 다단계 예측을 위한 잠재 공간에서의 순위 측정 지표(Hits@1, MRR)로 평가한다.

실험 결과

연구 질문

  • RQ1C-SWMs가 감독 없이 원시 픽셀 관찰에서 객체를 발견할 수 있는가?
  • RQ2객체 중심 잠재 표현과 GNN 기반 전이가 다단계 상태 예측과 조합 일반화를 정확하게 가능하게 하는가?
  • RQ3대조 학습이 재구성 기반 기준선과 비교하여 잠재 표현 및 예측 정확도를 향상시키는가?
  • RQ4객체 요소화가 보지 못한 환경 구성에 대한 일반화에 어떤 영향을 미치는가?

주요 결과

Model1 Step H@11 Step MRR5 Steps H@15 Steps MRR10 Steps H@110 Steps MRR
2D SHAPES - C-SWM100 ± 0.0100 ± 0.0100 ± 0.0100 ± 0.099.9 ± 0.0100 ± 0.0
2D SHAPES - latent GNN99.9 ± 0.0100 ± 0.097.4 ± 0.198.4 ± 0.089.7 ± 0.393.1 ± 0.2
2D SHAPES - factored states54.5 ± 18.165.0 ± 15.934.4 ± 16.047.4 ± 16.024.1 ± 11.237.0 ± 12.1
2D SHAPES - contrastive loss49.9 ± 0.955.2 ± 0.96.5 ± 0.59.3 ± 0.71.4 ± 0.12.6 ± 0.2
2D SHAPES - World Model (AE)98.7 ± 0.599.2 ± 0.336.1 ± 8.144.1 ± 8.16.5 ± 2.610.5 ± 3.6
2D SHAPES - World Model (VAE)94.2 ± 1.096.4 ± 0.614.1 ± 1.121.4 ± 1.41.4 ± 0.23.5 ± 0.4
3D BLOCKS - C-SWM99.9 ± 0.0100 ± 0.099.9 ± 0.0100 ± 0.099.9 ± 0.0100 ± 0.0
3D BLOCKS - latent GNN99.9 ± 0.099.9 ± 0.096.3 ± 0.497.7 ± 0.386.0 ± 1.890.2 ± 1.5
3D BLOCKS - factored states74.2 ± 9.382.5 ± 8.348.7 ± 12.962.6 ± 13.065.8 ± 14.049.6 ± 11.0
3D BLOCKS - contrastive loss48.9 ± 16.852.5 ± 17.812.2 ± 5.816.3 ± 7.13.1 ± 1.95.3 ± 2.8
3D BLOCKS - World Model (AE)93.5 ± 0.895.6 ± 0.626.7 ± 0.735.6 ± 0.84.0 ± 0.27.6 ± 0.3
3D BLOCKS - World Model (VAE)90.9 ± 0.794.2 ± 0.631.3 ± 2.341.8 ± 2.37.2 ± 0.912.9 ± 1.3
ATARI PONG - C-SWM (K=5)20.5 ± 3.541.8 ± 2.99.5 ± 2.222.2 ± 3.35.3 ± 1.615.8 ± 2.8
ATARI PONG - C-SWM (K=3)34.8 ± 5.354.3 ± 5.212.8 ± 3.428.1 ± 4.29.5 ± 1.721.1 ± 2.8
ATARI PONG - C-SWM (K=1)36.5 ± 5.656.2 ± 6.218.3 ± 1.935.7 ± 2.311.5 ± 1.026.0 ± 1.2
ATARI PONG - World Model (AE)23.8 ± 3.344.7 ± 2.41.7 ± 0.58.0 ± 0.51.2 ± 0.85.3 ± 0.8
ATARI PONG - World Model (VAE)1.0 ± 0.05.1 ± 0.11.0 ± 0.05.2 ± 0.01.0 ± 0.05.2 ± 0.0
SPACE INVADERS - C-SWM (K=5)48.5 ± 7.066.1 ± 6.616.8 ± 2.735.7 ± 3.711.8 ± 3.026.0 ± 4.1
SPACE INVADERS - C-SWM (K=3)46.2 ± 13.062.3 ± 11.510.8 ± 3.728.5 ± 5.86.0 ± 0.420.9 ± 0.9
SPACE INVADERS - C-SWM (K=1)31.5 ± 13.148.6 ± 11.810.0 ± 2.323.9 ± 3.66.0 ± 1.719.8 ± 3.3
SPACE INVADERS - World Model (AE)40.2 ± 3.659.6 ± 3.55.2 ± 1.114.1 ± 2.03.8 ± 0.810.4 ± 1.3
SPACE INVADERS - World Model (VAE)1.0 ± 0.05.3 ± 0.10.8 ± 0.25.2 ± 0.01.0 ± 0.05.2 ± 0.0
3-BODY PHYSICS - C-SWM100 ± 0.0100 ± 0.097.2 ± 0.998.5 ± 0.575.5 ± 4.785.2 ± 3.1
3-BODY PHYSICS - World Model (AE)100 ± 0.0100 ± 0.097.7 ± 0.398.8 ± 0.267.9 ± 2.478.4 ± 1.8
3-BODY PHYSICS - World Model (VAE)100 ± 0.0100 ± 0.083.1 ± 2.590.3 ± 1.623.6 ± 4.237.5 ± 4.8
3-BODY PHYSICS - PAIG89.2 ± 3.590.7 ± 3.457.7 ± 12.063.1 ± 11.125.1 ± 13.033.1 ± 13.4
  • C-SWMs는 해석 가능한 객체 수준 표현과 정확한 전이 예측을 학습하며, 재구성 기반 기준선에 비해 구조적으로 매우 구조화된 환경에서 우수하다.
  • 격자 세계 및 물리 엔진 과제에서 C-SWMs는 짧은 및 중간 길이의 예측에 대해 거의 완벽한 잠재 공간 예측을 달성하며, 객체 요인화 표현과 GNN 전이 사용 시 특히 높은 H@1 및 MRR를 보인다.
  • 감독 없이 객체 발견이 나타나며, 각 객체의 잠재 좌표가 실제 객체 위치와 근접하게 정렬된다(랜덤 선형 변환까지).
  • 대조적 손실은 특히 다객체 설정 및 VAE 기반 디코더를 사용할 때 재구성 기반 손실 대비 보지 못한 구성에 대한 일반화를 크게 향상시킨다.
  • 객체 슬롯(K) 수를 늘리는 경우 Atari 과제에서 검증 기반 튜닝이 필요하며, 반복적/객체 중심 인코딩이 로버스트성을 더 높일 수 있다.

더 나은 연구,지금 바로 시작하세요

연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.

카드 등록 없음 · 무료 플랜 제공

이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.