Skip to main content
QUICK REVIEW

[논문 리뷰] Learning Fast Samplers for Diffusion Models by Differentiating Through Sample Quality

Daniel Watson, William Chan|arXiv (Cornell University)|2022. 02. 11.
Generative Adversarial Networks and Image Synthesis인용 수 23
한 줄 요약

이 논문은 Differentiable Diffusion Sampler Search (DDSS)를 도입하여 샘플 품질에 미분 가능하게 경유하여 샘플러의 최적화를 수행하고 재훈련 없이도 몇 단계의 고품질 확산 샘플링이 가능하도록 GGDM 계열의 비-마르코프 샘플러를 빠르게 확보한다.

ABSTRACT

Diffusion models have emerged as an expressive family of generative models rivaling GANs in sample quality and autoregressive models in likelihood scores. Standard diffusion models typically require hundreds of forward passes through the model to generate a single high-fidelity sample. We introduce Differentiable Diffusion Sampler Search (DDSS): a method that optimizes fast samplers for any pre-trained diffusion model by differentiating through sample quality scores. We also present Generalized Gaussian Diffusion Models (GGDM), a family of flexible non-Markovian samplers for diffusion models. We show that optimizing the degrees of freedom of GGDM samplers by maximizing sample quality scores via gradient descent leads to improved sample quality. Our optimization procedure backpropagates through the sampling process using the reparametrization trick and gradient rematerialization. DDSS achieves strong results on unconditional image generation across various datasets (e.g., FID scores on LSUN church 128x128 of 11.6 with only 10 inference steps, and 4.82 with 20 steps, compared to 51.1 and 14.9 with strongest DDPM/DDIM baselines). Our method is compatible with any pre-trained diffusion model without fine-tuning or re-training required.

연구 동기 및 목표

  • Diffusion 모델에서 원래 모델의 재훈련이나 미세 조정 없이 추론 단계를 줄이는 것을 동기화한다.
  • 빠른 샘플러를 탐색하기 위한 미분 가능 최적화 프레임워크를 제안한다.
  • Generalized Gaussian Diffusion Model (GGDM)을 유연한 샘플러 계열로 도입한다.
  • 지각 손실(KID)을 사용한 샘플러 최적화가 몇 단계 구간에서 샘플 품질을 향상시킴을 보여준다.

제안 방법

  • 확산 샘플링 과정을 언롤하고 재매개변수화 트릭과 그래디언트 재생성을 이용해 매개 샘플러를 최적화한다.
  • 평균과 분산을 제어하는 학습 가능한 매개변수를 갖는 샘플러 계열(DDIM, VARS, GGDM 및 GGDM 변형)을 정의한다.
  • KID를 기반으로 한 지각 손실을 Inception 특징에서 계산하여 인간이 지각하는 품질과 정렬되도록 한다( Eq. 7–9 ).
  • 확률적 샘플러를 역전파하기 위해 샘플링 체인을 따라 미니배치 SGD(Adam)를 통해 미분한다.
  • 각 denoising 단계에서 이전의 더 노이즈가 있는 이미지의 정보를 통합하는 Generalized Gaussian Diffusion Models(GGDM)을 도입한다.
  • few-step 성능을 개선하기 위해 시간(step) 선택(TIME)과 예측 계수(PRED)의 학습을 허용한다.

실험 결과

연구 질문

  • RQ1미분 가능 최적화 절차가 사전 학습된 확산 모델의 기존 few-step 기준보다 빠른 샘플러를 식별할 수 있는가?
  • RQ2지각 손실(KID)을 최적화하는 것이 가능성이나 ELBO만 기반의 최적화보다 시각적으로 더 높은 충실도를 가진 샘플러를 만들어내는가?
  • RQ3GGDM 계열은 DDIM이나 VARS에 비해 얼마나 유연하게 고품질의 few-step 샘플러를 발견하는가?

주요 결과

SamplerKFIDIS
DDPM (linear stride)584.275.396
DDPM (linear stride)1043.397.034
DDPM (linear stride)1531.407.609
DDPM (linear stride)2025.947.879
DDPM (linear stride)2522.608.043
DDPM (quadratic stride)576.255.435
DDPM (quadratic stride)1042.036.965
DDPM (quadratic stride)1527.787.714
DDPM (quadratic stride)2020.2258.128
DDPM (quadratic stride)2516.178.350
DDIM (linear stride)544.416.750
DDIM (linear stride)1019.117.965
DDIM (linear stride)1514.068.190
DDIM (linear stride)2011.828.420
DDIM (linear stride)2510.528.512
DDIM (quadratic stride)532.667.090
DDIM (quadratic stride)1013.628.190
DDIM (quadratic stride)159.3188.495
DDIM (quadratic stride)207.5008.641
DDIM (quadratic stride)256.5608.759
GGDM +PRED+TIME513.778.520
GGDM +PRED+TIME108.2278.903
GGDM +PRED+TIME156.1159.050
GGDM +PRED+TIME204.7229.261
GGDM +PRED+TIME254.2509.186
DDPM (linear stride)5122.05.878
DDPM (linear stride)1058.7810.67
DDPM (linear stride)1539.3013.22
DDPM (linear stride)2031.3614.72
DDPM (linear stride)2526.3615.71
DDPM (quadratic stride)5394.81.351
DDPM (quadratic stride)10129.55.997
DDPM (quadratic stride)1580.109.595
DDPM (quadratic stride)2061.3411.60
DDPM (quadratic stride)2549.6013.01
DDIM (linear stride)5135.45.898
DDIM (linear stride)1040.7012.225
DDIM (linear stride)1528.5413.99
DDIM (linear stride)2024.22514.75
DDIM (linear stride)2522.1315.16
DDIM (quadratic stride)5409.11.380
DDIM (quadratic stride)10148.65.533
DDIM (quadratic stride)1567.659.842
DDIM (quadratic stride)2045.6011.99
DDIM (quadratic stride)2536.1113.225
GGDM +PRED+TIME555.1412.90
GGDM +PRED+TIME1037.3214.76
GGDM +PRED+TIME1524.6917.225
GGDM +PRED+TIME2020.6917.92
GGDM +PRED+TIME2518.4018.12
  • DDSS는 동일한 작은 스텝 버짓에서 강력한 기준선보다 substantially 더 나은 FID/IS 점수를 달성하는 빠른 샘플러를 발견한다(예: GGDM +PRED+TIME, 5–25 스텝이 DDPM/DDIM 기준선을 능가).
  • 지각 손실(KID)을 지각 손실로 최적화하면 높은 충실도의 샘플이 만들어지며 CIFAR-10, ImageNet-64 등 데이터셋에서 Robust한 개선이 나타난다.
  • GGDM은 DDIM이나 VARS 같은 좁은 계열보다 넓은 샘플러 계열로서 항상 몇 단계에서 더 나은 성능을 보이며, 모달리티가 원 forward 프로세스와 다를 때도 나타난다.
  • DDSS는 사전 학습된 DDPM의 미세 조정이나 재훈련이 필요하지 않으며, 원샷 포스트호크 샘플러 탐색으로 작동한다.
  • CIFAR-10과 ImageNet-64 전반에 걸쳐 DDSS 변형은 매우 적은 단계에서 FID를 크게 감소시키고 IS를 경쟁력 있게 유지한다(예: CIFAR-10: 5–25 단계; ImageNet-64: 5–25 단계).
  • Qualitative 샘플은 DDIM(η=0) 대비 저 스텝에서 비선정된 개선을 보여준다.

더 나은 연구,지금 바로 시작하세요

연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.

카드 등록 없음 · 무료 플랜 제공

이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.