[论文解读] Extreme Classification in Log Memory using Count-Min Sketch: A Case Study of Amazon Search with 50M Products
该论文提出MACH,一种新颖的极端分类框架,通过结合Count-Min Sketch与通用哈希技术,将大规模分类的内存消耗从O(K)降低至O(log K)。在包含4946万种产品的Amazon搜索数据集上进行训练,MACH在64亿参数下实现了SOTA级别的精确率与召回率,仅用单个p3.16x实例在35小时内完成训练——比先前方法快7–10倍,内存效率提升2–4倍。
In the last decade, it has been shown that many hard AI tasks, especially in NLP, can be naturally modeled as extreme classification problems leading to improved precision. However, such models are prohibitively expensive to train due to the memory bottleneck in the last layer. For example, a reasonable softmax layer for the dataset of interest in this paper can easily reach well beyond 100 billion parameters (> 400 GB memory). To alleviate this problem, we present Merged-Average Classifiers via Hashing (MACH), a generic $K$-classification algorithm where memory provably scales at $O(\log K)$ without any assumption on the relation between classes. MACH is subtly a count-min sketch structure in disguise, which uses universal hashing to reduce classification with a large number of classes to few embarrassingly parallel and independent classification tasks with a small (constant) number of classes. MACH naturally provides a technique for zero communication model parallelism. We experiment with 6 datasets; some multiclass and some multilabel, and show consistent improvement in precision and recall metrics compared to respective baselines. In particular, we train an end-to -end deep classifier on a private product search dataset sampled from Amazon Search Engine with 70 million queries and 49.46 million documents. MACH outperforms, by a significant margin, the state-of-the-art extreme classification models deployed on commercial search engines: Parabel and dense embedding models. Our largest model has 6.4 billion parameters and trains in less than 35 hrs on a single p3.16x machine. Our training times are 7-10x faster, and our memory footprints are 2-4x smaller than the best baselines. This training time is also significantly lower than the one reported by Google’s mixture of experts (MoE) language model on a comparable model size and hardware.
研究动机与目标
- 解决极端分类中的内存瓶颈问题,即数百万类别所需的Softmax层会消耗数百GB内存。
- 开发一种可扩展、内存高效的分类框架,同时保持高精确率与召回率,且不依赖于类别间关系的假设。
- 通过将大规模分类问题转化为独立的小类别任务,实现零通信模型并行训练。
- 在真实世界工业级大规模数据集上展示卓越的性能与效率,特别是在拥有5000万个以上类别的商品搜索任务中。
提出的方法
- MACH采用Count-Min Sketch结构,利用通用哈希将高维分类任务映射为一组更小、独立的分类问题。
- 每个哈希函数将原始的K类问题映射为固定大小的子问题,从而实现无需协调的并行训练与推理。
- 最终预测通过所有基于哈希的子分类器预测结果的平均值合并计算,既保留了模型表达能力,又显著降低了内存占用。
- 通过利用通用哈希与Sketch技术的集中性特性,理论上证明了内存复杂度可达到O(log K)。
- MACH支持带有极端输出层的深度神经网络端到端训练,可无缝集成至标准深度学习框架。
- 该架构支持零通信模型并行,因为每个子分类器可独立训练,无需参数同步。
实验结果
研究问题
- RQ1是否能在不牺牲模型准确率的前提下,实现类别数K的子线性内存复杂度,特别是O(log K)?
- RQ2在大规模商品搜索数据上,MACH相较于Parabel与密集嵌入方法等SOTA模型,在精确率、召回率与训练效率方面表现如何?
- RQ3MACH在超过4900万个类别的数据集上,能否在保持低内存占用与快速训练时间的同时实现良好扩展性?
- RQ4使用Count-Min Sketch结合通用哈希是否能有效实现极端分类中的零通信模型并行?
- RQ5在同等硬件与模型规模下,MACH是否能实现比Google的专家混合模型更快的训练速度与更低的内存使用?
主要发现
- MACH在包含7000万次查询与4946万篇文档的私有Amazon商品搜索数据集上,实现了SOTA级别的精确率与召回率。
- 最大规模的MACH模型(64亿参数)在单个p3.16x实例上训练时间不足35小时,显著优于基线模型的速度与效率。
- 与最佳现有基线(包括Parabel与密集嵌入模型)相比,MACH将内存占用降低了2–4倍,训练时间缩短了7–10倍。
- MACH的训练时间远低于Google专家混合语言模型在相似模型规模与硬件配置下的报告值。
- 该方法在六个多样化数据集上均表现出一致的性能提升,涵盖多类与多标签设置,验证了其通用性与鲁棒性。
- 理论上的O(log K)内存扩展性得到实证验证,确认了该方法在无需类别结构假设的前提下,可有效扩展至极端类别数量。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。