[论文解读] Learning Fast Samplers for Diffusion Models by Differentiating Through Sample Quality
本文提出 Differentiable Diffusion Sampler Search (DDSS),通过对样本质量微分来优化快速、非马尔可夫采样器(GGDM 家族),实现无需重新训练模型即可在少步扩散抽样中获得高质量的结果。
Diffusion models have emerged as an expressive family of generative models rivaling GANs in sample quality and autoregressive models in likelihood scores. Standard diffusion models typically require hundreds of forward passes through the model to generate a single high-fidelity sample. We introduce Differentiable Diffusion Sampler Search (DDSS): a method that optimizes fast samplers for any pre-trained diffusion model by differentiating through sample quality scores. We also present Generalized Gaussian Diffusion Models (GGDM), a family of flexible non-Markovian samplers for diffusion models. We show that optimizing the degrees of freedom of GGDM samplers by maximizing sample quality scores via gradient descent leads to improved sample quality. Our optimization procedure backpropagates through the sampling process using the reparametrization trick and gradient rematerialization. DDSS achieves strong results on unconditional image generation across various datasets (e.g., FID scores on LSUN church 128x128 of 11.6 with only 10 inference steps, and 4.82 with 20 steps, compared to 51.1 and 14.9 with strongest DDPM/DDIM baselines). Our method is compatible with any pre-trained diffusion model without fine-tuning or re-training required.
研究动机与目标
- 激励在不重新训练或微调原始模型的情况下,减少扩散模型的推理步数。
- 提出一个可微分的优化框架,用于搜索快速采样器。
- 引入 Generalized Gaussian Diffusion Model (GGDM) 作为一个灵活的采样器家族。
- 证明使用感知损失(KID)对采样器进行优化,在少步情形下能获得更优的样本质量。
提出的方法
- 将扩散采样过程展开,并使用重参数化技巧和梯度再材料化进行反向传播,优化一个参数化采样器。
- 定义采样器家族(DDIM、VARS、GGDM 及 GGDM 变体),具备可学习的参数来控制均值与方差。
- 使用基于 Kernel Inception Distance (KID) 的感知损失,在 Inception 特征中计算,以与人类感知质量对齐(Eq. 7–9)。
- 通过在取样链上进行微分,利用小批量随机梯度下降(Adam)对随机采样器进行反向传播。
- 引入 Generalized Gaussian Diffusion Models (GGDM),在每个去噪步骤中结合来自所有先前(更嘈杂的)图像的信息。
- 允许学习时间步选择(TIME)和预测系数(PRED),以改善少步性能。
实验结果
研究问题
- RQ1可微分优化过程是否能够识别出在预训练扩散模型上超过现有少步基线的快速采样器?
- RQ2将感知损失(KID)优化是否会产生比仅基于似然或 ELBO 的优化具有更高视觉保真度的采样器?
- RQ3与 DDIM 或 VARS 相比,GGDM 家族在发现高质量少步采样器方面有多灵活?
主要发现
| Sampler | K | FID | IS |
|---|---|---|---|
| DDPM (linear stride) | 5 | 84.27 | 5.396 |
| DDPM (linear stride) | 10 | 43.39 | 7.034 |
| DDPM (linear stride) | 15 | 31.40 | 7.609 |
| DDPM (linear stride) | 20 | 25.94 | 7.879 |
| DDPM (linear stride) | 25 | 22.60 | 8.043 |
| DDPM (quadratic stride) | 5 | 76.25 | 5.435 |
| DDPM (quadratic stride) | 10 | 42.03 | 6.965 |
| DDPM (quadratic stride) | 15 | 27.78 | 7.714 |
| DDPM (quadratic stride) | 20 | 20.225 | 8.128 |
| DDPM (quadratic stride) | 25 | 16.17 | 8.350 |
| DDIM (linear stride) | 5 | 44.41 | 6.750 |
| DDIM (linear stride) | 10 | 19.11 | 7.965 |
| DDIM (linear stride) | 15 | 14.06 | 8.190 |
| DDIM (linear stride) | 20 | 11.82 | 8.420 |
| DDIM (linear stride) | 25 | 10.52 | 8.512 |
| DDIM (quadratic stride) | 5 | 32.66 | 7.090 |
| DDIM (quadratic stride) | 10 | 13.62 | 8.190 |
| DDIM (quadratic stride) | 15 | 9.318 | 8.495 |
| DDIM (quadratic stride) | 20 | 7.500 | 8.641 |
| DDIM (quadratic stride) | 25 | 6.560 | 8.759 |
| GGDM +PRED+TIME | 5 | 13.77 | 8.520 |
| GGDM +PRED+TIME | 10 | 8.227 | 8.903 |
| GGDM +PRED+TIME | 15 | 6.115 | 9.050 |
| GGDM +PRED+TIME | 20 | 4.722 | 9.261 |
| GGDM +PRED+TIME | 25 | 4.250 | 9.186 |
| DDPM (linear stride) | 5 | 122.0 | 5.878 |
| DDPM (linear stride) | 10 | 58.78 | 10.67 |
| DDPM (linear stride) | 15 | 39.30 | 13.22 |
| DDPM (linear stride) | 20 | 31.36 | 14.72 |
| DDPM (linear stride) | 25 | 26.36 | 15.71 |
| DDPM (quadratic stride) | 5 | 394.8 | 1.351 |
| DDPM (quadratic stride) | 10 | 129.5 | 5.997 |
| DDPM (quadratic stride) | 15 | 80.10 | 9.595 |
| DDPM (quadratic stride) | 20 | 61.34 | 11.60 |
| DDPM (quadratic stride) | 25 | 49.60 | 13.01 |
| DDIM (linear stride) | 5 | 135.4 | 5.898 |
| DDIM (linear stride) | 10 | 40.70 | 12.225 |
| DDIM (linear stride) | 15 | 28.54 | 13.99 |
| DDIM (linear stride) | 20 | 24.225 | 14.75 |
| DDIM (linear stride) | 25 | 22.13 | 15.16 |
| DDIM (quadratic stride) | 5 | 409.1 | 1.380 |
| DDIM (quadratic stride) | 10 | 148.6 | 5.533 |
| DDIM (quadratic stride) | 15 | 67.65 | 9.842 |
| DDIM (quadratic stride) | 20 | 45.60 | 11.99 |
| DDIM (quadratic stride) | 25 | 36.11 | 13.225 |
| GGDM +PRED+TIME | 5 | 55.14 | 12.90 |
| GGDM +PRED+TIME | 10 | 37.32 | 14.76 |
| GGDM +PRED+TIME | 15 | 24.69 | 17.225 |
| GGDM +PRED+TIME | 20 | 20.69 | 17.92 |
| GGDM +PRED+TIME | 25 | 18.40 | 18.12 |
- DDSS 发现的快速采样器在相同的小步预算下获得显著更好的 FID/IS 分数,相较于强基线(例如 GGDM +PRED+TIME 在 5–25 步时超过 DDPM/DDIM 基线)。
- 以 KID 作为感知损失进行优化可获得高保真采样,并在各数据集(CIFAR-10、ImageNet-64)上实现稳健改进。
- GGDM 作为更广义的采样器家族,在少步情形下的一致表现优于更窄的家族如 DDIM 或 VARS,即使边际分布与原前向过程不同。
- DDSS 不需要对预训练的 DDPM 进行微调或再次训练;它作为一次性的事后采样器搜索运行。
- 在 CIFAR-10 和 ImageNet-64 上,DDSS 的变体在极少步数下实现显著降低的 FID 和具有竞争力的 IS(例如 CIFAR-10:5–25 步;ImageNet-64:5–25 步)。
- 定性样本显示在低步数下相较于 DDIM(η=0) 具有非人为挑选的改进。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。