[논문 리뷰] Timestep-Aware Block Masking for Efficient Diffusion Model Inference
시간대별로 학습된 이진 마스크를 도입해 확산모델의 계산을 건너뛰어 샘플링 속도를 DDPM, LDM, DiT, PixArt 전반에서 더 빠르게하되 품질 손실은 최소화한다. 마스크는 타임스텝별로 엔드투엔드로 학습되며 특성 충실도, 희소성, 이중모드 정규화와 타임스텝 인식 손실 스케일링 및 지식 가이드 보정으로 최적화된다.
Diffusion Probabilistic Models (DPMs) have achieved great success in image generation but suffer from high inference latency due to their iterative denoising nature. Motivated by the evolving feature dynamics across the denoising trajectory, we propose a novel framework to optimize the computational graph of pre-trained DPMs on a per-timestep basis. By learning timestep-specific masks, our method dynamically determines which blocks to execute or bypass through feature reuse at each inference stage. Unlike global optimization methods that incur prohibitive memory costs via full-chain backpropagation, our method optimizes masks for each timestep independently, ensuring a memory-efficient training process. To guide this process, we introduce a timestep-aware loss scaling mechanism that prioritizes feature fidelity during sensitive denoising phases, complemented by a knowledge-guided mask rectification strategy to prune redundant spatial-temporal dependencies. Our approach is architecture-agnostic and demonstrates significant efficiency gains across a broad spectrum of models, including DDPM, LDM, DiT, and PixArt. Experimental results show that by treating the denoising process as a sequence of optimized computational paths, our method achieves a superior balance between sampling speed and generative quality. Our code will be released.
연구 동기 및 목표
- 확산 모델의 추론 비용을 타임스텝 간 안정적인 특성 역학을 활용해 줄이는 것을 동기로 삼는다.
- 기본 모델의 재학습 없이 타임스텝별 이진 마스크 프레임워크를 제안해 블록 계산을 건너뛰거나 재사용한다.
- 각 타임스텝마다 마스크를 독립적으로 최적화해 메모리 효율적인 학습을 보장한다.
- 타임스텝 인식 손실 스케일링과 지식 기반 마스크 보정을 도입해 생성 품질을 유지한다.
제안 방법
- t×b 이진 마스크 m을 학습해 t는 타임스텝, b는 네트워크 블록을 각각 나타내고, 블록이 계산될지 아니면 캐시된 피처를 재사용할지 결정한다.
- 확산 모델 파라미터를 고정하고 원래 모델과의 피처 충실도 손실로 엔드투엔드 학습을 통해 타임스텝별 마스크를 최적화한다.
- m에 대해 [0,1]의 연속 이완 s를 사용하고 L1 희소성 및 이중모드 정규화를 적용해 이진화를 촉진한다.
- 피처 변화 delta[t]를 바탕으로 타임스텝 인식 손실 가중치를 도입해 민감한 디노이징 단계에서 충실도를 우선시한다.
- 블록과 타임스텝 간의 의존성을 전파하는 지식 기반 후처리 규칙을 적용해 추론 가속을 추가로 달성한다.
- UNet형 CNN 및 diffusion 트랜스포머(DiT의 MHA/MLP, U-Net의 ResBlock/AttnBlock 등)에 적용 가능한 아키텍처 독립적 접근법을 제공한다.
실험 결과
연구 질문
- RQ1타임스텝별 마스킹으로 원래 모델 재학습 없이 확산 모델 추론 속도를 높일 수 있는가?
- RQ2높은 속도 향상을 달성하면서 생성 품질을 보존하기 위해 이러한 마스크를 효율적으로 학습하는 방법은 무엇인가?
- RQ3타임스텝 인식 손실과 마스크 보정이 충실도를 유지하면서 가속화를 극대화하는 데 어떤 역할을 하는가?
- RQ4이 방법이 다양한 확산 아키텍처(DDPM, LDM, DiT, PixArt) 및 데이터셋에 대해 보편적인가?
주요 결과
| 방법 | 추가 데이터 | 학습 시간↓ | MACs↓ | 속도↑ | FID↓ |
|---|---|---|---|---|---|
| DDPM [13] | – | 0.61T | 1× | 1.00 | 4.19 |
| DDPM* | – | 0.61T | 1× | 1.00 | 4.25 |
| Diff-Pruning [7] | ✓ | 0.34T | 1.37× | 1.37 | 5.29 |
| CT [42] * | ✓ | – | – | 1.62× | 4.68 |
| DeepCache [30] | ✗ | 0.35T | 1.61× | 1.61 | 4.70 |
| Ours | ✗ | 0.2h | 0.34T | 1.63× | 4.66 |
- 다양한 아키텍처에서 의미 있는 속도 향상을 달성: DDPM에서 CIFAR-10은 1.63×, LSUN 변형에서 1.31×–1.63×, ImageNet에서 2.75×는 LDM-4-G로, FID/IS 지표는 경쟁력 있음.
- CIFAR-10, LSUN-Bedroom, LSUN-Churches에서 당사 방법은 속도를 consistently 향상시키며 FID를 베이스라인(DeepCache, Diff-Pruning, CT) 대비 보존하거나 약간 개선.
- ImageNet의 DiT-XL/2에 대해 우리 방법은 L2C 정확도와 동등하거나 이를 상회하면서 가속은 더 빠름(256×256에서 1.67×, 512×512에서도 경쟁력).
- 앙상블에서 무작위 마스크 샘플링이 Gumbel-Softmax보다 이 설정에서 더 낫다는 점; 마스크 보정과 타임스텝 인식 손실 스케일링이 속도 향상의 품질 손실을 최소로 크게 향상.
- 학습된 모델에서 마스크 값은 0에 가까운 혹은 1에 가까운 집중을 보이며 안정적이고 결정적인 블록 건너뛰기 결정을 시사.
- 마스크 학습에는 가우시안 노이즈만 입력으로 필요하고 사전 학습된 가중치를 수정하지 않는다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.