[论文解读] Generating Multi-Categorical Samples with Generative Adversarial Networks
本工作通过使用多输出 Gumbel-Softmax 或 Softmax 层,改编 GAN 以生成包含多个分类变量的样本,比较了几种架构,并在不同数据集上展示了相对现有基线的性能提升。
We propose a method to train generative adversarial networks on mutivariate feature vectors representing multiple categorical values. In contrast to the continuous domain, where GAN-based methods have delivered considerable results, GANs struggle to perform equally well on discrete data. We propose and compare several architectures based on multiple (Gumbel) softmax output layers taking into account the structure of the data. We evaluate the performance of our architecture on datasets with different sparsity, number of features, ranges of categorical values, and dependencies among the features. Our proposed architecture and method outperforms existing models.
研究动机与目标
- 激发并解决使用 GAN 生成具有多个分类变量的样本的挑战。
- 提出针对多分类输出的架构与训练损失。
- 扩展评估指标以评估多分类合成数据的保真度。
- 在具有不同稀疏性、维度和依赖性的数据集上,实证比较六种模型。
提出的方法
- 将多分类数据表示为跨 N 个分类变量的一组独热编码的串联。
- 修改生成器/解码器,以使用 Gumbel-Softmax 或 Softmax 激活为每个分类变量产生单独的输出。
- 将 GAN、WGAN-GP、ARAE、MedGAN 等变体应用并适应到多分类场景,配以相应的损失形式。
- 在训练过程中引入考虑多分类结构的重构/奖励目标(按类别的交叉熵)。
- 使用扩展的评估指标,比较生成数据与真实数据在边际分布和预测依赖信息上的差异。
实验结果
研究问题
- RQ1如何在不丢失训练可微性的情况下,将 GAN 调整为生成具有多个分类变量的样本?
- RQ2在稀疏性、维度和分类基数各异的数据集上,多分类 GAN 架构是否优于改编的基线模型?
- RQ3哪些评估指标能有效同时捕捉合成数据中多分类特征的边际保真度与依赖关系?
主要发现
| 模型 | 数据集 | MSE_p | MSE_f | MSE_a |
|---|---|---|---|---|
| ARAE | FIXED 2 | 0.00031 ± 0.00004 | 0.00001 ± 0.00001 | 0.00059 ± 0.00022 |
| MedGAN | FIXED 2 | 0.00036 ± 0.00031 | 0.00005 ± 0.00003 | 0.00056 ± 0.00033 |
| MC-ARAE | FIXED 2 | 0.00046 ± 0.00028 | 0.00001 ± 0.00000 | 0.00058 ± 0.00024 |
| MC-MedGAN | FIXED 2 | 0.00013 ± 0.00006 | 0.00000 ± 0.00000 | 0.00032 ± 0.00017 |
| MC-GumbelGAN | FIXED 2 | 0.00337 ± 0.00188 | 0.00014 ± 0.00012 | 0.00050 ± 0.00012 |
| MC-WGAN-GP | FIXED 2 | 0.00030 ± 0.00007 | 0.00001 ± 0.00000 | 0.00068 ± 0.00012 |
| ARAE | FIXED 10 | 0.00398 ± 0.00002 | 0.00274 ± 0.00021 | 0.02156 ± 0.00175 |
| MedGAN | FIXED 10 | 0.00720 ± 0.00825 | 0.00463 ± 0.00404 | 0.01961 ± 0.00214 |
| MC-ARAE | FIXED 10 | 0.00266 ± 0.00009 | 0.00036 ± 0.00018 | 0.01086 ± 0.00159 |
| MC-MedGAN | FIXED 10 | 0.00022 ± 0.00003 | 0.00167 ± 0.00010 | 0.00062 ± 0.00044 |
| MC-GumbelGAN | FIXED 10 | 0.00056 ± 0.00006 | 0.00110 ± 0.00013 | 0.00055 ± 0.00035 |
| MC-WGAN-GP | FIXED 10 | 0.00026 ± 0.00001 | 0.00123 ± 0.00005 | 0.00048 ± 0.00010 |
| ARAE | MIX SMALL | 0.00261 ± 0.00020 | 0.01303 ± 0.00146 | 0.01560 ± 0.00039 |
| MedGAN | MIX SMALL | 0.00083 ± 0.00039 | 0.01889 ± 0.00258 | 0.02070 ± 0.00170 |
| MC-ARAE | MIX SMALL | 0.00195 ± 0.00040 | 0.00081 ± 0.00018 | 0.00759 ± 0.00100 |
| MC-MedGAN | MIX SMALL | 0.00029 ± 0.00003 | 0.00133 ± 0.00012 | 0.00080 ± 0.00018 |
| MC-GumbelGAN | MIX SMALL | 0.00078 ± 0.00027 | 0.00104 ± 0.00013 | 0.00047 ± 0.00008 |
| MC-WGAN-GP | MIX SMALL | 0.00048 ± 0.00010 | 0.00140 ± 0.00014 | 0.00037 ± 0.00016 |
| ARAE | MIX BIG | 0.04209 ± 0.00362 | 0.02075 ± 0.01144 | 0.00519 ± 0.00087 |
| MedGAN | MIX BIG | 0.01023 ± 0.00263 | 0.00211 ± 0.00033 | 0.00708 ± 0.00162 |
| MC-ARAE | MIX BIG | 0.00800 ± 0.00019 | 0.00249 ± 0.00035 | 0.00472 ± 0.00092 |
| MC-MedGAN | MIX BIG | 0.00142 ± 0.00015 | 0.00491 ± 0.00055 | 0.01309 ± 0.00106 |
| MC-GumbelGAN | MIX BIG | 0.00312 ± 0.00032 | 0.00194 ± 0.00017 | 0.00430 ± 0.00021 |
| MC-WGAN-GP | MIX BIG | 0.00144 ± 0.00006 | 0.00536 ± 0.00030 | 0.01664 ± 0.00177 |
| ARAE | CENSUS | 0.00165 ± 0.00082 | 0.00206 ± 0.00030 | 0.00668 ± 0.00175 |
| MedGAN | CENSUS | 0.00871 ± 0.01078 | 0.00709 ± 0.00889 | 0.01723 ± 0.02177 |
| MC-ARAE | CENSUS | 0.00333 ± 0.00020 | 0.00129 ± 0.00019 | 0.00360 ± 0.00095 |
| MC-MedGAN | CENSUS | 0.00012 ± 0.00004 | 0.00024 ± 0.00003 | 0.00013 ± 0.00003 |
| MC-GumbelGAN | CENSUS | 0.01866 ± 0.00040 | 0.00981 ± 0.00034 | 0.03930 ± 0.00469 |
| MC-WGAN-GP | CENSUS | 0.00019 ± 0.00004 | 0.00017 ± 0.00002 | 0.00008 ± 0.00002 |
- 多分类 GAN 变体在多数数据集上通常优于基线 ARAE 和 MedGAN。
- 性能提升因数据集和配置而异;没有单一模型在所有设置中占主导。
- 基于 Gumbel-Softmax 与 WGAN-GP 的生成器及其多分类解码器,在若干配置上获得基于 MSE 的指标提升。
- 接近 Census 的数据集在若干多分类模型上显示出显著提升,表明在高维、多类别数据上的有效性。
- 在多次实验中,较高的维度和稀疏性会增加捕捉依赖关系的难度,影响模型选择。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。