[論文レビュー] Dynamic Sparse Training: Find Efficient Sparse Network From Scratch With\n Trainable Masked Layers
tldr: 本論文は Dynamic Sparse Training (DST) を提案します。学習中に trainable pruning 閾値を用いてスパースなネットワーク構造と重みを共同で学習し、細かなステップごとの剪定と回復を可能にするエンドツーエンド手法です。
We present a novel network pruning algorithm called Dynamic Sparse Training\nthat can jointly find the optimal network parameters and sparse network\nstructure in a unified optimization process with trainable pruning thresholds.\nThese thresholds can have fine-grained layer-wise adjustments dynamically via\nbackpropagation. We demonstrate that our dynamic sparse training algorithm can\neasily train very sparse neural network models with little performance loss\nusing the same number of training epochs as dense models. Dynamic Sparse\nTraining achieves the state of the art performance compared with other sparse\ntraining algorithms on various network architectures. Additionally, we have\nseveral surprising observations that provide strong evidence for the\neffectiveness and efficiency of our algorithm. These observations reveal the\nunderlying problems of traditional three-stage pruning algorithms and present\nthe potential guidance provided by our algorithm to the design of more compact\nnetwork architectures.\n
研究の動機と目的
- 推論時のメモリと計算量を削減するために効率的なスパースネットワークの必要性を動機づける。
- 重みとレイヤー毎の剪定マスクの両方を学習するエンドツーエンドのスパース学習フレームワークを提案する。
- 各学習ステップ内でエポック間ではなく、細かなステップごとの剪定と回復を可能にする。
- バックプロパゲーションと trainable threshold 機構を介してレイヤー毎の剪定率を自動的に調整する。
- 複数のアーキテクチャで MNIST、CIFAR-10、ImageNet において最先端の性能を示す。
提案手法
- 各レイヤーごとにニューロン/フィルタ用の trainable な閾値として剪定を表現する。
- |W| - t に対する単位階を適用した S によるバイナリマスク M を導出して W ∘ M のスパース化を得る。
- 閾値ベクトル t の学習のためのストレートスルー推定器に基づく微分を導入する。
- より高いスパース性を促すための sparse 正則化項 Ls = sum exp(-ti) を組み込む。
- 性能勾配と構造勾配の両方を伝搬可能とするため、密結合層を trainable masked 層に置換する。
- 閾値とマスクは各学習ステップで更新可能であり、細かな剪定と回復を実現する。
実験結果
リサーチクエスチョン
- RQ1trainable pruning thresholds を用いて、重みとスパース構造をエンドツーエンドで共同学習できるか。
- RQ2学習中のステップごとの剪定と回復は、スパース学習における事前定義された剪定スケジュールよりも優れているか。
- RQ3レイヤーごとに学習可能な閾値が、アーキテクチャ全体で最終的なスパース性パターンとモデル性能にどのように影響するか。
- RQ4DST が観察されたスパースパターンを通じて、コンパクトなアーキテクチャ設計の指針を提供するか。
主な発見
| アーキテクチャ | 方法 | Dense baseline | Model Remaining Percentage (%) | Sparse Accuracy | Difference |
|---|---|---|---|---|---|
| Lenet-300-100 | Dense baseline vs sparsity (DST results) | 98.16 ± 0.06 | 2.48 ± 0.21 | 97.69 ± 0.14 | -0.47 |
| Lenet-5-Caffe | Dense baseline vs sparsity (DST results) | 99.18 ± 0.05 | 1.64 ± 0.13 | 99.11 ± 0.07 | -0.07 |
| LSTM-a | Dense baseline vs sparsity (DST results) | 98.64 ± 0.12 | 1.93 ± 0.03 | 98.70 ± 0.06 | +0.06 |
| LSTM-b | Dense baseline vs sparsity (DST results) | 98.87 ± 0.07 | 0.98 ± 0.04 | 98.89 ± 0.11 | +0.02 |
| VGG-16 (CIFAR-10) | Sparse Momentum vs DST results | 93.51 ± 0.05 | 10 | 93.36 ± 0.04 | -0.15 |
| VGG-16 (CIFAR-10) | DST results (8.82% remaining) | 93.75 ± 0.21 | 8.82 ± 0.34 | 93.93 ± 0.05 | +0.18 |
| VGG-16 (CIFAR-10) | DST results (3.76% remaining) | 93.75 ± 0.21 | 3.76 ± 0.53 | 93.02 ± 0.37 | -0.73 |
| WideResNet-16-8 (CIFAR-10) | Sparse Momentum vs DST results | 95.43 ± 0.02 | 10 | 94.87 ± 0.04 | -0.56 |
| WideResNet-16-8 (CIFAR-10) | DST results (8.?? remaining) | 95.18 ± 0.06 | 9.86 ± 0.22 | 95.05 ± 0.08 | -0.13 |
| WideResNet-16-8 (CIFAR-10) | DST results (4.64% remaining) | 95.18 ± 0.06 | 4.64 ± 0.15 | 94.73 ± 0.11 | -0.45 |
| ResNet-50 (ImageNet) | DS T vs baselines | Dense baseline 74.90 / 92.40 | 20 | 73.80 / 91.80 | -1.10 / -0.60 |
| ResNet-50 (ImageNet) | DST results (19.24 remaining) | 74.95 / 92.60 | 19.24 | 74.02 / 92.49 | -0.73 / -0.11 |
| ResNet-50 (ImageNet) | DST results (9.87 remaining) | 74.95 / 92.60 | 9.87 | 72.78 / 91.53 | -2.17 / -1.07 |
- DST は MNIST Lenet-300-100 でほぼ 98% のパラメータを剪定しても性能低下が小さい(スパース結果: 残り 2.48%)。
- MNIST Lenet-5-Caffe ではスパース学習により残りが 1.64% でほぼ密での精度を達成(スパース: 99.11%)。
- シーケンシャルMNIST の LSTM モデルは 99% 以上のパラメータ剪定で同等かそれ以上のスパース精度を達成。
- CIFAR-10 の VGG-16 および WideResNet で、DST は高スパース性で Sparse Momentum や Dynamic Sparse Reparameterization を上回る(例: VGG-16: DST 残り 8.82%、スパース精度 93.93% 同様に Others は 10% 残り)。
- ImageNet (ResNet-50): DST はベースラインより少し高いスパース性で高い top-1/top-5 精度を達成(設定により残り ~9.87-19.24%)。
- DST は α の異なる値に対して一貫したスパースパターンを示し、レイヤーごとの冗長性を示し、アーキテクチャ設計の指針を提供する。
より良い研究を、今すぐ始めましょう
論文設計から論文執筆まで、研究時間を劇的に削減しましょう。
クレジットカード登録不要
このレビューはAIが作成し、人間の編集者が確認しました。