Skip to main content
QUICK REVIEW

[论文解读] Generating Multi-Categorical Samples with Generative Adversarial Networks

Ramiro Daniel Camino, Christian Hammerschmidt|arXiv (Cornell University)|Jul 3, 2018
Generative Adversarial Networks and Image Synthesis参考文献 14被引用 43
一句话总结

本工作通过使用多输出 Gumbel-Softmax 或 Softmax 层,改编 GAN 以生成包含多个分类变量的样本,比较了几种架构,并在不同数据集上展示了相对现有基线的性能提升。

ABSTRACT

We propose a method to train generative adversarial networks on mutivariate feature vectors representing multiple categorical values. In contrast to the continuous domain, where GAN-based methods have delivered considerable results, GANs struggle to perform equally well on discrete data. We propose and compare several architectures based on multiple (Gumbel) softmax output layers taking into account the structure of the data. We evaluate the performance of our architecture on datasets with different sparsity, number of features, ranges of categorical values, and dependencies among the features. Our proposed architecture and method outperforms existing models.

研究动机与目标

  • 激发并解决使用 GAN 生成具有多个分类变量的样本的挑战。
  • 提出针对多分类输出的架构与训练损失。
  • 扩展评估指标以评估多分类合成数据的保真度。
  • 在具有不同稀疏性、维度和依赖性的数据集上,实证比较六种模型。

提出的方法

  • 将多分类数据表示为跨 N 个分类变量的一组独热编码的串联。
  • 修改生成器/解码器,以使用 Gumbel-Softmax 或 Softmax 激活为每个分类变量产生单独的输出。
  • 将 GAN、WGAN-GP、ARAE、MedGAN 等变体应用并适应到多分类场景,配以相应的损失形式。
  • 在训练过程中引入考虑多分类结构的重构/奖励目标(按类别的交叉熵)。
  • 使用扩展的评估指标,比较生成数据与真实数据在边际分布和预测依赖信息上的差异。

实验结果

研究问题

  • RQ1如何在不丢失训练可微性的情况下,将 GAN 调整为生成具有多个分类变量的样本?
  • RQ2在稀疏性、维度和分类基数各异的数据集上,多分类 GAN 架构是否优于改编的基线模型?
  • RQ3哪些评估指标能有效同时捕捉合成数据中多分类特征的边际保真度与依赖关系?

主要发现

模型数据集MSE_pMSE_fMSE_a
ARAEFIXED 20.00031 ± 0.000040.00001 ± 0.000010.00059 ± 0.00022
MedGANFIXED 20.00036 ± 0.000310.00005 ± 0.000030.00056 ± 0.00033
MC-ARAEFIXED 20.00046 ± 0.000280.00001 ± 0.000000.00058 ± 0.00024
MC-MedGANFIXED 20.00013 ± 0.000060.00000 ± 0.000000.00032 ± 0.00017
MC-GumbelGANFIXED 20.00337 ± 0.001880.00014 ± 0.000120.00050 ± 0.00012
MC-WGAN-GPFIXED 20.00030 ± 0.000070.00001 ± 0.000000.00068 ± 0.00012
ARAEFIXED 100.00398 ± 0.000020.00274 ± 0.000210.02156 ± 0.00175
MedGANFIXED 100.00720 ± 0.008250.00463 ± 0.004040.01961 ± 0.00214
MC-ARAEFIXED 100.00266 ± 0.000090.00036 ± 0.000180.01086 ± 0.00159
MC-MedGANFIXED 100.00022 ± 0.000030.00167 ± 0.000100.00062 ± 0.00044
MC-GumbelGANFIXED 100.00056 ± 0.000060.00110 ± 0.000130.00055 ± 0.00035
MC-WGAN-GPFIXED 100.00026 ± 0.000010.00123 ± 0.000050.00048 ± 0.00010
ARAEMIX SMALL0.00261 ± 0.000200.01303 ± 0.001460.01560 ± 0.00039
MedGANMIX SMALL0.00083 ± 0.000390.01889 ± 0.002580.02070 ± 0.00170
MC-ARAEMIX SMALL0.00195 ± 0.000400.00081 ± 0.000180.00759 ± 0.00100
MC-MedGANMIX SMALL0.00029 ± 0.000030.00133 ± 0.000120.00080 ± 0.00018
MC-GumbelGANMIX SMALL0.00078 ± 0.000270.00104 ± 0.000130.00047 ± 0.00008
MC-WGAN-GPMIX SMALL0.00048 ± 0.000100.00140 ± 0.000140.00037 ± 0.00016
ARAEMIX BIG0.04209 ± 0.003620.02075 ± 0.011440.00519 ± 0.00087
MedGANMIX BIG0.01023 ± 0.002630.00211 ± 0.000330.00708 ± 0.00162
MC-ARAEMIX BIG0.00800 ± 0.000190.00249 ± 0.000350.00472 ± 0.00092
MC-MedGANMIX BIG0.00142 ± 0.000150.00491 ± 0.000550.01309 ± 0.00106
MC-GumbelGANMIX BIG0.00312 ± 0.000320.00194 ± 0.000170.00430 ± 0.00021
MC-WGAN-GPMIX BIG0.00144 ± 0.000060.00536 ± 0.000300.01664 ± 0.00177
ARAECENSUS0.00165 ± 0.000820.00206 ± 0.000300.00668 ± 0.00175
MedGANCENSUS0.00871 ± 0.010780.00709 ± 0.008890.01723 ± 0.02177
MC-ARAECENSUS0.00333 ± 0.000200.00129 ± 0.000190.00360 ± 0.00095
MC-MedGANCENSUS0.00012 ± 0.000040.00024 ± 0.000030.00013 ± 0.00003
MC-GumbelGANCENSUS0.01866 ± 0.000400.00981 ± 0.000340.03930 ± 0.00469
MC-WGAN-GPCENSUS0.00019 ± 0.000040.00017 ± 0.000020.00008 ± 0.00002
  • 多分类 GAN 变体在多数数据集上通常优于基线 ARAE 和 MedGAN。
  • 性能提升因数据集和配置而异;没有单一模型在所有设置中占主导。
  • 基于 Gumbel-Softmax 与 WGAN-GP 的生成器及其多分类解码器,在若干配置上获得基于 MSE 的指标提升。
  • 接近 Census 的数据集在若干多分类模型上显示出显著提升,表明在高维、多类别数据上的有效性。
  • 在多次实验中,较高的维度和稀疏性会增加捕捉依赖关系的难度,影响模型选择。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。