Skip to main content
QUICK REVIEW

[论文解读] Toward Understanding Generative Data Augmentation

Chenyu Zheng, Guoqiang Wu|arXiv (Cornell University)|May 27, 2023
Generative Adversarial Networks and Image Synthesis被引用 8
一句话总结

该论文为生成性数据增强(GDA)在非独立同分布设置下提供了一般化稳定性界,并推导了二元高斯混合模型和基于GAN的GDA的界限,通过仿真和CIFAR-10实验验证理论,突出在训练集较小或过拟合情形下GDA的帮助作用。

ABSTRACT

Generative data augmentation, which scales datasets by obtaining fake labeled examples from a trained conditional generative model, boosts classification performance in various learning tasks including (semi-)supervised learning, few-shot learning, and adversarially robust learning. However, little work has theoretically investigated the effect of generative data augmentation. To fill this gap, we establish a general stability bound in this not independently and identically distributed (non-i.i.d.) setting, where the learned distribution is dependent on the original train set and generally not the same as the true distribution. Our theoretical result includes the divergence between the learned distribution and the true distribution. It shows that generative data augmentation can enjoy a faster learning rate when the order of divergence term is $o(\max\left( \log(m)β_m, 1 / \sqrt{m}) ight)$, where $m$ is the train set size and $β_m$ is the corresponding stability constant. We further specify the learning setup to the Gaussian mixture model and generative adversarial nets. We prove that in both cases, though generative data augmentation does not enjoy a faster learning rate, it can improve the learning guarantees at a constant level when the train set is small, which is significant when the awful overfitting occurs. Simulation results on the Gaussian mixture model and empirical results on generative adversarial nets support our theoretical conclusions. Our code is available at https://github.com/ML-GSAI/Understanding-GDA.

研究动机与目标

  • Motivate the study of theoretical learning guarantees for generative data augmentation (GDA).
  • Develop a general algorithmic stability bound for GDA in non-i.i.d. settings where learned and true distributions differ.
  • Specialize the general bound to binary Gaussian mixture models (bGMM) and GAN-based GDA to derive explicit guarantees.
  • Analyze implications for deep generative models and practical settings, including diffusion models and CIFAR-10 experiments.

提出的方法

  • Define GDA formally with training data S, learned model distribution D_G(S), augmented data S_G, and mixed distribution D~(S).
  • Derive a generalization bound (Gen-error) for A(~S) that decomposes into a distributions' divergence term and a generalization term w.r.t. the mixed distribution.
  • Establish conditions under which GDA yields faster learning rates via a divergence order o(max(log(m)β_m, 1/√m)).
  • Specialize the bound to a binary Gaussian mixture model (bGMM) to obtain explicit rates and discuss negative learning rates for large m_G.
  • Extend the analysis to deep learning with GANs, bounding quantities via SGD stability and TV distance between distributions, and relate to diffusion models.

实验结果

研究问题

  • RQ1Can we establish learning guarantees for GDA and characterize when it improves learning performance?
  • RQ2How does the divergence between the learned distribution and the true distribution affect GDA's effectiveness?
  • RQ3What augmentation size m_G optimally balances improvement vs. data consumption under GDA?
  • RQ4Do results extend to practical deep generative models like GANs and diffusion models in standard datasets?
  • RQ5What do theoretical bounds predict for overfitting scenarios in real-world data (e.g., CIFAR-10)?

主要发现

  • A general stability-based bound for GDA shows Gen-error is controlled by distributions' divergence plus generalization error on the mixed distribution.
  • GDA can yield a faster learning rate when the learned distribution converges to the true distribution fast enough, specifically when the divergence term is o(max(log(m)β_m, 1/√m)).
  • For bGMM and GANs, the divergence term scales as at least max(log(m)β_m, 1/√m), implying limited or no faster learning rate with large m_S, but potential constant-level improvement when data is scarce and overfitting is severe.
  • In deep learning settings (GANs and SGD-based classifiers), diffusion models show promise for GDA with faster convergence in TV distance than GANs, while standard augmentation may negate gains when m_S is large.
  • Experiments on bGMM support the theoretical bounds, and CIFAR-10 experiments show GAN-based GDA helps when overfitting is present but may hurt with large m_S and standard augmentation.

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。