Skip to main content
QUICK REVIEW

[论文解读] Progressive Distillation for Fast Sampling of Diffusion Models

Tim Salimans, Jonathan Ho|arXiv (Cornell University)|Feb 1, 2022
Generative Adversarial Networks and Image Synthesis被引用 190
一句话总结

论文提出渐进蒸馏,用以压缩扩散模型采样过程,通过迭代地将步骤数减半,在仅需 4 步时实现高质量样本,同时保持性能。

ABSTRACT

Diffusion models have recently shown great promise for generative modeling, outperforming GANs on perceptual quality and autoregressive models at density estimation. A remaining downside is their slow sampling time: generating high quality samples takes many hundreds or thousands of model evaluations. Here we make two contributions to help eliminate this downside: First, we present new parameterizations of diffusion models that provide increased stability when using few sampling steps. Second, we present a method to distill a trained deterministic diffusion sampler, using many steps, into a new diffusion model that takes half as many sampling steps. We then keep progressively applying this distillation procedure to our model, halving the number of required sampling steps each time. On standard image generation benchmarks like CIFAR-10, ImageNet, and LSUN, we start out with state-of-the-art samplers taking as many as 8192 steps, and are able to distill down to models taking as few as 4 steps without losing much perceptual quality; achieving, for example, a FID of 3.0 on CIFAR-10 in 4 steps. Finally, we show that the full progressive distillation procedure does not take more time than it takes to train the original model, thus representing an efficient solution for generative modeling using diffusion at both train and test time.

研究动机与目标

  • 动机并解决扩散模型在无条件与类条件图像生成中的慢采样瓶颈。
  • 提出一种蒸馏方法,将来自慢速多步教师的精度转移给更快的学生模型。
  • 证明渐进蒸馏在降低步数的同时,仍能在标准基准(CIFAR-10、ImageNet、LSUN)上保持样本质量。
  • 提供稳定的扩散参数化和损失加权策略,以支持快速蒸馏且不降级。
  • 表明整体蒸馏过程相对于训练原始模型具有计算效率。

提出的方法

  • 提出渐进蒸馏的公式化方法,使学生模型以一个学生步骤去匹配两个教师 DDIM 步骤,然后迭代地将步骤数减半。
  • 使用来自让教师从给定的 z_t 开始运行两步 DDIM 并反演以获得学生的清晰目标的蒸馏目标。
  • 尝试扩散去噪模型的不同参数化(直接 x、epsilon 或 v)以及在低信噪比时仍具信息性的损失加权。
  • 采用余弦噪声调度和具有 BigGAN 风格上采样/下采样的 U-Net 架构进行训练,遵循标准扩散模型训练流程。
  • 将蒸馏模型与 DDIM 和随机基线进行比较,并在各数据集上报告 FID(如适用,亦报告 IS)。
  • 提供开源实现和详细的可重复性说明。

实验结果

研究问题

  • RQ1渐进蒸馏是否能在不显著降低样本质量的前提下,实质性地减少扩散模型的采样步骤?
  • RQ2哪些参数化和损失加权策略能够在无条件与类条件生成中实现稳定且有效的蒸馏?
  • RQ3在标准基准(CIFAR-10、ImageNet、LSUN)上,蒸馏模型的质量与现有快速采样方法相比如何?
  • RQ4相对训练原始扩散模型,渐进蒸馏的计算成本是多少,是否具有实际效率?

主要发现

  • 蒸馏模型在 CIFAR-10 上仅需 4 步采样就可生成高质量样本,取得具有竞争力的 FID 分数。
  • 渐进蒸馏通过在每次迭代中将步数减半带来显著加速,整个过程的训练时间不超过原始模型的训练时间。
  • 不同的稳定参数化(直接 x、epsilon 或 v)和损失加权(基于信噪比的变体)表现出鲁棒性,在消融实验中识别出一种不稳定的组合。
  • 在 CIFAR-10、64×64 ImageNet 和 128×128 LSUN 基准上,蒸馏模型在相近或更低的步数下优于或可与更快的基线相比,尤其是在 4 到 8 步的区间。
  • 蒸馏过程保持高效:总时间不超过训练原始模型所需的时间,使测试时的效率提升具有实际意义。
  • 蒸馏模型也可与随机采样器一起使用,在 DDIM 与随机基线之间获得中等性能。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。