[论文解读] Image Classification at Supercomputer Scale
论文提出系统级优化,在 ImageNet 上大规模训练 ResNet-50,使用 1024-chip TPU v3 Pod,在 2.2 分钟内达到 76.3% 的准确率,吞吐量超过 1.05 百万张图像/秒。
Deep learning is extremely computationally intensive, and hardware vendors have responded by building faster accelerators in large clusters. Training deep learning models at petaFLOPS scale requires overcoming both algorithmic and systems software challenges. In this paper, we discuss three systems-related optimizations: (1) distributed batch normalization to control per-replica batch sizes, (2) input pipeline optimizations to sustain model throughput, and (3) 2-D torus all-reduce to speed up gradient summation. We combine these optimizations to train ResNet-50 on ImageNet to 76.3% accuracy in 2.2 minutes on a 1024-chip TPU v3 Pod with a training throughput of over 1.05 million images/second and no accuracy drop.
研究动机与目标
- 推动在 petascale 水平训练深度网络并在保持准确性的前提下缩短墙钟时间。
- 识别在加速器上进行大批量同步 SGD 时的系统瓶颈。
- 开发并验证在不牺牲模型质量的前提下平衡全局批量大小与每个副本的批量大小的技术。
- 展示组合优化,达到最先进的吞吐量并保持准确性。
提出的方法
- 对卷积使用 bfloat16 的混合精度训练,对非卷积运算使用 32 位。
- 采用带 warmup 和衰减的线性学习率标度;使用 LARS 将批量大小扩展至最多 32768。
- 引入分布式批量归一化,使 BN 统计量的控制独立于全局批量大小。
- 通过数据集分片、缓存、预取、融合的 JPEG 解码与裁剪,以及并行解析来优化输入数据管线。
- 为梯度求和采用 2-D torus 全规约算法,以降低通信延迟。
- 在 1024-chip TPU v3 Pod 上展示性能并与先前结果进行对比。
实验结果
研究问题
- RQ1如何将分布式 BN、输入管线优化和 2-D 全规约结合,以在大规模规模下实现同步 SGD?
- RQ2在大批量训练中,每副本批量大小、全局批量大小与模型准确性之间存在哪些取舍?
- RQ3在训练时使用非常大的批量时,混合精度和缩放策略是否能保持 ImageNet 上 ResNet-50 的准确性?
- RQ4在大型 TPU pod 上,若不损失准确性,能够达到怎样的端到端吞吐量和墙钟时间?
主要发现
| Hardware | Chips | Batch | Optimizer | BN | Accuracy | Time | |
|---|---|---|---|---|---|---|---|
| Goyal et al. [6] | P100 | 256 | 8192 | Momentum | Local | 76.3% | 1 hour |
| Smith et al. [16] | TPU v2 | 128 | 8192 → 16384 | Momentum | Local | 76.1% | 30 mins. |
| Akiba et al. [2] | P100 | 1024 | 32768 | RMS + Mom. | Local | 74.9% | 15 mins. |
| Jia et al. [10] | P40 | 1024 | 65536 | LARS | Local | 76.2% | 8.7 mins. |
| Baseline TPU v2 | 4 | 1024 | Momentum | Local | 76.3% | 7.2 hours | |
| Ours TPU v2 | 256 | 16384 | Momentum | Local | 75.3% | 9.7 mins. | |
| Ours TPU v2 | 256 | 32768 | LARS | Local | 76.3% | 8.0 mins. | |
| Ours TPU v3 | 512 | 32768 | LARS | Local | 76.4% | 3.3 mins. | |
| Ours TPU v3 | 1024 | 65536 | LARS | Local | 75.2% | 1.8 mins. | |
| Ours TPU v3 | 1024 | 32768 | LARS | Distributed | 76.3% | 2.2 mins. |
- 在规模化训练中,ImageNet 的 ResNet-50 达到 76.3% 的准确率且不牺牲准确性。
- 在 1024-chip TPU v3 Pod 上,训练时间 2.2 分钟,吞吐量超过 1.05 百万张图像/秒。
- 2-D 梯度求和与 torus 链路在潜延迟和吞吐量上优于 1-D ring all-reduce,使可扩展的同步成为可能。
- 分布式 BN 使 BN 的有效批量大小独立于全局批量大小,有助于在大规模下保持准确性。
- 输入管线优化(缓存、预取、融合 JPEG 解码/裁剪、并行解析)显著提升吞吐量,尤其在数据在各工作节点分片时。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。