Skip to main content
QUICK REVIEW

[論文レビュー] Just Train Twice: Improving Group Robustness without Training Group Information

Evan Liu, Behzad Haghgoo|arXiv (Cornell University)|Jul 19, 2021
Machine Learning and Data Classification被引用数 70
ひとこと要約

Jtt は標準的な ERM モデルを訓練して高損失の例を識別し、次にその誤分類例を第2の ERM 実行で重みづけして扱うことで、グループ DRO の性能に近づきつつ最悪グループの性能を改善します。訓練時にグループラベルを学習せずに実現します。

ABSTRACT

Standard training via empirical risk minimization (ERM) can produce models that achieve high accuracy on average but low accuracy on certain groups, especially in the presence of spurious correlations between the input and label. Prior approaches that achieve high worst-group accuracy, like group distributionally robust optimization (group DRO) require expensive group annotations for each training point, whereas approaches that do not use such group annotations typically achieve unsatisfactory worst-group accuracy. In this paper, we propose a simple two-stage approach, JTT, that first trains a standard ERM model for several epochs, and then trains a second model that upweights the training examples that the first model misclassified. Intuitively, this upweights examples from groups on which standard ERM models perform poorly, leading to improved worst-group performance. Averaged over four image classification and natural language processing tasks with spurious correlations, JTT closes 75% of the gap in worst-group accuracy between standard ERM and group DRO, while only requiring group annotations on a small validation set in order to tune hyperparameters.

研究の動機と目的

  • ERM ベースのモデルが虚偽の相関により少数派グループで性能を落とす問題を動機づける。
  • 訓練グループのラベルなしで最悪グループの精度を向上させる単純な二段階法(Just Train Twice)を提案する。
  • 虚偽の相関を伴う4つのデータセットで実証的な改善を示す。
  • エラー集合が何を表すのか、そしてハイパーパラメータ調整のためにグループ注釈付きの小さな検証セットの役割を分析する。

提案手法

  • ステージ1: ERM による識別モデルを T 回訓練し、誤分類した訓練例のエラー集合 E を収集する。
  • ステージ2: E に含まれる例を lambda_up 回繰り返すことでアップサンプリングしたデータセット上で最終モデルを訓練し、影響力を高める。
  • 最終目的関数は J_up-ERM(θ,E) = lambda_up * sum_{(x,y) in E} l(x,y;θ) + sum_{(x,y) not in E} l(x,y;θ)。
  • ハイパーパラメータには識別モデルのエポック数 T とアップウェイト係数 lambda_up が含まれる;チューニングは worst-group バリデーション精度を用いて行う。
  • Jtt、CVaR DRO、LfF の各ハイパーパラメータ調整には検証用のグループ情報を用いたチューニングを推奨する。

実験結果

リサーチクエスチョン

  • RQ12段階のグループ注釈なし手法が diverse なタスクにおいて最悪グループの精度で group DRO とのギャップを埋めることができるか?
  • RQ2第一の ERM モデルで識別された誤分類例は、訓練時のグループラベルなしで hard グループをどれくらいうまく捉えられるか?
  • RQ3ハイパーパラメータ調整が最悪グループの性能に与える影響はどうか、検証グループ情報は調整に必須か?
  • RQ4Jtt は最悪グループの性能と平均精度の点で CVaR DRO および LfF とどう比較されるか?
  • RQ5Jtt のエラー集合は異なるグループの構成と最悪グループの例の濃縮度にどのように寄与しているか?

主な発見

MethodWaterbirds Avg Acc.Waterbirds Worst-group Acc.CelebA Avg Acc.CelebA Worst-group Acc.MultiNLI Avg Acc.MultiNLI Worst-group Acc.CivilComments Avg Acc.CivilComments Worst-group Acc.
ERM97.3%72.6%95.6%47.2%82.4%67.9%92.6%57.4%
CVaR DRO (Levy et al., 2020)96.0%75.9%82.5%64.4%82.0%68.0%92.5%60.5%
LfF (Nam et al., 2020)91.2%78.0%85.1%77.2%80.8%70.2%92.5%58.8%
Jtt (Ours)93.3%86.7%88.0%81.1%78.6%72.6%91.1%69.3%
Group DRO (Sagawa et al., 2020a)93.5%91.4%92.9%88.9%81.4%77.7%88.9%69.9%
  • Jtt は Waterbirds、CelebA、MultiNLI、CivilComments-WILDS の各データセットで ERM よりも一貫して最悪グループの精度を改善する。
  • 平均して、Jtt は ERM と比較して最悪グループの精度ギャップを約16.2ポイント減少させ、group DRO へのギャップを約75%解消する。
  • Jtt の平均精度は最良の平均精度より約4.2ポイント低い程度であり、有利なトレードオフを示す。
  • エラー集合は最悪グループの例で濃縮されており、最悪グループのリコールが高く(データセット全体の平均86.4%)、トレーニング時の出現率よりも高い精度を示す。
  • ハイパーパラメータ調整を最悪グループ検証精度に基づいて行うことは、訓練グループ注釈を持たない方法が強力な最悪グループ性能を達成するために重要である。

より良い研究を、今すぐ始めましょう

論文設計から論文執筆まで、研究時間を劇的に削減しましょう。

クレジットカード登録不要

このレビューはAIが作成し、人間の編集者が確認しました。