[논문 리뷰] Self-Attention Between Datapoints: Going Beyond Individual Input-Output Pairs in Deep Learning
본 논문은 전체 데이터셋을 입력으로 받아 데이터포인트 간 자체 어텐션을 이용해 상호 간 관계를 학습하는 Non-Parametric Transformers(NPTs)를 제안한다. 이를 통해 데이터포인트 간 조회가 가능해지고 표 형식(tabular) 데이터와 이미지 데이터의 예측 성능이 향상된다.
We challenge a common assumption underlying most supervised deep learning: that a model makes a prediction depending only on its parameters and the features of a single input. To this end, we introduce a general-purpose deep learning architecture that takes as input the entire dataset instead of processing one datapoint at a time. Our approach uses self-attention to reason about relationships between datapoints explicitly, which can be seen as realizing non-parametric models using parametric attention mechanisms. However, unlike conventional non-parametric models, we let the model learn end-to-end from the data how to make use of other datapoints for prediction. Empirically, our models solve cross-datapoint lookup and complex reasoning tasks unsolvable by traditional deep learning models. We show highly competitive results on tabular data, early results on CIFAR-10, and give insight into how the model makes use of the interactions between points.
연구 동기 및 목표
- 지도 학습에서 매개변수 의존성 가정을 의문시한다.
- 전체 데이터셋을 예측에 활용하는 일반적 아키텍처(NPTs)를 제안한다.
- 어텐션 메커니즘을 통해 데이터포인트 간 상호작용을 엔드 투 엔드로 학습할 수 있게 한다.
- 표 형식 데이터와 이미지 데이터셋에서 데이터포인트 간 조회 및 추론을 시연한다.
제안 방법
- NPT에 전체 데이터셋(X)과 마스킹 행렬(M)을 입력하여 마스킹된 값 p(X^M | X^O)의 재구성을 가능하게 한다.
- 데이터포인트 간 대체적 어텐션(ABD)과 속성 간 어텐션(ABA)을 적용해 데이터포인트 간 관계와 각 데이터포인트 변환을 모델링한다.
- 트랜스포머 계열 아키텍처를 따라 잔차 연결과 계층 정규화를 갖춘 다중헤드 자기어텐션을 사용한다.
- BERT에서 영감을 얻은 마스킹 목표를 사용하여 대상 손실과 보조 특징 마스킹 손실을 결합한 마스킹 목표로 학습한다: L^NPT = (1-λ)L^Targets + λL^Features.
- 교차 포인트 어텐션을 가능하게 하도록 학습/테스트 데이터를 같은 배치에 유지하며 미니배치로 대용량 데이터셋을 처리한다.
실험 결과
연구 질문
- RQ1NPT가 표준 감독 학습 벤치마크에서 경쟁력 있는 성능을 달성할 수 있는가?
- RQ2이상적 교차 포인트 조회 작업에서 데이터포인트 간 어텐션을 활용해 예측 학습이 가능할까?
- RQ3실세계 데이터 예측에서도 데이터포인트 간 상호작용에 실제로 의존하는가?
- RQ4NPT를 사용할 때 예측에 가장 관련성이 높은 데이터포인트의 종류는 무엇인가?
주요 결과
- NPT는 UCI 벤치마크의 이진 및 다중 클래스 분류 작업에서 가장 높은 평균 순위를 기록하며 여러 부스팅 방법을 능가한다.
- 회귀 작업에서 NPT는 XGBoost와 동률의 최상 평균 순위를 기록하며 CatBoost에만 밀린다.
- CIFAR-10은 CNN+ABD 아키텍처로 93.7%의 테스트 정확도; MNIST는 선형 패칭으로 98.3%에 도달한다.
- 반합성 단백질 회귀 작업에서 NPT는 중복된 행에서 대상 값을 조회하여 거의 완벽한 상관관계(r = 99.9%)를 얻는다.
- 오염 실험은 다른 데이터포인트를 무작위화하면 예측 성능이 손실됨을 보여주며, 이는 실제 데이터에서 데이터포인트 간 상호작용에 의존한다는 것을 시사한다(데이터세트에 따라 다름).
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.