Skip to main content
QUICK REVIEW

[论文解读] Entity Embeddings of Categorical Variables

Cheng Guo, Felix Berkhahn|arXiv (Cornell University)|Apr 22, 2016
Bayesian Modeling and Causal Inference参考文献 36被引用 304
一句话总结

本文提出面向分类特征的实体嵌入,在神经网络中学习低维嵌入,以改善结构化数据的函数近似,显示相对独热编码的性能提升,并有助于类别的可视化与聚类。

ABSTRACT

We map categorical variables in a function approximation problem into Euclidean spaces, which are the entity embeddings of the categorical variables. The mapping is learned by a neural network during the standard supervised training process. Entity embedding not only reduces memory usage and speeds up neural networks compared with one-hot encoding, but more importantly by mapping similar values close to each other in the embedding space it reveals the intrinsic properties of the categorical variables. We applied it successfully in a recent Kaggle competition and were able to reach the third position with relative simple features. We further demonstrate in this paper that entity embedding helps the neural network to generalize better when the data is sparse and statistics is unknown. Thus it is especially useful for datasets with lots of high cardinality features, where other methods tend to overfit. We also demonstrate that the embeddings obtained from the trained neural network boost the performance of all tested machine learning methods considerably when used as the input features instead. As entity embedding defines a distance measure for categorical variables it can be used for visualizing categorical data and for data clustering.

研究动机与目标

  • 激发并演示神经网络在使用朴素编码处理高基数分类特征时所面临的困难。
  • 提出并形式化实体嵌入,作为对分类变量学习得到的密集表示。
  • 表明嵌入在稀疏数据上的泛化能力更强,并在作为输入特征时提升多种学习器的性能。
  • 展示嵌入在对分类数据进行理解方面的可视化与聚类能力。

提出的方法

  • 将每个分类值映射到一个密集向量(嵌入),在有监督训练期间与神经网络一起联合学习。
  • 将嵌入层视为对独热输入的线性变换,其中嵌入对应于层权重。
  • 将所有嵌入与连续输入拼接,并通过反向传播端到端训练。
  • 在真实数据集(Rossmann 商店销售量)上比较使用独热编码的神经网络与使用实体嵌入的神经网络。
  • 使用 10 个训练周期、Adam 优化,并通过集成预测来稳定结果。
  • 证明嵌入不仅提升神经网络的表现,而且在将嵌入用作特征时也能提升其他模型的性能。

实验结果

研究问题

  • RQ1实体嵌入能否学习到有意义、紧凑的分类变量表示,反映类别值之间的相似性?
  • RQ2与独热编码相比,嵌入是否能提升高基数分类特征的预测性能和泛化能力?
  • RQ3嵌入是否有助于对分类数据进行可视化和聚类?
  • RQ4为神经网络学习的嵌入在作为输入特征时,是否可以迁移以改善其他机器学习方法?

主要发现

  • 实体嵌入相较于独热编码在内存效率和速度方面有提升。
  • 在功能上相似的嵌入类别往往在嵌入空间中更接近。
  • 嵌入在稀疏数据和高基数特征上提升泛化能力,而其他方法容易过拟合。
  • 为神经网络学习的嵌入在作为输入时显著提升 KNN、随机森林和梯度提升树的性能。
  • 对嵌入的可视化(如 t-SNE)揭示了有意义的结构,例如州的地理聚类以及商店嵌入与销售额的连续变化。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。