[论文解读] Sharpness-Aware Minimization for Efficiently Improving Generalization
SAM 通过同时优化训练损失和损失景观的陡度来提升泛化,在 CIFAR、ImageNet 和迁移任务上获得更好的测试表现。
In today's heavily overparameterized models, the value of the training loss provides few guarantees on model generalization ability. Indeed, optimizing only the training loss value, as is commonly done, can easily lead to suboptimal model quality. Motivated by prior work connecting the geometry of the loss landscape and generalization, we introduce a novel, effective procedure for instead simultaneously minimizing loss value and loss sharpness. In particular, our procedure, Sharpness-Aware Minimization (SAM), seeks parameters that lie in neighborhoods having uniformly low loss; this formulation results in a min-max optimization problem on which gradient descent can be performed efficiently. We present empirical results showing that SAM improves model generalization across a variety of benchmark datasets (e.g., CIFAR-10, CIFAR-100, ImageNet, finetuning tasks) and models, yielding novel state-of-the-art performance for several. Additionally, we find that SAM natively provides robustness to label noise on par with that provided by state-of-the-art procedures that specifically target learning with noisy labels. We open source our code at \url{https://github.com/google-research/sam}.
研究动机与目标
- 将损失景观几何与过参数化模型的泛化之间的联系动机化。
- 提出一个实用的优化目标,最小化损失值和损失陡度。
- 证明优化陡度可以在多样化数据集和体系结构上实现更好的泛化。
提出的方法
- 引入 SAM:min_w max_{||epsilon||_p <= rho} L_S(w+epsilon) + lambda||w||^2。
- 通过对内层最大化求导并使用一阶泰勒展开,推导出高效的梯度近似。
- 将扰动 hat{epsilon}(w) 计算为训练损失梯度的 rho 尺度标准化(p=2 情况)。
- 在对 w+hat{epsilon}(w) 处计算梯度并对 SAM 目标使用 SGD 更新 w。
- 提供伪代码(算法 1)并讨论实际实现细节,包括并行化和二阶项消融。
- 分析 m-陡度(子批次扰动)和 Hessian 谱来将陡度与泛化联系起来。
实验结果
研究问题
- RQ1在标准视觉基准上将陡度项纳入训练目标是否能提升泛化?
- RQ2相较于 SGD,SAM 在 CIFAR-10/100、ImageNet 以及微调任务上的表现如何?
- RQ3邻域大小 rho 与 m-陡度变体对性能和泛化的影响如何?
- RQ4SAM 是否在标签噪声鲁棒性方面与最先进的噪声标签方法相当?
主要发现
| 模型 | 数据增强 | CIFAR-10(SAM) | CIFAR-10(SGD) | CIFAR-100(SAM) | CIFAR-100(SGD) |
|---|---|---|---|---|---|
| WRN-28-10 (200 epochs) | Basic | 2.7±0.1 | 3.5±0.1 | 16.5±0.2 | 18.8±0.2 |
| WRN-28-10 (200 epochs) | Cutout | 2.3±0.1 | 2.6±0.1 | 14.9±0.2 | 16.9±0.1 |
| WRN-28-10 (200 epochs) | AA | 2.1±<0.1 | 2.3±0.1 | 13.6±0.2 | 15.8±0.2 |
| WRN-28-10 (1800 epochs) | Basic | 2.4±0.1 | 3.5±0.1 | 16.3±0.2 | 19.1±0.1 |
| WRN-28-10 (1800 epochs) | Cutout | 2.1±0.1 | 2.7±0.1 | 14.0±0.1 | 17.4±0.1 |
| WRN-28-10 (1800 epochs) | AA | 1.6±<0.1 | 1.9±<0.1 | 11.3±0.1 | 14.6±0.1 |
| Shake-Shake (26 2x96d) | Basic | 2.3±<0.1 | 2.7±0.1 | 15.1±0.1 | 17.0±0.1 |
| Shake-Shake (26 2x96d) | Cutout | 2.0±<0.1 | 2.3±0.1 | 14.2±0.2 | 15.7±0.2 |
| Shake-Shake (26 2x96d) | AA | 1.6±<0.1 | 1.9±0.1 | 12.8±0.1 | 14.1±0.2 |
| PyramidNet | Basic | 2.7±0.1 | 4.0±0.1 | 14.6±0.4 | 19.7±0.3 |
| PyramidNet | Cutout | 1.9±0.1 | 2.5±0.1 | 12.6±0.2 | 16.4±0.1 |
| PyramidNet | AA | 1.6±0.1 | 1.9±0.1 | 11.6±0.1 | 14.6±0.1 |
| PyramidNet+ShakeDrop | Basic | 2.1±0.1 | 2.5±0.1 | 13.3±0.2 | 14.5±0.1 |
| PyramidNet+ShakeDrop | Cutout | 1.6±<0.1 | 1.9±0.1 | 11.3±0.1 | 11.8±0.2 |
| PyramidNet+ShakeDrop | AA | 1.4±<0.1 | 1.6±<0.1 | 10.3±0.1 | 10.6±0.1 |
- 与常规 SGD 相比,SAM 在 CIFAR-10/100、ImageNet 和微调任务上持续提升泛化。
- 在 CIFAR-10/100 上,SAM 对若干模型和增强方式达到接近最先进的结果(如 WRN、Shake-Shake、PyramidNet及其组合),例如在某些设置下 CIFAR-10 的错误率为 1.6%,在 AA 配置下 CIFAR-100 的错误率为 11.3%。
- SAM 提供对标签噪声的鲁棒性,达到与专门的噪声标签方法相当的水平,且通常优于无显式噪声标签策略的基线。
- m-陡度变体表明较小的 m(每个加速器子批次)能带来更好的泛化并与实际泛化差距的相关性强于全批次陡度。
- 在 ImageNet 上对 ResNet 变体的训练中,SAM 提高了 top-1 和 top-5 精度;例如在 400 世纪期时,ResNet-152 的 top-1 误差从 20.3%(非 SAM)降至 18.4%(使用 SAM)。
- Hessian 分析证实 SAM 收敛到的极小值具较低的曲率(lambda_max 约 1.0,相较于无 SAM 的约 24)并且谱更平坦。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。