[论文解读] Towards Optimal Structured CNN Pruning via Generative Adversarial Learning
本文提出 Generative Adversarial Learning (GAL),通过学习稀疏软掩码在端到端、无标签的方式对 CNN 进行裁剪,联合裁剪通道、分支和块,获得强大的压缩和加速。
Structured pruning of filters or neurons has received increased focus for compressing convolutional neural networks. Most existing methods rely on multi-stage optimizations in a layer-wise manner for iteratively pruning and retraining which may not be optimal and may be computation intensive. Besides, these methods are designed for pruning a specific structure, such as filter or block structures without jointly pruning heterogeneous structures. In this paper, we propose an effective structured pruning approach that jointly prunes filters as well as other structures in an end-to-end manner. To accomplish this, we first introduce a soft mask to scale the output of these structures by defining a new objective function with sparsity regularization to align the output of baseline and network with this mask. We then effectively solve the optimization problem by generative adversarial learning (GAL), which learns a sparse soft mask in a label-free and an end-to-end manner. By forcing more scaling factors in the soft mask to zero, the fast iterative shrinkage-thresholding algorithm (FISTA) can be leveraged to fast and reliably remove the corresponding structures. Extensive experiments demonstrate the effectiveness of GAL on different datasets, including MNIST, CIFAR-10 and ImageNet ILSVRC 2012. For example, on ImageNet ILSVRC 2012, the pruned ResNet-50 achieves 10.88\% Top-5 error and results in a factor of 3.7x speedup. This significantly outperforms state-of-the-art methods.
研究动机与目标
- 相较于多阶段逐层方法,提出高效、宽松且无标签的结构化裁剪动机。
- 提出一种软掩码框架,用于对 CNN 中异构结构进行稀疏化与裁剪。
- 开发一种端到端的 GAL 优化,利用判别器和 FISTA 来移除冗余结构。
提出的方法
- 引入一个稀疏软掩码 m,用于对可裁剪的结构(通道、分支、块)的输出进行尺度缩放。
- 用生成器(裁剪后的网络)和判别器来构建对抗目标,使裁剪后的输出与基线输出对齐(通过 MSE 的数据损失)。
- 采用交替的 GAN+FISTA 方法求解优化:通过 SGD 更新判别器,使用带 L1 稀疏性约束的 FISTA 进行裁剪。
- 对 m 使用 L1 稀疏性惩罚,以便在 m_i → 0 时实现结构移除。
- 对权重和判别器应用正则化(L1/L2 或对抗正则化),以平衡博弈。
- 使用 FISTA 高效获得用于裁剪的精确零掩码条目。
实验结果
研究问题
- RQ1一个软可学习掩码是否能在端到端、无标签的条件下同时裁剪异构的 CNN 结构(通道、分支、块)?
- RQ2将 L1 稀疏掩码与 FISTA 相结合的生成对抗学习是否在压缩率和精度保持方面优于传统的多阶段裁剪?
- RQ3GAL 在数据集(MNIST、CIFAR-10、ImageNet)和架构(LeNet、VGG、DenseNet、GoogLeNet、ResNet、DenseNet-40、ResNet 变体)上的表现如何?
主要发现
- 在 ImageNet 上,使用 GAL 裁剪的 ResNet-50 达到 10.88% 的 Top-5 错误率和 3.7× 加速。
- GAL 在 MNIST、CIFAR-10 和 ImageNet 上显示出对不同结构(通道、分支、块)的强裁剪效率。
- 消融实验表明对判别器的对抗正则化优于 L1/L2,提升了裁剪效果。
- 在 ResNet-50 上对块和通道的联合裁剪(GAL-0.5-joint)比单独裁剪块或通道获得更高的加速和压缩。
- GAL 在多种网络和数据集上常常达到甚至超过最先进的裁剪方法。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。