Skip to main content
QUICK REVIEW

[논문 리뷰] Just Train Twice: Improving Group Robustness without Training Group Information

Evan Liu, Behzad Haghgoo|arXiv (Cornell University)|2021. 07. 19.
Machine Learning and Data Classification인용 수 70
한 줄 요약

Jtt는 표준 ERM 모델을 학습하여 고손실 예제를 식별한 다음, 잘못 분류된 예제를 두 번째 ERM 실행에서 가중치를 높여 worst-group 성능을 개선하고 그룹 라벨을 학습하지 않는 방식으로 그룹 DRO 성능에 근접합니다.

ABSTRACT

Standard training via empirical risk minimization (ERM) can produce models that achieve high accuracy on average but low accuracy on certain groups, especially in the presence of spurious correlations between the input and label. Prior approaches that achieve high worst-group accuracy, like group distributionally robust optimization (group DRO) require expensive group annotations for each training point, whereas approaches that do not use such group annotations typically achieve unsatisfactory worst-group accuracy. In this paper, we propose a simple two-stage approach, JTT, that first trains a standard ERM model for several epochs, and then trains a second model that upweights the training examples that the first model misclassified. Intuitively, this upweights examples from groups on which standard ERM models perform poorly, leading to improved worst-group performance. Averaged over four image classification and natural language processing tasks with spurious correlations, JTT closes 75% of the gap in worst-group accuracy between standard ERM and group DRO, while only requiring group annotations on a small validation set in order to tune hyperparameters.

연구 동기 및 목표

  • 스퓨리어스 상관관계로 인한 소수 집단에서 ERM 기반 모델의 실패 문제를 동기 부여한다.
  • 훈련 그룹 라벨 없이 최악 그룹 정확도를 높이기 위한 간단한 두 단계 방법(Just Train Twice)을 제안한다.
  • 스퓨리어스 상관관계가 있는 네 가지 데이터셋에서 실증적 개선을 보여준다.
  • 오류 집합이 무엇을 나타내는지와 하이퍼파라미터 튜닝을 위한 그룹 주석이 달린 소규모 검증 세트의 역할을 분석한다.]
  • method_for_review_list시스템_필드_오류_표현은_변환되었음
  • method_to_translate_제한없이_2단계_제공되며_주요_기법
  • Stage 1: Train an identification model via ERM for T steps and collect the error set E of misclassified training examples.
  • Stage 2: Train a final model on an upsampled dataset where examples in E are repeated lambda_up times, increasing their influence.
  • The final objective is J_up-ERM(θ,E) = lambda_up * sum_{(x,y) in E} l(x,y;θ) + sum_{(x,y) not in E} l(x,y;θ).
  • Hyperparameters include T (epochs for the identification model) and lambda_up (upweight factor); tuning uses worst-group validation accuracy.
  • Validation-guided tuning is recommended for hyperparameters across Jtt, CVaR DRO, and LfF.

제안 방법

  • Stage 1: ERM을 통해 식별 모델을 T 단계 학습하고 잘못 분류된 학습 예제의 오류 집합 E를 수집한다.
  • Stage 2: E에 속하는 예제를 lambda_up 배로 반복 재샘플링한 데이터셋에서 최종 모델을 학습시켜 이들의 영향력을 증가시킨다.
  • 최종 목적 함수는 J_up-ERM(θ,E) = lambda_up * sum_{(x,y) in E} l(x,y;θ) + sum_{(x,y) not in E} l(x,y;θ)이다.
  • 하이퍼파라미터로는 식별 모델의 에포크 수 T와 가중치 증가 인자 lambda_up가 포함되며, 튜닝은 worst-group 검증 정확도를 사용한다.
  • Jtt, CVaR DRO, 및 LfF 전반에 걸친 하이퍼파라미터 튜닝은 검증 기반으로 권장된다.

실험 결과

연구 질문

  • RQ1두 단계의 그룹 주석 없이도 다양한 작업에서 worst-group 정확도에서 그룹 DRO에 근접하는 방법이 가능한가?
  • RQ2첫 번째 ERM 모델이 식별한 잘못 분류된 예제가 학습 시점의 그룹 라벨 없이 어려운 그룹을 얼마나 잘 포착하는가?
  • RQ3하이퍼파라미터 튜닝이 worst-group 성능에 미치는 영향은 무엇이며 튜닝에 검증-그룹 정보가 필수적인가?
  • RQ4Jtt는 worst-group 성능과 평균 정확도에서 CVaR DRO 및 LfF와 어떻게 비교되는가?
  • RQ5다양한 그룹에 대한 Jtt 오류 집합의 구성을 어떻게 보이며 worst-group 예제에 대한 풍부성은 어떠한가?

주요 결과

MethodWaterbirds Avg Acc.Waterbirds Worst-group Acc.CelebA Avg Acc.CelebA Worst-group Acc.MultiNLI Avg Acc.MultiNLI Worst-group Acc.CivilComments Avg Acc.CivilComments Worst-group Acc.
ERM97.3%72.6%95.6%47.2%82.4%67.9%92.6%57.4%
CVaR DRO (Levy et al., 2020)96.0%75.9%82.5%64.4%82.0%68.0%92.5%60.5%
LfF (Nam et al., 2020)91.2%78.0%85.1%77.2%80.8%70.2%92.5%58.8%
Jtt (Ours)93.3%86.7%88.0%81.1%78.6%72.6%91.1%69.3%
Group DRO (Sagawa et al., 2020a)93.5%91.4%92.9%88.9%81.4%77.7%88.9%69.9%
  • Jtt는 Waterbirds, CelebA, MultiNLI, CivilComments-WILDS에서 ERM보다 일관되게 worst-group 정확도를 개선한다.
  • 평균적으로 Jtt는 ERM과 비교해 worst-group 정확도 격차를 약 16.2 퍼센트 포인트 감소시키고 그룹 DRO 대비 격차의 약 75%를 해소한다.
  • Jtt의 평균 정확도는 최상의 평균 정확도보다 약 4.2 포인트 낮은 수준으로 유리한 트레이드-오프를 보인다.
  • 오류 집합은 worst-group 예제에 대해 풍부하게 구성되어 worst group의 재현율이 평균 86.4%로 높고, 훈련에서의 발생 비율보다 높은 정밀도를 보인다.
  • 비훈련 그룹 기반 하이퍼파라미터 튜닝은 강한 worst-group 성능을 달성하기 위해 필수적이다.

더 나은 연구,지금 바로 시작하세요

연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.

카드 등록 없음 · 무료 플랜 제공

이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.