[论文解读] Efficient Sharpness-aware Minimization for Improved Training of Neural Networks
ESAM 通过引入随机权重扰动和对锐度敏感的数据选择,在减少 Sharpness Aware Minimization (SAM) 的计算开销的同时,保持或提升相对于 SAM 的泛化性能。
Overparametrized Deep Neural Networks (DNNs) often achieve astounding performances, but may potentially result in severe generalization error. Recently, the relation between the sharpness of the loss landscape and the generalization error has been established by Foret et al. (2020), in which the Sharpness Aware Minimizer (SAM) was proposed to mitigate the degradation of the generalization. Unfortunately, SAM s computational cost is roughly double that of base optimizers, such as Stochastic Gradient Descent (SGD). This paper thus proposes Efficient Sharpness Aware Minimizer (ESAM), which boosts SAM s efficiency at no cost to its generalization performance. ESAM includes two novel and efficient training strategies-StochasticWeight Perturbation and Sharpness-Sensitive Data Selection. In the former, the sharpness measure is approximated by perturbing a stochastically chosen set of weights in each iteration; in the latter, the SAM loss is optimized using only a judiciously selected subset of data that is sensitive to the sharpness. We provide theoretical explanations as to why these strategies perform well. We also show, via extensive experiments on the CIFAR and ImageNet datasets, that ESAM enhances the efficiency over SAM from requiring 100% extra computations to 40% vis-a-vis base optimizers, while test accuracies are preserved or even improved.
研究动机与目标
- 通过促进稳定极小值来提高过参数化 DNN 的泛化能力的动机。
- 在不牺牲性能的前提下,扩展 SAM 以加入提升效率的策略。
- 开发并评估 ESAM,包含两个组成部分:随机权重扰动(SWP)和锐度敏感数据选择(SDS)。
- 为 SWP 和 SDS 提供理论依据,并通过在 CIFAR-10、CIFAR-100 和 ImageNet 上的大量实验进行验证。
提出的方法
- 回顾 SAM 及其计算缺陷:为了锐度的内在最大化在每次迭代中需要额外的前向/后向传播。
- 通过整合两种策略:SWP 和 SDS,引入 ESAM。
- SWP:在锐度估计期间随机选择一部分权重进行扰动,并对扰动进行缩放以使期望扰动的范数等同于 SAM,降低反向传播成本。
- SDS:从每个小批量中选择在权重扰动下损失增加最多的一部分数据,以在更少样本的情况下近似 SAM 目标。
- 提供理论论证,表明 SWP 的期望扰动在范数和方向上与 SAM 相符,且 SDS 使用锐度敏感子集对 SAM 损失给出上界。
- 算法 1 概述了带有邻域大小、扰动缩放和子集比率参数的 ESAM。
实验结果
研究问题
- RQ1ESAM 是否能够在降低计算开销的同时实现与 SAM 相似的稳定极小值?
- RQ2随机权重扰动对锐度估计保真度有何影响?
- RQ3锐度敏感数据选择是否在提高训练效率的同时保留 SAM 的泛化性能?
- RQ4SWP 和 SDS 在多种体系结构与数据集上的单独及共同表现如何?
主要发现
| 数据集 / 模型 | SGD 准确率 | SGD 图像/秒 | SAM 准确率 | SAM 图像/秒 | ESAM 准确率 | ESAM 图像/秒 |
|---|---|---|---|---|---|---|
| CIFAR-10 / ResNet-18 | 95.41 | 3387 | 96.52 | 1717 | 96.56 | 2409 |
| CIFAR-10 / Wide-28-10 | 96.34 | 801 | 97.27 | 396 | 97.29 | 550 |
| CIFAR-10 / PyramidNet-110 | 96.62 | 580 | 97.30 | 289 | 97.81 | 401 |
| CIFAR-100 / ResNet-18 | 78.17 | 3438 | 80.17 | 1730 | 80.41 | 2423 |
| CIFAR-100 / Wide-28-10 | 81.56 | 792 | 83.42 | 391 | 84.51 | 545 |
| CIFAR-100 / PyramidNet-110 | 81.89 | 555 | 84.46 | 276 | 85.56 | 381 |
| ImageNet / ResNet-50 | 76.00* | 1327 | 76.70* | 654 | 77.05 | 846 |
| ImageNet / ResNet-101 | 77.80* | 891 | 78.60* | 438 | 79.09 | 564 |
- ESAM 将开销从 SAM 的大约额外 100% 计算降至基线优化器的约 40%,同时保持或提升测试准确率。
- 在 CIFAR-10/100 上,ESAM 在多种体系结构(ResNet-18、Wide-ResNet-28-10、PyramidNet-110)上优于 SAM,具有更高或可比的准确度和更高的吞吐量(图像/秒)。
- 在 ImageNet 上,ESAM 的准确率高于 SAM(例如 ResNet-50 与 ResNet-101),训练速度比 SAM 大约快 28.7%。
- 消融研究表明 SWP 与 SDS 两者均提升了效率与性能;最佳设置通常接近 beta ≈ 0.5–0.6 与 gamma ≈ 0.5。
- 损失景观可视化显示 ESAM 达到的极小值比 SGD 更平坦,且与 SAM 相似。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。