[논문 리뷰] DADA: Differentiable Automatic Data Augmentation
DADA는 편향되지 않은 RELAX 그래디언트 추정기로 데이터 증강 정책을 학습하기 위한 차별화 가능 프레임워크를 제시하며, 기존 Auto-DA 방법들보다 최소 한 차수 빠른 탐색을 달성하면서도 경쟁력 있는 정확도를 유지합니다. 정책 샘플링에 Gumbel-Softmax 이완을 사용하고 네트워크와 증강 매개변수를 공동으로 학습하기 위한 일회성 양층 최적화를 적용합니다.
Data augmentation (DA) techniques aim to increase data variability, and thus train deep networks with better generalisation. The pioneering AutoAugment automated the search for optimal DA policies with reinforcement learning. However, AutoAugment is extremely computationally expensive, limiting its wide applicability. Followup works such as Population Based Augmentation (PBA) and Fast AutoAugment improved efficiency, but their optimization speed remains a bottleneck. In this paper, we propose Differentiable Automatic Data Augmentation (DADA) which dramatically reduces the cost. DADA relaxes the discrete DA policy selection to a differentiable optimization problem via Gumbel-Softmax. In addition, we introduce an unbiased gradient estimator, RELAX, leading to an efficient and effective one-pass optimization strategy to learn an efficient and accurate DA policy. We conduct extensive experiments on CIFAR-10, CIFAR-100, SVHN, and ImageNet datasets. Furthermore, we demonstrate the value of Auto DA in pre-training for downstream detection problems. Results show our DADA is at least one order of magnitude faster than the state-of-the-art while achieving very comparable accuracy. The code is available at https://github.com/VDIGPKU/DADA.
연구 동기 및 목표
- 레이블 데이터가 제한될 때 일반화 성능을 향상시키기 위한 자동 데이터 증강(DA) 정책 학습 동기 부여.
- 네트워크 가중치와의 공동 최적화를 가능하게 하는 DA 정책 탐색의 차별화 가능 공식화 제안.
- AutoAugment, PBA, Fast AutoAugment에 비해 DA 정책 탐색의 계산 비용 감소.
제안 방법
- 부분 정책 선택을 범주형 분포로 표현하고, 작동 적용은 베르누이 분포로 나타낸다.
- 정책 선택의 이산성을 Gumbel-Softmax로 이완하여 차별화 가능 최적화를 수행한다.
- 배포 매개변수에 대해 편향 없는 그래디언트를 얻기 위해 RELAX 그래디언트 추정기를 사용한다.
- 네트워크 가중치와 DA 정책 매개변수를 함께 업데이트하기 위한 일회성 양층 최적화를 적용한다.
- 직전 경사 추정 및 그래디언트 기반 역전파를 통해 증강 크기를 평가한다.
실험 결과
연구 질문
- RQ1Gumbel-Softmax 및 RELAX를 통한 차별화 가능 최적화가 데이터 증강 정책과 네트워크 가중치의 효율적인 공동 학습을 가능하게 하는가?
- RQ2DADA가 검색 비용을 크게 줄이면서 최첨단 Auto-DA 방법과 유사한 정확도를 달성하는가?
- RQ3DADA가 대규모 데이터셋(ImageNet) 및 하위 작업(객체 탐지)에 얼마나 잘 일반화되는가?
주요 결과
| 데이터셋 | 모델 | baseline | 컷아웃 | AA | PBA | 빠른 AA | DADA |
|---|---|---|---|---|---|---|---|
| CIFAR-10 | Wide-ResNet-40-2 | 5.3 | 4.1 | 3.7 | - | 0 | 3.6 |
| CIFAR-10 | Wide-ResNet-28-10 | 3.9 | 3.1 | 2.6 | 2.6 | 2.7 | 2.7 |
| CIFAR-10 | Shake-Shake(26 2x32d) | 3.6 | 3.0 | 2.5 | 2.5 | 2.7 | 2.7 |
| CIFAR-10 | Shake-Shake(26 2x96d) | 2.9 | 2.6 | 2.0 | 2.0 | 2.0 | 2.0 |
| CIFAR-10 | Shake-Shake(26 2x112d) | 2.8 | 2.6 | 1.9 | 2.0 | 2.0 | 2.0 |
| CIFAR-10 | PyramidNet+ShakeDrop | 2.7 | 2.3 | 1.5 | 1.5 | 1.8 | 1.7 |
| CIFAR-100 | Wide-ResNet-40-2 | 26.0 | 25.2 | 20.7 | - | 20.7 | 20.9 |
| CIFAR-100 | Wide-ResNet-28-10 | 18.8 | 18.4 | 17.1 | 16.7 | 17.3 | 17.5 |
| CIFAR-100 | Shake-Shake(26 2x96d) | 17.1 | 16.0 | 14.3 | 15.3 | 14.9 | 15.3 |
| CIFAR-100 | PyramidNet+ShakeDrop | 14.0 | 12.2 | 10.7 | 10.9 | 11.9 | 11.2 |
- DADA는 최첨단 DA 방법들에 비해 최소 한 차수의 속도 향상을 달성하면서도 경쟁력 있는 정확도를 유지합니다.
- ImageNet에서 DADA는 22.5% 상위-1 오류(ResNet-50)에 대해 탐색에 1.3 GPU-시간을 보고합니다.
- CIFAR-10/100 및 SVHN에서 DADA는 탐색 비용을 크게 줄인 상태로 경쟁력 있는 오류율을 제공합니다(예: CIFAR-10의 축소 데이터 탐색 약 0.1 GPU-시간).
- RELAX를 사용하면 Gumbel-Softmax에 비해 그래디언트 추정의 바이어스가 감소하고 CIFAR-10에서의 정책 성능이 향상됩니다.
- DADA의 학습 DA 정책은 COCO에서 RetinaNet, Faster R-CNN, Mask R-CNN과 같은 하위 탐지 모델의 성능을 개선합니다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.