[논문 리뷰] TabTransformer: Tabular Data Modeling Using Contextual Embeddings
TabTransformer 는 Transformer 층의 맥락 임베딩(contextual embeddings)을 사용하여 표 형식 데이터를 모델링하고, ML 베이스라인 대비 우수한 정확도와 GBDT에 필적하며, 누락/노이즈 데이터에 강건하고 두 단계의 반지도학습(pre-training) 접근법을 사용한다.
We propose TabTransformer, a novel deep tabular data modeling architecture for supervised and semi-supervised learning. The TabTransformer is built upon self-attention based Transformers. The Transformer layers transform the embeddings of categorical features into robust contextual embeddings to achieve higher prediction accuracy. Through extensive experiments on fifteen publicly available datasets, we show that the TabTransformer outperforms the state-of-the-art deep learning methods for tabular data by at least 1.0% on mean AUC, and matches the performance of tree-based ensemble models. Furthermore, we demonstrate that the contextual embeddings learned from TabTransformer are highly robust against both missing and noisy data features, and provide better interpretability. Lastly, for the semi-supervised setting we develop an unsupervised pre-training procedure to learn data-driven contextual embeddings, resulting in an average 2.1% AUC lift over the state-of-the-art methods.
연구 동기 및 목표
- 범주형 특징에 대한 맥락 임베딩을 학습함으로써 표 형식 데이터에서 MLP와 그래디언트 부스팅 의사결정 트리(GBDT) 간의 성능 차이를 줄인다.
- 트랜스포머 기반의 자기 주의(self-attention)를 활용하여 열 임베딩을 맥락 표현으로 변환하고 예측 정확도를 향상시킨다.
- 누락되거나 노이즈가 있는 범주형 특징에 대한 강건성을 입증하고 학습된 임베딩의 해석가능성을 제공한다.
- 레이블링된 데이터가 부족할 때 성능을 향상시키기 위한 두 단계 반지도학습 파이프라인(비레이블 데이터로의 프리트레이닝, 그다음 라벨링 데이터로의 파인튜닝)을 제안한다.
제안 방법
- 각 범주형 특징을 누락 값 임베딩을 포함하는 전용 열 임베딩 테이블로 임베드한다.
- 임베딩 시퀀스를 N개의 트랜스포머 층을 통과시키며(멀티헤드 자기 주의 뒤에 피드포워드 블록).
- 최상위 트랜스포머 층의 맥락 임베딩을 연속형 특징과 연결(concatenate)하고 이를 최종 예측을 위한 MLP로 입력한다.
- 선택적으로 비레이블 데이터에서 MLM(마스킹된 언어 모델링) 또는 RTD(대체 토큰 탐지) 작업을 사용하여 트랜스포머 층을 프리트레이닝한 다음 라벨링 데이터로 파인튜닝한다.
- 표준 지도 학습 손실을 최소화하도록 그래디언트 기반 학습으로 엔드-투-엔드 최적화한다(분류는 교차 엔트로피, 회귀는 평균 제곱 오차).
- 반지도 설정에서는 두 단계 워크플로우를 수행한다: (i) 비레이블 데이터에서 프리트레이닝, (ii) 라벨링 데이터에서 파인튜닝.
실험 결과
연구 질문
- RQ1범주형 특징에 대한 트랜스포머 기반 맥락 임베딩이 표 형식 데이터에서 전통적 MLP를 능가할 수 있는가?
- RQ2맥락 임베딩이 기본 신경망 모델에 비해 누락 및 노이즈가 있는 범주형 특징에 대한 강건성을 제공하는가?
- RQ3다양한 데이터셋에 걸쳐 TabTransformer가 트리 기반 모델(GBDT) 및 다른 딥 표 형식 모델과 비교해 어떻게 수행하는가?
- RQ4제2의 두 단계 반지도 학습 프리트레이닝/파인튜닝 파이프라인이 한정된 라벨링 데이터에서 측정 가능한 AUC 증가를 가져오는가?
주요 결과
| 모델 이름 | 평균 AUC (%) | 표준편차 (%) |
|---|---|---|
| TabTransformer | 82.8 | 0.4 |
| MLP | 81.8 | 0.4 |
| GBDT | 82.9 | 0.4 |
| Sparse MLP | 81.4 | 0.4 |
| Logistic Regression | 80.4 | 0.4 |
| TabNet | 77.1 | 0.5 |
| VIB | 80.5 | 0.4 |
- TabTransformer는 15개 데이터셋 중 14개에서 기준 MLP를 넘었으며 평균 1.0% AUC 이득을 기록했다.
- TabTransformer는 지도 학습에서 GBDT와 동등하거나 근접하게 경쟁하며, 여러 딥 표 형식 기준선 모델(TabNet, VIB 등)을 능가한다.
- 맥락 임베딩은 트랜스포머 층을 거치며 더 예측력이 높아져 임베딩에 선형 모델을 적용해 엔드투엔드 성능에 근접하게 한다.
- 모델은 노이즈가 많거나 결측이 증가함에 따라 MLP를 능가하며 누락/잡음에 대한 강건성을 보인다.
- 반지도 설정에서 TabTransformer-RTD/MLM 프리트레이닝은 비레이블 데이터가 풍부할 때 경쟁자들 대비 의미 있는 AUC 증가(평균 최대 2.1% 승)를 가져온다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.