Skip to main content
QUICK REVIEW

[논문 리뷰] Your Classifier is Secretly an Energy Based Model and You Should Treat it Like One

Will Grathwohl, Kuan-Chieh Wang|arXiv (Cornell University)|2019. 12. 06.
Anomaly Detection Techniques and Applications인용 수 119
한 줄 요약

이 논문은 표준 판별 분류기를 결합 에너지 기반 모델(JEM)로 재해석하고 p(x, y)와 p(x)를 모델링하도록 학습시켜, 판별적 성능과 생성적 성능을 경쟁력 있게 달성하며 보정, OOD 탐지 및 강건성의 향상을 이끈다.

ABSTRACT

We propose to reinterpret a standard discriminative classifier of p(y|x) as an energy based model for the joint distribution p(x,y). In this setting, the standard class probabilities can be easily computed as well as unnormalized values of p(x) and p(x|y). Within this framework, standard discriminative architectures may beused and the model can also be trained on unlabeled data. We demonstrate that energy based training of the joint distribution improves calibration, robustness, andout-of-distribution detection while also enabling our models to generate samplesrivaling the quality of recent GAN approaches. We improve upon recently proposed techniques for scaling up the training of energy based models and presentan approach which adds little overhead compared to standard classification training. Our approach is the first to achieve performance rivaling the state-of-the-artin both generative and discriminative learning within one hybrid model.

연구 동기 및 목표

  • 표준 분류기를 p(x, y)를 모델링하기 위한 결합 에너지 기반 모델로 재프레이밍한다.
  • 레이블이 없는 데이터에 대해서도 판별 성능을 보존하며 학습 가능하게 한다.
  • 에너지 기반 학습을 통해 보정, 강건성 및 out-of-distribution 탐지를 향상시킨다.
  • 단일 모델 내에서 판별 정확도와 함께 경쟁력 있는 생성 능력을 시현한다.

제안 방법

  • pθ(x,y) = exp(fθ(x)[y]) / Z(θ)로부터 classifier logits에서 p(x, y)를 정의하고 에너지 Eθ(x,y) = -fθ(x)[y]를 사용한다.
  • y에 대해 주변합을 취해 unnormalized p(x) = sum_y exp(fθ(x)[y]) / Z(θ) 를 얻는다.
  • p(x) 의 에너지 근사로 LogSumExp(fθ(x))를 사용하고 표준 cross-entropy를 통해 p(y|x)를 학습한다.
  • SGLD를 사용하여 모델 분포에서 샘플링하여 p(x)의 로그를 편향되지 않은 기울기로 추정해 학습한다.
  • log p(x)의 기울기 기대치를 추정하기 위해 persistent contrastive divergence를 사용한다.
  • 아키텍처는 Wide Residual Networks에 기초하고 CIFAR10, SVHN, CIFAR100에서 학습한다.

실험 결과

연구 질문

  • RQ1표준 분류기가 p(x, y)와 p(x)로의 결합 에너지 기반 모델로 해석될 수 있는가?
  • RQ2EBMs 기반 학습이 보정, OOD 탐지 및 적대적 강건성을 개선하면서도 판별 성능을 유지하는가?
  • RQ3대규모 이미지 데이터셋에서 결합된 EBMs가 판별 정확도와 함께 경쟁력 있는 생성 품질을 제공하는가?
  • RQ4SGLD 기반 샘플링의 규모가 학습 안정성과 성능에 어떤 영향을 미치는가?

주요 결과

  • JEM은 CIFAR10/SVHN/CIFAR100에서 경쟁력 있는 정확도를 달성하는 한편, 최첨단 하이브리드 모델에 버금가는 생성 능력도 제공한다.
  • JEM은 CIFAR100에서 보정성을 향상시키며 기대 보정 오차(ECE)로 측정할 때 거의 완벽한 보정에 근접한다.
  • JEM은 로그 p(x) 및 그래디언트 기반 매스 스코어를 포함한 여러 점수 방식을 사용한 OOD 탐지를 향상시키고 여러 베이스라인을 능가한다.
  • 모델은 표준 분류기보다 강건성 증가를 보이며, 특수한 강건한 방법에 근접한 수준으로 적대적 교란에 대해 더 큰 회복력을 보인다.
  • CIFAR10에서 JEM은 92.9% 정확도와 IS 8.76 및 FID 38.4를 달성하여 여러 하이브리드 베이스라인을 능가하면서도 판별 전용 모델과 경쟁력을 유지한다.

더 나은 연구,지금 바로 시작하세요

연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.

카드 등록 없음 · 무료 플랜 제공

이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.