Skip to main content
QUICK REVIEW

[논문 리뷰] Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization

Shiori Sagawa, Pang Wei Koh|arXiv (Cornell University)|2019. 11. 20.
Domain Adaptation and Few-Shot Learning참고 문헌 58인용 수 364
한 줄 요약

이 논문은 과매개변수화된 신경망에서 그룹 DRO가 더 강력한 정규화(예: 강한 L2 또는 조기 중지)와 함께 사용할 때 최악의 그룹 일반화 성능을 개선하고, 그룹 DRO 모델 학습을 위한 수렴 보장이 있는 확장 가능한 온라인 알고리즘을 제시한다.

ABSTRACT

Overparameterized neural networks can be highly accurate on average on an i.i.d. test set yet consistently fail on atypical groups of the data (e.g., by learning spurious correlations that hold on average but not in such groups). Distributionally robust optimization (DRO) allows us to learn models that instead minimize the worst-case training loss over a set of pre-defined groups. However, we find that naively applying group DRO to overparameterized neural networks fails: these models can perfectly fit the training data, and any model with vanishing average training loss also already has vanishing worst-case training loss. Instead, the poor worst-case performance arises from poor generalization on some groups. By coupling group DRO models with increased regularization---a stronger-than-typical L2 penalty or early stopping---we achieve substantially higher worst-group accuracies, with 10-40 percentage point improvements on a natural language inference task and two image tasks, while maintaining high average accuracies. Our results suggest that regularization is important for worst-group generalization in the overparameterized regime, even if it is not needed for average generalization. Finally, we introduce a stochastic optimization algorithm, with convergence guarantees, to efficiently train group DRO models.

연구 동기 및 목표

  • i.i.d. 훈련에서 허위 상관관계가 비정형 그룹의 성능 저하를 야기한다는 문제를 동기부여한다.
  • 과매개변수화된 네트워크에 그룹 DRO를 naively 적용하는 것이 왜 최악의 그룹 일반화를 개선하지 못하는지 조사한다.
  • 강한 정규화가 그룹 DRO가 최악의 그룹 정확도에서 상당한 이득을 얻으면서 평균 정확도를 유지하도록 할 수 있음을 입증한다.
  • 수렴 보장이 있는 온라인 최적화 알고리즘을 제안하고 그 성능을 분석한다.

제안 방법

  • 그룹 DRO를 알려진 허위 상관관계로부터 정의된 그룹으로 형식화하고, 최악의 위험이 최대 그룹 위험과 같다는 것을 도출한다.
  • 과매개현 상황에서 ERM과 그룹 DRO 둘 다 제로 훈련 손실을 달성하면 최악의 그룹 테스트 성능이 나쁘다.
  • 강한 L2 페널티나 조기 중지와 같은 정규화 전략을 조사하여 완벽한 훈련 적합을 방지하고 최악의 그룹 일반화 격차를 줄인다.
  • 그룹 보정 DRO를 도입해 그룹별 일반화 격차 항 C/√ng를 추가하여 훈련 중 더 작은 그룹을 우선시한다.
  • 온라인 교대 그래디언트 알고리즘을 개발하여 θ를 SGD로 업데이트하고 그룹 분포 q를 지수 그래디언트 상승으로 업데이트하며, 볼록 설정에서 수렴 보장을 제공한다.

실험 결과

연구 질문

  • RQ1과매개변수화된 신경망에서 그룹 DRO가 최악의 그룹 일반화를 개선할 수 있으며, 어떤 정규화 조건에서 그런가?
  • RQ2강한 L2, 조기 중지 등 서로 다른 정규화 전략이 그룹 DRO에서 최악의 그룹 대 평균 성능에 어떤 영향을 미치는가?
  • RQ3그룹 크기를 기반으로 한 보정이 그룹별 일반화 격차를 고려하여 최악의 그룹 정확도를 더 높일 수 있는가?
  • RQ4제안된 온라인 학습 알고리즘이 안정적이고 수렴하는가, 이론적 보장은 무엇인가?
  • RQ5그룹 DRO를 그룹 이동하에서의 최악의 경우 강건성의 기준으로 중요도 가중치와 비교하면 어떤 차이가 있는가?

주요 결과

  • 과매개변수 모델은 표준 정규화 하에서 훈련 오차는 거의 0에 수렴하지만 최악의 그룹 테스트 성능은 나쁘며 Waterbirds, CelebA, MultiNLI에서 각각 최악의 그룹 정확도가 60.0%, 41.1%, 65.7%로 나타난다.
  • 강한 정규화(큰 L2 패널티 또는 조기 중지)는 그룹 DRO가 최악의 그룹 정확도를 상당히 높이고 평균 정확도도 유지하도록 한다(예: 강한 정규화 하에서 Waterbirds 84.6%, CelebA 86.7%).
  • 정규화를 가진 그룹 DRO는 최악의 경우 성능을 10–40포인트 향상시킨다.
  • 그룹별 일반화 격차를 고려하는 항(C/√ng)에 비례하는 보정을 도입하면 일부 상황에서 최악의 그룹 테스트 정확도가 더 향상된다(예: Waterbirds에서 5.9포인트 향상).
  • θ에 대해 SGD를 수행하고 그룹 가중치 분포 q에 대해 지수 그래디언트 업데이트를 교대로 수행하는 온라인 최적화 알고리즘은 볼록 설정에서 수렴 보장을 제공하고 대규모 모델/데이터셋으로 확장된다.

더 나은 연구,지금 바로 시작하세요

연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.

카드 등록 없음 · 무료 플랜 제공

이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.