Skip to main content
QUICK REVIEW

[論文レビュー] Dash: Semi-Supervised Learning with Dynamic Thresholding

Yi Xu, Lei Shang|arXiv (Cornell University)|Sep 1, 2021
Machine Learning and Data Classification被引用数 52
ひとこと要約

Dashは動的閾値設定機構を導入し、SSLトレーニング中にラベルなしデータを選択することで、各反復で使用する疑似ラベル付き例を適応させて性能を向上させ、理論的な収束保証を提供します。

ABSTRACT

While semi-supervised learning (SSL) has received tremendous attentions in many machine learning tasks due to its successful use of unlabeled data, existing SSL algorithms use either all unlabeled examples or the unlabeled examples with a fixed high-confidence prediction during the training progress. However, it is possible that too many correct/wrong pseudo labeled examples are eliminated/selected. In this work we develop a simple yet powerful framework, whose key idea is to select a subset of training examples from the unlabeled data when performing existing SSL methods so that only the unlabeled examples with pseudo labels related to the labeled data will be used to train models. The selection is performed at each updating iteration by only keeping the examples whose losses are smaller than a given threshold that is dynamically adjusted through the iteration. Our proposed approach, Dash, enjoys its adaptivity in terms of unlabeled data selection and its theoretical guarantee. Specifically, we theoretically establish the convergence rate of Dash from the view of non-convex optimization. Finally, we empirically demonstrate the effectiveness of the proposed method in comparison with state-of-the-art over benchmarks.

研究の動機と目的

  • SSLを固定の高信頼度閾値を避けることで改善する動機づけ。
  • 反復ごとに減少する損失閾値に基づいて未ラベルデータを選択する動的閾値フレームワーク(Dash)を提案する。
  • 非凸設定におけるDashアルゴリズムの理論的収束保証を提供する。
  • 画像分類ベンチマークで最先端のSSL手法に対するDashの実験的有効性を示す。

提案手法

  • Dashは動的閾値rho_t以下の損失を持つ例を維持して、更新ごとに未ラベルデータのサブセットを選択する。
  • 閾値rho_tはrho_t = C * gamma^{-(t-1)} * rho_hatとして設定され、反復ごとに低下する。
  • 初期のウォームアップ段階でラベル付きデータを用いてrho_hatを推定する;以降の選択段階ではFixMatchからの疑似ラベルを持つ未ラベルデータを使用する。
  • 確率的勾配は、unsupervised損失f_u(w; xi^u) <= rho_tを満たす未ラベル例のみと、ラベル付きデータの損失を組み合わせて計算する。
  • DashはFixMatchのような既存のSSLパイプラインと統合可能で、標準仮定(PL条件)の下で非漸近的収束保証を提供する。
  • 理論的結果は、非凸仮定の下でサンプル複雑性と収束速度を確立し、監視付きSGDに似た速度と一致する。

実験結果

リサーチクエスチョン

  • RQ1未ラベルデータが分布の混合から来る場合に、収束を保証できるSSLアルゴリズムを設計できるか。
  • RQ2減少する損失閾値を介して未ラベルデータを動的に選択することは、FixMatchのような固定閾値手法よりSSL性能を改善するか。
  • RQ3正しい疑似ラベルの包含と誤ったものの排除をバランスさせるために、動的閾値をどのように構築・推定すべきか。
  • RQ4このような動的閾値SSL手法の理論的収束保証とサンプル複雑性はどうなるか。

主な発見

アルゴリズムCIFAR-10 40ラベルCIFAR-10 250ラベルCIFAR-10 4000ラベルCIFAR-100 400ラベルCIFAR-100 2500ラベルCIFAR-100 10000ラベル
Pi-model------
Pseudo-Labeling------
Mean Teacher------
MixMatch47.54 ± 11.5011.05 ± 0.866.42 ± 0.1067.61 ± 1.3239.94 ± 0.3728.31 ± 0.33
UDA29.05 ± 5.938.82 ± 1.084.88 ± 0.1859.28 ± 0.8833.13 ± 0.2224.50 ± 0.25
ReMixMatch19.10 ± 9.645.44 ± 0.054.72 ± 0.1344.28 ± 2.0627.43 ± 0.3123.03 ± 0.56
RYS (UDA)-5.53 ± 0.174.75 ± 0.28---
RYS (FixMatch)-5.05 ± 0.124.35 ± 0.06---
FixMatch (CTA)11.39 ± 3.355.07 ± 0.334.31 ± 0.1549.95 ± 3.0128.64 ± 0.2423.18 ± 0.11
Dash (CTA, ours)9.16 ± 4.314.78 ± 0.124.13 ± 0.0644.83 ± 1.3627.85 ± 0.1922.77 ± 0.21
FixMatch (RA)13.81 ± 3.375.07 ± 0.654.26 ± 0.0548.85 ± 1.7528.29 ± 0.1122.60 ± 0.12
Dash (RA, ours)13.22 ± 3.754.56 ± 0.134.08 ± 0.0644.76 ± 0.9627.18 ± 0.2121.97 ± 0.14
  • Dashは提案された動的閾値SSLに対して非漸近的収束保証を非凸設定下で提供する。
  • 経験的に、Dashは標準の画像分類ベンチマーク(CIFAR-10、CIFAR-100、SVHN、STL-10)において、さまざまなラベル制限下で複数の最先端SSL手法より優れている。
  • Dashはトレーニング初期により多くの正しい疑似ラベル付き未ラベル例を維持し、後半エポックで固定閾値手法(FixMatch)より誤った例をより積極的に削減する。
  • 理論的結果はDashの高確率収束を示し、O(1/ε)の具体的なサンプル複雑性境界を提供する。
  • 異なるデータ拡張 regimes(CTA, RA)を用いた実験は、DashがFixMatchベースのパイプラインと互換性を持ち、競争力のある利得を示すことを示している。

より良い研究を、今すぐ始めましょう

論文設計から論文執筆まで、研究時間を劇的に削減しましょう。

クレジットカード登録不要

このレビューはAIが作成し、人間の編集者が確認しました。