[論文レビュー] STORM: Efficient Stochastic Transformer based World Models for Reinforcement Learning
STORM は Transformer ベースの確率的世界モデルを導入し、VAE エンコーダでサンプル効率とトレーニング速度を改善。 Atari 100k において lookahead search なしで新しい平均ヒューマン正規化スコアを達成し、リアルタイム学習を faster.
Recently, model-based reinforcement learning algorithms have demonstrated remarkable efficacy in visual input environments. These approaches begin by constructing a parameterized simulation world model of the real environment through self-supervised learning. By leveraging the imagination of the world model, the agent's policy is enhanced without the constraints of sampling from the real environment. The performance of these algorithms heavily relies on the sequence modeling and generation capabilities of the world model. However, constructing a perfectly accurate model of a complex unknown environment is nearly impossible. Discrepancies between the model and reality may cause the agent to pursue virtual goals, resulting in subpar performance in the real environment. Introducing random noise into model-based reinforcement learning has been proven beneficial. In this work, we introduce Stochastic Transformer-based wORld Model (STORM), an efficient world model architecture that combines the strong sequence modeling and generation capabilities of Transformers with the stochastic nature of variational autoencoders. STORM achieves a mean human performance of $126.7\%$ on the Atari $100$k benchmark, setting a new record among state-of-the-art methods that do not employ lookahead search techniques. Moreover, training an agent with $1.85$ hours of real-time interaction experience on a single NVIDIA GeForce RTX 3090 graphics card requires only $4.3$ hours, showcasing improved efficiency compared to previous methodologies.
研究の動機と目的
- 視覚環境におけるモデルベース RL のサンプル効率の改善を動機付ける。
- Transformers と確率的潜在表現を活用した効率的な世界モデルを開発する。
- 予測誤差の蓄積とトレーニング時間を抑えつつ、Atari 100k での性能を維持または改善する。
提案手法
- 観察をカテゴリカル VAE エンコーダでマップし、確率的潜在 z_t(32 カテゴリ × 32 クラス)を得る。
- z_t とアクション a_t を単一トークン e_t に結合し、GPT風の Transformer をシーケンスモデルとして用いて h_t を生成。
- h_t から報酬、継続フラグ、次の潜在分布をMLPヘッドで予測。
- 再構成、報酬、継続、ダイナミクス(KL)、表現(KL)項を組み合わせた自己教師付き損失で世界モデルを学習(β 重み付き)。
- DreamerV3 風の actor-critic 目的で lambda-returns と KV-cache 加速推論を用い、 imagined experiences のみからエージェントポリシーを学習。
実験結果
リサーチクエスチョン
- RQ1確率的 Transformer ベースの世界モデルは Atari 100k において RNN ベースや Transformer-XL ベースのモデルを上回ることができるか?
- RQ2画像ごとに単一の確率的潜在表現がダイナミクスを効果的に捉え、ポリシー学習に有効か?
- RQ3提案損失設計と imagination ベース学習はサンプル効率と計算効率にどのような影響を与えるか?
- RQ4世界モデル設計の選択肢(エンコーダのタイプ、状態表現、トランスフォーマーの深さ)が性能に与える影響は?
- RQ5STORM を用いて実環境との相互作用を限定的にして高い性能を実現可能か?
主な発見
| Game | Random | Human | SimPLe [11] | TWM [12] | IRIS [13] | DreamerV3 [10] | STORM (ours) |
|---|---|---|---|---|---|---|---|
| Alien | 228 | 7128 | 617 | 675 | 420 | 959 | 984 |
| Amidar | 6 | 1720 | 74 | 122 | 143 | 139 | 205 |
| Assault | 222 | 742 | 527 | 683 | 1524 | 706 | 801 |
| Asterix | 210 | 8503 | 1128 | 1116 | 854 | 932 | 1028 |
| Bank Heist | 14 | 753 | 34 | 467 | 53 | 649 | 641 |
| Battle Zone | 2360 | 37188 | 4031 | 5068 | 13074 | 12250 | 13540 |
| Boxing | 0 | 12 | 8 | 78 | 70 | 78 | 80 |
| Breakout | 2 | 30 | 16 | 20 | 84 | 31 | 16 |
| Chopper Command | 811 | 7388 | 979 | 1697 | 1565 | 420 | 1888 |
| Crazy Climber | 10780 | 35829 | 62584 | 71820 | 59234 | 97190 | 66776 |
| Demon Attack | 152 | 1971 | 208 | 350 | 2034 | 303 | 165 |
| Freeway | 0 | 30 | 17 | 24 | 31 | 0 | 34 |
| Freeway w/o traj | 0 | 30 | 17 | 24 | 31 | 0 | 0 |
| Frostbite | 65 | 4335 | 237 | 1476 | 259 | 909 | 1316 |
| Gopher | 258 | 2413 | 597 | 1675 | 2236 | 3730 | 8240 |
| Hero | 1027 | 30826 | 2657 | 7254 | 7037 | 11161 | 11044 |
| James Bond | 29 | 303 | 101 | 362 | 463 | 445 | 509 |
| Kangaroo | 52 | 3035 | 51 | 1240 | 838 | 4098 | 4208 |
| Krull | 1598 | 2666 | 2204 | 6349 | 6616 | 7782 | 8413 |
| Kung Fu Master | 256 | 22736 | 14862 | 24555 | 21760 | 21420 | 26182 |
| Ms Pacman | 307 | 6952 | 1480 | 1588 | 999 | 1327 | 2673 |
| Pong | -21 | 15 | 13 | 19 | 15 | 18 | 11 |
| Private Eye | 25 | 69571 | 35 | 87 | 100 | 882 | 7781 |
| Qbert | 164 | 13455 | 1289 | 3331 | 746 | 3405 | 4522 |
| Road Runner | 12 | 7845 | 5641 | 9109 | 9615 | 15565 | 17564 |
| Seaquest | 68 | 42055 | 683 | 774 | 661 | 618 | 525 |
| Up N Down | 533 | 11693 | 3350 | 15982 | 3546 | 7667 | 7985 |
- STORM は Atari 100k で平均的ヒト正規化スコアを 126.7% に達し lookahead search なしの手法として新記録を樹立。
- RTX 3090 で約 1.85 時間の実データ学習後、約 4.3 時間で完了し、従来手法より効率が向上。
- SimPLe、TWM、IRIS、DreamerV3 と比較して、Transformer シーケンスモデリングと確率的潜在表現の利点により、複数の報酬関連オブジェクトを持つゲームで性能が向上。
- シーケンスモデルとして Transformer を用い、単一の確率的潜在と観察-行動を結合したトークンを用いるアブレーションは有効である一方、より大きな Transformer 深度は Atari 100k で必ずしも性能を改善しない。
- 単一のデモンストレーション軌跡を取り入れると希薄報酬ゲーム(例: Pong)で探索性が向上する一方、密報酬ゲーム(例: Ms. Pacman)では妨げになる可能性がある。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。