Skip to main content
QUICK REVIEW

[논문 리뷰] Gradient Boosting Neural Networks: GrowNet

Sarkhan Badirli, Xuanqing Liu|arXiv (Cornell University)|2020. 02. 19.
Advanced Neural Network Applications참고 문헌 27인용 수 54
한 줄 요약

GrowNet은 얕은 신경망을 약한 학습기로 활용하는 그래디언트 부스팅 프레임워크에 보정 단계와 2차 통계를 도입하여 분류, 회귀, 랭킹 학습에서 최첨단 성능을 달성합니다.

ABSTRACT

A novel gradient boosting framework is proposed where shallow neural networks are employed as ``weak learners''. General loss functions are considered under this unified framework with specific examples presented for classification, regression, and learning to rank. A fully corrective step is incorporated to remedy the pitfall of greedy function approximation of classic gradient boosting decision tree. The proposed model rendered outperforming results against state-of-the-art boosting methods in all three tasks on multiple datasets. An ablation study is performed to shed light on the effect of each model components and model hyperparameters.

연구 동기 및 목표

  • 신경망을 약한 학습기로 활용하는 unified 그래디언트-부스팅 패러다임을 의사 결정 트리 대신 목표로 삼는다.
  • 2차 통계와 보정 단계를 통해 안정성과 작업 특성 튜닝을 개선한 효율적인 off-the-shelf 학습 알고리즘을 개발한다.
  • 여러 실제 데이터 세트에서 분류, 회귀, 학습-대-랭크에 대한 GrowNet의 적용 가능성과 우수성을 입증한다.
  • 구성요소(2차 통계, 보정 단추, 동적 부스팅) 및 하이퍼파라미터의 영향에 대한 절삭 연구를 제공한다.

제안 방법

  • 모델을 얕은 신경망의 가법 앙상블로 표현한다: ŷ_i = sum_{k=1}^K α_k f_k(x_i).
  • 각 약한 학습기 f_t를 잔차 g_i와 h_i에 대한 2차 Newton-Raphson 단계로 손실을 테일러 확장하여 최소화하여 학습시킨다.
  • 다음 학습기의 스택된 특징 집합을 형성하기 위해 이전 약한 학습기의 펜ultimate-레이어 특징으로 입력을 보강한다.
  • 추가된 모든 학습기를 원래 입력에서 재학습시키고 α_t를 업데이트하여 상호 학습자 상관관계를 감소시키는 보정 단계(C/S)를 도입한다.
  • 3가지 작업(회귀, 분류, 랭킹)에 맞춘 최적화를 가능하게 하는 각 약한 학습기의 목표를 2차 통계로 형성한다.
  • 보정 단계 동안 동적 부스팅 비율 α_t를 적용하여 하이퍼파라미터 튜닝을 용이하게 한다.

실험 결과

연구 질문

  • RQ1얕은 신경망을 약한 학습기로 사용하는 그래디언트 부스팅이 일반적인 GBDT 방법 및 깊은 신경망을 넘어서는 성능을 보일 수 있는가?
  • RQ22차 통계와 보정 단계를 도입하면 분류, 회귀, 학습-대-랭크 전반에서 안정성, 수렴성 및 일반화가 개선되는가?
  • RQ3다양한 데이터 세트에서 GrowNet이 XGBoost, AdaNet, 및 딥 뉴럴 네트워크와 비교했을 때 성능, 학습 시간, 튜닝 노력 면에서 어떤 차이가 있는가?

주요 결과

데이터셋 / 작업지표XGBoostGrowNet (pairwise)GrowNet (Gen. I div.)
MSLR-WEB 10KNDCG@50.4677(0.0287)0.5106(0.0011)0.5044(0.0072)
MSLR-WEB 10KNDCG@100.4858(0.0245)0.5203(0.0015)0.5137(0.0070)
Yahoo LTRNDCG@50.76180.77260.7713(0.0006)
Yahoo LTRNDCG@100.79130.81010.8088(0.0005)
  • GrowNet은 Microsoft Learning to Rank(MSLR-WEB 10K) 및 Yahoo LTR 데이터 세트에서 XGBoost 및 GrowNet 변형들(pairwise 및 일반화된 I-다이버전스 손실)보다 더 우수한 NDCG@5 및 NDCG@10를 달성한다.
  • MSLR-WEB 10K에서 NDCG@5는 XGBoost에서 0.4677(0.0287)에서 GrowNet(pairwise)에서 0.5106(0.0011)로, NDCG@10은 0.4858(0.0245)에서 0.5203(0.0015)로 향상된다.
  • Yahoo LTR에서 GrowNet(pairwise)로 NDCG@5가 0.7618에서 0.7726으로, NDCG@10이 0.7913에서 0.8101로 향상된다.
  • GrowNet은 또한 Higgs, CT 슬라이스 로컬라이제이션 및 YearPredictionMSD 데이터 세트에서 베이스라인 대비 회귀에서 RMSE를, 분류에서 AUC를 경쟁력 있게 달성한다.
  • 삭감 연구(ablation studies)에서 보정 단계와 2차 통계가 측정 가능한 이점을 제공하며, 펜ultimate-레이어 특징을 활용한 스택형 특징 접근이 특히 랭킹 작업에서 성능을 향상시킨다.
  • 30개의 얕은 학습기(두 개의 은닉층MLP)로 구성된 GrowNet은 더 깊은 DNN 스택과 견주어도 학습 속도가 빠르고 튜닝이 덜 필요하다.

더 나은 연구,지금 바로 시작하세요

연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.

카드 등록 없음 · 무료 플랜 제공

이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.