[논문 리뷰] Consistency Trajectory Models: Learning Probability Flow ODE Trajectory of Diffusion
CTM은 확률 흐름 ODE 궤적을 학습함으로써 점수 기반(diffusion) 모델과 증류(diffusion) 모델을 하나의 프레임워크에서 통합하고, 유연하고 고품질의 적은 NFE 샘플링과 궤적을 통해 탐색하는 새로운 gamma 샘플링 방식을 가능하게 한다.
Consistency Models (CM) (Song et al., 2023) accelerate score-based diffusion model sampling at the cost of sample quality but lack a natural way to trade-off quality for speed. To address this limitation, we propose Consistency Trajectory Model (CTM), a generalization encompassing CM and score-based models as special cases. CTM trains a single neural network that can -- in a single forward pass -- output scores (i.e., gradients of log-density) and enables unrestricted traversal between any initial and final time along the Probability Flow Ordinary Differential Equation (ODE) in a diffusion process. CTM enables the efficient combination of adversarial training and denoising score matching loss to enhance performance and achieves new state-of-the-art FIDs for single-step diffusion model sampling on CIFAR-10 (FID 1.73) and ImageNet at 64x64 resolution (FID 1.92). CTM also enables a new family of sampling schemes, both deterministic and stochastic, involving long jumps along the ODE solution trajectories. It consistently improves sample quality as computational budgets increase, avoiding the degradation seen in CM. Furthermore, unlike CM, CTM's access to the score function can streamline the adoption of established controllable/conditional generation methods from the diffusion community. This access also enables the computation of likelihood. The code is available at https://github.com/sony/ctm.
연구 동기 및 목표
- 단일 프레임워크 내에서 점수 기반 및 증류 확산 모델을 연결한다.
- PF ODE의 미분적(점수) 및 적분적(궤적) 성분을 모두 일치시키도록 학습을 가능하게 한다.
- 품질과 계산 비용 사이의 트레이드를 가능하게 하기 위해 PF ODE 궤적을 자유롭게 탐색하도록 한다.
- 성능 향상을 위해 적대적 학습 및 재구성/잡음 제거 손실을 도입한다.
- 긴 궤적 구간을 제어 가능한 확률성으로 탐색하기 위한 gamma 샘플링을 도입한다.
제안 방법
- G(x_t, t, s)를 PF ODE 해로 정의하고, 적분과 피적분에 모두 접근하기 위한 보조 변수 g를 도입한다(도lem마 1).
- G_theta(x_t, t, s) = (s/t) x_t + (1 - s/t) g_theta(x_t, t, s)로 매개화하여 궤적과 피적분함수 접근을 모두 가능하게 한다.
- 재구성 유사 손실과 소프트 일관성 손실(Eq. 5)을 사용하여 CTM 예측을 경험적 PF ODE 궤적에 소프트 매칭하여 CTM을 훈련시킨다.
- 사전 학습된 점수 모델 D_phi를 교사로 사용하여 궤적 재구성(Eq. 3) 및 소프트 매칭(Eq. 5)을 위한 Solver(x_t, t, u; phi)을 얻는다.
- CTM 손실 L_CTM과 DSM 손실 L_DSM 및 GAN 손실 L_GAN의 합으로 L을 공동 최적화하여 CTM, 잡음 제거 점수 매칭 및 적대적 학습을 융합한다.
- gamma-샘플링을 도입하여 PF ODE 궤적을 순방향 및 역방향으로 탐색하되, gamma를 조정하여 확률성을 제어한다(gamma ∈ [0,1]).
실험 결과
연구 질문
- RQ1하나의 신경망이 점수 추정과 궤적 기반 업데이트를 모두 출력하여 점수 기반 샘플링과 증류 샘플링을 통합할 수 있는가?
- RQ2PF ODE 궤적을 학습하면 NFEs에 따른 품질 저하 없이 샘플링 속도와 샘플 품질 사이를 유연하게 트레이드할 수 있는가?
- RQ3gamma-샘플링이 서로 다른 NFEs에서 샘플의 진실도와 다양성에 어떤 영향을 미치는가?
- RQ4CTM이 CIFAR-10 및 ImageNet 64×64에서 몇 NFEs로도 최첨단 FID 및 가능성(likelihood)을 달성할 수 있는가?
- RQ5CTM이 점수 접근을 통해 정확한 가능도 계산을 지원하고 궤적의 적대적 정제(adversarial refinement)를 가능하게 하는가?
주요 결과
| 모델 | NFE | 무조건 FID | 조건부 FID | NLL | 비고 |
|---|---|---|---|---|---|
| CTM (ours) | 1 | 1.98 | 2.43 | - | CIFAR-10, unconditional/conditional; NFE=1 |
| CTM (ours) | 2 | 1.87 | 2.43 | - | CIFAR-10, NFE=2; surpasses teacher with 2 NFEs |
| CTM (ours) | 1 | 2.06 | - | - | ImageNet 64×64; NFE=1; SOTA among distillation/short-NFE methods |
| CTM (ours) | 2 | 1.90 | - | - | ImageNet 64×64; NFE=2; SOTA among distillation methods |
- CTM은 단일 네트워크가 점수와 궤적 정보를 모두 출력하도록 학습하여 CIFAR-10의 단일 스텝 확산 샘플링에서 FID 1.73, ImageNet 64×64에서 FID 2.06의 최첨단 성능을 달성한다.
- CTM은 점수 접근과 재구성 손실을 활용하여 CIFAR-10에서 새로운 NLL(가능도) 추정치를 달성한다.
- gamma-샘플링이 완전히 확률적(GAMMA=1), 결정적(GAMMA=0), 중간의 확률적 체계를 포함하는 단일 샘플링 프레임워크를 제공하며, 서로 다른 NFEs에서 안정성과 품질이 향상된다.
- CTM은 NFE=1에서 교사 모델(EDM/CM)을 능가하고 NFE=2에서 더 우수한 결과를 얻으며 CIFAR-10에서 ImageNet 64×64에서도 성능이 좋고, 학습 반복 수가 CM/EDM의 10%에 불과하다.
- CTM은 점수 접근을 통한 정확한 가능도 계산을 가능하게 하고 궤적의 적대적 학습 정제를 통해 최소한의 NFEs로도 경쟁력 있는 결과를 달성한다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.