[논문 리뷰] Extreme Classification in Log Memory using Count-Min Sketch: A Case Study of Amazon Search with 50M Products
이 논문은 대규모 분류에서 메모리 스케일링을 O(K)에서 O(log K)로 줄이기 위해 유니버설 해싱을 사용하는 Count-Min 스킴을 통합한 새로운 극단적 분류 프레임워크인 MACH를 소개한다. 4946만 개의 제품을 포함한 아마존 검색 데이터셋에서 훈련된 MACH는 64억 개의 파라미터를 가진 상태에서 35시간 이내에 훈련되며, 기존 방법보다 7~10배 빠르고 2~4배 더 메모리 효율적인 성능을 기록했다. 이는 단일 p3.16x 인스턴스에서 수행된 결과이다.
In the last decade, it has been shown that many hard AI tasks, especially in NLP, can be naturally modeled as extreme classification problems leading to improved precision. However, such models are prohibitively expensive to train due to the memory bottleneck in the last layer. For example, a reasonable softmax layer for the dataset of interest in this paper can easily reach well beyond 100 billion parameters (> 400 GB memory). To alleviate this problem, we present Merged-Average Classifiers via Hashing (MACH), a generic $K$-classification algorithm where memory provably scales at $O(\log K)$ without any assumption on the relation between classes. MACH is subtly a count-min sketch structure in disguise, which uses universal hashing to reduce classification with a large number of classes to few embarrassingly parallel and independent classification tasks with a small (constant) number of classes. MACH naturally provides a technique for zero communication model parallelism. We experiment with 6 datasets; some multiclass and some multilabel, and show consistent improvement in precision and recall metrics compared to respective baselines. In particular, we train an end-to -end deep classifier on a private product search dataset sampled from Amazon Search Engine with 70 million queries and 49.46 million documents. MACH outperforms, by a significant margin, the state-of-the-art extreme classification models deployed on commercial search engines: Parabel and dense embedding models. Our largest model has 6.4 billion parameters and trains in less than 35 hrs on a single p3.16x machine. Our training times are 7-10x faster, and our memory footprints are 2-4x smaller than the best baselines. This training time is also significantly lower than the one reported by Google’s mixture of experts (MoE) language model on a comparable model size and hardware.
연구 동기 및 목표
- 수백 GB의 메모리가 필요한 수백만 개의 클래스를 위한 소프트맥스 레이어로 인해 발생하는 메모리 병목 현상을 해결한다.
- 클래스 간 관계에 대한 가정 없이도 높은 정밀도와 재현율을 유지하면서도 확장성 있고 메모리 효율적인 분류 프레임워크를 개발한다.
- 대규모 분류 문제를 독립적이고 소규모 클래스 문제로 분해함으로써 분산 훈련에서 통신이 없는 모델 병렬화를 가능하게 한다.
- 특히 5000만 개 이상의 클래스를 포함하는 실세계 산업 규모의 데이터셋에서 뛰어난 성능과 효율성을 입증한다.
제안 방법
- MACH는 유니버설 해싱을 사용하여 고차원 분류 문제를 일련의 더 작은 독립 분류 문제로 매핑하는 Count-Min 스킴 아키텍처를 활용한다.
- 각 해싱 함수는 원래의 K개 클래스 문제를 일정 크기의 하위 문제로 매핑함으로써, 간단한 병렬 훈련과 추론을 가능하게 한다.
- 최종 예측은 모든 해싱 기반 하위 분류기의 예측을 평균화하여 병합함으로써 모델의 표현력을 유지하면서도 메모리 사용량을 줄인다.
- 유니버설 해싱과 스킴 기법의 집중 성질을 활용하여, 이론적으로 O(log K)의 메모리 스케일링을 입증한다.
- 심층 신경망의 극단적 출력 레이어와 함께 엔드 투 엔드 훈련이 가능하며, 표준 딥러닝 프레임워크와 원활하게 통합된다.
- 각 하위 분류기가 상호 동기화 없이 독립적으로 훈련될 수 있으므로, 통신이 없는 모델 병렬화가 가능하다.
실험 결과
연구 질문
- RQ1클래스 수에 대해 비선형적인 메모리 복잡도, 특히 O(log K)로 극단적 분류를 수행하면서도 모델 정확도를 손상시키지 않을 수 있는가?
- RQ2대규모 제품 검색 데이터에서 Parabel 및 밀도 기반 임베딩 방법과 같은 최신 기술 모델과 비교해 MACH는 정밀도, 재현율 및 훈련 효율성 측면에서 어떻게 성과를 내는가?
- RQ34900만 개 이상의 클래스를 포함하는 데이터셋에 대해 MACH는 낮은 메모리 사용량과 짧은 훈련 시간을 유지하면서 얼마나 잘 스케일링되는가?
- RQ4유니버설 해싱을 사용한 Count-Min 스킴이 극단적 분류에서 효과적인 통신이 없는 모델 병렬화를 가능하게 하는가?
- RQ5유사한 하드웨어 및 모델 크기 조건에서 MACH는 구글의 믹스처 오브 익스퍼트 모델보다 더 빠른 훈련 속도와 낮은 메모리 사용량을 달성할 수 있는가?
주요 결과
- MACH는 7000만 개의 쿼리와 4946만 개의 문서를 포함한 비공개 아마존 제품 검색 데이터셋에서 최신 기술(SOTA) 수준의 정밀도와 재현율을 달성했다.
- 64억 개의 파라미터를 가진 가장 큰 MACH 모델은 단일 p3.16x 인스턴스에서 35시간 이내에 훈련되었으며, 기준 모델들보다 빠른 속도와 높은 효율성을 확보했다.
- Parabel 및 밀도 기반 임베딩 모델을 포함한 최고의 기존 기준 모델들보다 MACH는 메모리 사용량을 2~4배 줄이고, 훈련 시간을 7~10배 단축시켰다.
- 유사한 모델 크기와 하드웨어 구성에서 구글의 믹스처 오브 익스퍼트 언어 모델의 보고된 훈련 시간보다 MACH의 훈련 시간은 상당히 짧았다.
- 다양한 다중 클래스 및 다중 레이블 설정을 포함한 여섯 가지 다양한 데이터셋에서 일관된 성능 향상이 관찰되어, MACH의 일반성과 강건성을 입증했다.
- 이론적 메모리 스케일링 O(log K)가 실증적으로 검증되었으며, 클래스 구조에 대한 가정 없이도 극단적인 클래스 수에 대한 스케일링 가능성을 확인했다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.