[論文レビュー] DADA: Differentiable Automatic Data Augmentation
DADAは、偏りのないRELAX勾配推定量を用いたデータ増強ポリシーの学習を可能にする微分可能な枠組みを提示し、従来のAuto-DA手法より少なくとも1桁の高速な探索を達成しつつ精度を競争力のある水準で維持します。ポリシーサンプリングにはGumbel-Softmax緩和を用い、ネットワークと増強パラメータを共同で学習するワンパスの階層最適化を行います。
Data augmentation (DA) techniques aim to increase data variability, and thus train deep networks with better generalisation. The pioneering AutoAugment automated the search for optimal DA policies with reinforcement learning. However, AutoAugment is extremely computationally expensive, limiting its wide applicability. Followup works such as Population Based Augmentation (PBA) and Fast AutoAugment improved efficiency, but their optimization speed remains a bottleneck. In this paper, we propose Differentiable Automatic Data Augmentation (DADA) which dramatically reduces the cost. DADA relaxes the discrete DA policy selection to a differentiable optimization problem via Gumbel-Softmax. In addition, we introduce an unbiased gradient estimator, RELAX, leading to an efficient and effective one-pass optimization strategy to learn an efficient and accurate DA policy. We conduct extensive experiments on CIFAR-10, CIFAR-100, SVHN, and ImageNet datasets. Furthermore, we demonstrate the value of Auto DA in pre-training for downstream detection problems. Results show our DADA is at least one order of magnitude faster than the state-of-the-art while achieving very comparable accuracy. The code is available at https://github.com/VDIGPKU/DADA.
研究の動機と目的
- ラベル付きデータが限られている場合に一般化を改善するため、自動データ拡張(DA)ポリシー学習を動機づける。
- ネットワーク重みと同時最適化を可能にするため、DAポリシー探索の微分可能な定式化を提案する。
- AutoAugment、PBA、Fast AutoAugmentと比較してDAポリシー探索の計算コストを削減する。
提案手法
- 部分ポリシーの選択をカテゴリ分布で表現し、操作適用をベルヌーイ分布で表現する。
- 微分可能な最適化のために、離散ポリシー選択をGumbel-Softmaxで緩和する。
- 分布パラメータの無偏勾配を得るためにRELAX勾配推定器を用いる。
- ネットワーク重みとDAポリシーパラメータを同時に更新するためにワンパスの階層型最適化を適用する。
- ストレートスルー勾配推定と勾配ベースのバックプロパゲーションを用いて増強量を評価する。
実験結果
リサーチクエスチョン
- RQ1Gumbel-SoftmaxとRELAXによる微分可能最適化は、データ拡張ポリシーとネットワーク重みの効率的な共同学習を可能にするか?
- RQ2DADAは最先端のAuto-DA手法と同等の精度を、探索コストを大幅に削減して達成するか?
- RQ3DADAは大規模データセット(ImageNet)および下流タスク(物体検出)へどの程度転移できるか?
主な発見
| データセット | モデル | ベースライン | カットアウト | AA | PBA | Fast AA | DADA |
|---|---|---|---|---|---|---|---|
| CIFAR-10 | Wide-ResNet-40-2 | 5.3 | 4.1 | 3.7 | - | 0 | 3.6 |
| CIFAR-10 | Wide-ResNet-28-10 | 3.9 | 3.1 | 2.6 | 2.6 | 2.7 | 2.7 |
| CIFAR-10 | Shake-Shake(26 2x32d) | 3.6 | 3.0 | 2.5 | 2.5 | 2.7 | 2.7 |
| CIFAR-10 | Shake-Shake(26 2x96d) | 2.9 | 2.6 | 2.0 | 2.0 | 2.0 | 2.0 |
| CIFAR-10 | Shake-Shake(26 2x112d) | 2.8 | 2.6 | 1.9 | 2.0 | 2.0 | 2.0 |
| CIFAR-10 | PyramidNet+ShakeDrop | 2.7 | 2.3 | 1.5 | 1.5 | 1.8 | 1.7 |
| CIFAR-100 | Wide-ResNet-40-2 | 26.0 | 25.2 | 20.7 | - | 20.7 | 20.9 |
| CIFAR-100 | Wide-ResNet-28-10 | 18.8 | 18.4 | 17.1 | 16.7 | 17.3 | 17.5 |
| CIFAR-100 | Shake-Shake(26 2x96d) | 17.1 | 16.0 | 14.3 | 15.3 | 14.9 | 15.3 |
| CIFAR-100 | PyramidNet+ShakeDrop | 14.0 | 12.2 | 10.7 | 10.9 | 11.9 | 11.2 |
- DADAは最先端のDA手法より少なくとも1桁の速度向上を達成しつつ、競争力のある精度を維持します。
- ImageNetでは、検索に1.3 GPU時間を報告し、top-1誤差22.5%(ResNet-50)。”
- CIFAR-10/100およびSVHNでは、検索コストを劇的に削減しつつ競争力のある誤り率を提供します(例:データ量を削減したCIFAR-10の検索は約0.1 GPU時間)。
- RELAXを用いると、Gumbel-Softmaxと比べて勾配推定のバイアスが低減され、CIFAR-10でポリシー性能が向上します。
- DADAが学習したDAポリシーは、COCOデータセット上の下流検出モデル(RetinaNet、Faster R-CNN、Mask R-CNN)を改善します。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。