[논문 리뷰] Transformers Can Do Bayesian Inference
이 논문은 사전 분포에서 추출된 데이터 포인트의 집합 값 입력을 활용하여 인라인 학습을 통해 베이지안 추론을 수행하는 데 Transformer를 사용하는 Prior-Data Fitted Networks (PFNs)를 소개한다. PFNs는 MCMC와 NUTS에 비해 최대 200배 빠른 속도로 사후 예측 분포를 근사하며, 가우스 프로세스의 거의 완벽한 모방을 이룩하고, 표본 수가 적은 회귀, 소수의 샘플을 이용한 이미지 분류, 베이지안 신경망 등 다양한 작업에서 뛰어난 성능을 보인다.
Currently, it is hard to reap the benefits of deep learning for Bayesian methods, which allow the explicit specification of prior knowledge and accurately capture model uncertainty. We present Prior-Data Fitted Networks (PFNs). PFNs leverage in-context learning in large-scale machine learning techniques to approximate a large set of posteriors. The only requirement for PFNs to work is the ability to sample from a prior distribution over supervised learning tasks (or functions). Our method restates the objective of posterior approximation as a supervised classification problem with a set-valued input: it repeatedly draws a task (or function) from the prior, draws a set of data points and their labels from it, masks one of the labels and learns to make probabilistic predictions for it based on the set-valued input of the rest of the data points. Presented with a set of samples from a new supervised learning task as input, PFNs make probabilistic predictions for arbitrary other data points in a single forward propagation, having learned to approximate Bayesian inference. We demonstrate that PFNs can near-perfectly mimic Gaussian processes and also enable efficient Bayesian inference for intractable problems, with over 200-fold speedups in multiple setups compared to current methods. We obtain strong results in very diverse areas such as Gaussian process regression, Bayesian neural networks, classification for small tabular data sets, and few-shot image classification, demonstrating the generality of PFNs. Code and trained PFNs are released at https://github.com/automl/TransformersCanDoBayesianInference.
연구 동기 및 목표
- 딥 러닝을 활용하여 저자료 환경에서 베이지안 방법을 적용하는 데 어려움을 해결하고 효율적인 사후 근사화를 가능하게 한다.
- 가우스 프로세스와 베이지안 신경망과 같은 복잡한 모델에서 정확한 베이지안 추론이 불가능한 문제를 해결한다.
- 데이터를 샘플링할 수 있는 임의의 사전 분포를 사용하여 사후 근사화에 활용할 수 있는 일반적인 프레임워크를 개발한다.
- 해석적 형태나 복잡한 근사화 없이 사전 분포에 대한 샘플링 메커니즘만으로도 확장 가능하고, 미분 가능하며, 민감한 베이지안 추론을 가능하게 한다.
- 소수의 표본 데이터, 소수의 샘플 학습, 회귀 등 다양한 작업에서 이론적 성능과 신뢰도 측정이 우수한 방법의 효과성을 입증한다.
제안 방법
- PFNs는 사후 근사화를 집합 값 입력을 가진 지도 학습 분류 문제로 재정의한다: 각 학습 작업마다 사전에서 함수를 샘플링하고, (x, y) 쌍의 집합을 수집하며, 하나의 레이블을 마스킹하고, 모델이 마스킹된 레이블을 확률적으로 예측하도록 훈련시킨다.
- PFN의 입력은 (x, y) 쌍의 집합이며, 그 중 하나의 레이블이 마스킹된다; 모델은 Transformer 아키텍처의 어텐션 메커니즘을 사용하여 누락된 레이블에 대한 분포를 학습한다.
- PFN은 여러 샘플링된 작업에 대해 마스킹된 예측 목표를 최대 우도 기반으로 엔드 투 엔드로 훈련한다: −∑ log qθ(y_test|x_test, D_train)를 최소화한다.
- 추론 도중에는 실제 데이터 세트 D_train과 테스트 포인트 x_test가 훈련된 PFN에 입력되며, 단일 순전파를 통해 전체 예측 분포 qθ*(y_test|x_test, D_train)를 출력한다.
- 연속적인 출력에서 적절한 불확실성 추정을 가능하게 하기 위해 회귀 작업을 위한 새로운 예측 분포가 도입된다.
- 이 방법은 민감한 사전 분포를 지원한다: 데이터를 샘플링할 수 있는 임의의 분포를 사용할 수 있으며, 해석이 어려운 복잡한 사전 분포인 베이지안 신경망이나 가우스 프로세스도 포함된다.
실험 결과
연구 질문
- RQ1해석적 형태가 필요 없이 베이지안 추론에서 사후 예측 분포를 근사하는 데에 Transformer를 효과적으로 사용할 수 있는가?
- RQ2Transformer의 인라인 학습 능력을 다양한 저자료 기계학습 작업에서의 베이지안 추론에 얼마나 효과적으로 활용할 수 있는가?
- RQ3정확도, 속도, 캘리브레이션 측면에서 PFNs의 성능이 MCMC와 NUTS, SVI와 Bayes-by-Backprop를 사용한 기존 기준 모델에 비해 어떻게 되는가?
- RQ4PFNs는 해석이 어려운 복잡한 사전 분포를 포함한 다양한 종류의 사전 분포에 일반화될 수 있는가?
- RQ5어느 정도의 아키텍처 선택(예: 어텐션 헤드, 위치 인코딩, 활성화 함수)이 사후 근사에서 PFNs의 성능에 영향을 미치는가?
주요 결과
- PFNs는 가우스 프로세스 예측을 거의 완벽하게 모방하며, Dionis 데이터셋에서 평균 AUC 0.981, jannis 데이터셋에서 0.996을 기록하여 모든 기준 모델을 초월한다.
- PFN-BNN 모델은 30개의 훈련 샘플을 가진 21개의 표본 데이터셋에서 평균 AUC 0.855를 기록하며, T-BNN(0.654) 및 기타 기준 모델보다 정확도와 캘리브레이션 측면에서 뚜렷한 우월성을 보였다.
- MCMC와 NUTS에 비해 추론 시간에서 최대 200배의 속도 향상을 보였으며, 동일한 벤치마크에서 NUTS가 12시간 이상 소요되는 데 비해 GPU에서 PFNs는 단 13초만에 추론을 완료했다.
- 강력한 캘리브레이션 성능을 기록하였으며, PFN-BNN의 경우 기대 캘리브레이션 오차(ECE)가 0.025로, 기준 로지스틱 회귀 모델의 0.157보다 훨씬 낮았다.
- Omniglot에서의 소수의 샘플 이미지 분류 작업에서 PFNs는 평균 AUC 0.981을 기록하여 KNN(0.871)과 CatBoost(0.945)를 모두 압도하며, 소수의 샘플 설정에서 강력한 일반화 능력을 입증했다.
- PFN 프레임워크는 다양한 사전 분포로 일반화된다: 단지 사전 분포에 대한 샘플링 메커니즘만으로도 가우스 프로세스, 베이지안 신경망, 기타 해석이 어려운 모델의 사후 분포를 성공적으로 근사한다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.