[論文レビュー] Consistency Trajectory Models: Learning Probability Flow ODE Trajectory of Diffusion
CTM は確率流ODEの軌道を学習することにより、スコアベースおよび蒸留拡散モデルを統合し、柔軟かつ高品質な few-NFE サンプリングと軌道を横断する新しい gamma-sampling スキームを可能にする。
Consistency Models (CM) (Song et al., 2023) accelerate score-based diffusion model sampling at the cost of sample quality but lack a natural way to trade-off quality for speed. To address this limitation, we propose Consistency Trajectory Model (CTM), a generalization encompassing CM and score-based models as special cases. CTM trains a single neural network that can -- in a single forward pass -- output scores (i.e., gradients of log-density) and enables unrestricted traversal between any initial and final time along the Probability Flow Ordinary Differential Equation (ODE) in a diffusion process. CTM enables the efficient combination of adversarial training and denoising score matching loss to enhance performance and achieves new state-of-the-art FIDs for single-step diffusion model sampling on CIFAR-10 (FID 1.73) and ImageNet at 64x64 resolution (FID 1.92). CTM also enables a new family of sampling schemes, both deterministic and stochastic, involving long jumps along the ODE solution trajectories. It consistently improves sample quality as computational budgets increase, avoiding the degradation seen in CM. Furthermore, unlike CM, CTM's access to the score function can streamline the adoption of established controllable/conditional generation methods from the diffusion community. This access also enables the computation of likelihood. The code is available at https://github.com/sony/ctm.
研究の動機と目的
- スコアベースと蒸留拡散モデルを単一のフレームワーク内で橋渡しする。
- PF ODE の infinitesimal (score) および integral (trajectory) 成分の両方に合わせてトレーニングを可能にする。
- PF ODE trajectories に沿う制限のない走査を可能にして、品質と計算量をトレードオフする。
- 性能向上のための対抗的学習と再構成/ノイズ除去ロスを組み込む。
- 長い軌道ジャンプを制御可能な確率性で横断する gamma-sampling を導入する。
提案手法
- G(x_t, t, s) を t から s までの PF ODE 解として定義し、積分項と被積分項の両方へアクセスする補助変数 g を導入する (Lemma 1).
- G_theta(x_t, t, s) を (s/t) x_t + (1 - s/t) g_theta(x_t, t, s) とパラメータ化して、軌道と被積分関数の両方へアクセス可能にする。
- 再構成様の損失とソフト整合性損失(Eq. 5)を用いて、CTM の予測を実データ PF ODE trajectory にソフトマッチさせることで CTM を訓練する。
- 事前訓練済みスコアモデル D_phi を教師として使用し、Trajectory reconstruction (Eq. 3) および soft matching (Eq. 5) の Solver(x_t, t, u; phi) を取得する。
- L = L_CTM + lambda_DSM L_DSM + lambda_GAN L_GAN を同時最適化して、CTM、ノイズ除去スコアマッチング、対向学習を融合する。
- gamma-sampling を導入し、PF ODE 軌道に沿って前方・後方へ走査し、乱雑さを制御する調整可能な gamma を用いて(gamma in [0,1])。
実験結果
リサーチクエスチョン
- RQ11 つのニューラルネットワークが、スコア推定と軌道ベースの更新の両方を出力して、スコアベースと蒸留サンプリングを統一できるか?
- RQ2PF ODE 軌道を学習することは、NFEs にわたって劣化させることなく、サンプリング速度とサンプル品質を柔軟に取り引きできるか?
- RQ3gamma-sampling は、異なる NFEs においてサンプルの忠実度と多様性にどのように影響するか?
- RQ4CTM は CIFAR-10 および ImageNet 64×64 において、少数の NFE で最先端の FID および尤度を達成できるか?
- RQ5CTM はスコアアクセスを介した厳密な尤度計算をサポートし、軌道の対向的な洗練を可能にするか?
主な発見
| Model | NFE | 無条件 FID | 条件付き FID | NLL | メモ |
|---|---|---|---|---|---|
| CTM (ours) | 1 | 1.98 | 2.43 | - | CIFAR-10, unconditional/conditional; NFE=1 |
| CTM (ours) | 2 | 1.87 | 2.43 | - | CIFAR-10, NFE=2; surpasses teacher with 2 NFEs |
| CTM (ours) | 1 | 2.06 | - | - | ImageNet 64×64; NFE=1; SOTA among distillation/short-NFE methods |
| CTM (ours) | 2 | 1.90 | - | - | ImageNet 64×64; NFE=2; SOTA among distillation methods |
- CTM は、スコアと軌道情報の両方を出力する単一のネットワークを訓練することで、CIFAR-10 の単一ステップ拡散サンプリングで最先端の FID(FID 1.73)、ImageNet 64×64 で最先端の FID(FID 2.06)を達成。
- CTM はスコアアクセスと再構成損失を活用して CIFAR-10 で新しい SOTA の尤度推定 (NLL) を達成。
- gamma-sampling を用いたCTMは、全確率的(gamma=1)、決定論的(gamma=0)、および中間の確率的領域を包含する統一サンプリングフレームワークを提供し、NFEs 全体で安定性と品質を向上させる。
- NFE=1 では teacher models (EDM/CM) を上回り、NFE=2 では CIFAR-10 で優れた結果を達成、ImageNet 64×64 でも同様に優れ、訓練反復数は CM/EDM の 10% と大幅に少なくて済む。
- CTM はスコアアクセスを通じた厳密な尤度計算を可能にし、軌道品質をさらに磨くための対向訓練を可能にする。最小の NFEs で競争力のあるまたは優れた結果を達成。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。