[논문 리뷰] Training Deep Neural Networks with 8-bit Floating Point Numbers
이 논문은 chunk 기반 누적 및 부동 소수점 확률 반올림으로 가능하게 FP16 누적과 FP16 가중치 업데이트를 사용하여 8비트 부동 소수점 수(FP8)로 다양한 DNN을 성공적으로 학습시켰으며, FP32 기준선에 비해 정확도에 상응하는 성능을 달성하면서 메모리 및 계산 요구를 줄였습니다.
The state-of-the-art hardware platforms for training Deep Neural Networks (DNNs) are moving from traditional single precision (32-bit) computations towards 16 bits of precision -- in large part due to the high energy efficiency and smaller bit storage associated with using reduced-precision representations. However, unlike inference, training with numbers represented with less than 16 bits has been challenging due to the need to maintain fidelity of the gradient computations during back-propagation. Here we demonstrate, for the first time, the successful training of DNNs using 8-bit floating point numbers while fully maintaining the accuracy on a spectrum of Deep Learning models and datasets. In addition to reducing the data and computation precision to 8 bits, we also successfully reduce the arithmetic precision for additions (used in partial product accumulation and weight updates) from 32 bits to 16 bits through the introduction of a number of key ideas including chunk-based accumulation and floating point stochastic rounding. The use of these novel techniques lays the foundation for a new generation of hardware training platforms with the potential for 2-4x improved throughput over today's systems.
연구 동기 및 목표
- 모델 정확도를 손실 없이 8비트로 낮춘 학습 정밀도 감소를 목표로 한다.
- FP8/FP16 형식 및 누적·업데이트 문제를 극복하기 위한 기술을 도입한다.
- 표준 데이터셋에서 CNN 및 DNN 전반에 걸친 광범위한 실증 검증을 보여준다.
- 처리량 및 에너지 효율성을 2-4배 향상시키는 하드웨어 친화적 접근법을 제안한다.
제안 방법
- 데이터 및 누적을 위한 FP8(1,5,2) 및 FP16(1,6,9) 형식을 정의한다.
- 길게 이어진 점곱(dot-product)을 파티션하고 swamping 오류를 줄이기 위해 chunk-based 누적을 사용한다.
- 반올림 중 정보 손실을 보존하기 위해 가중치 업데이트에 부동 소수점 확률적 반올림을 적용한다.
- Softmax 계산을 안정시키기 위해 마지막 층 GEMMs에 대해 FP16을 유지한다.
- 역전파 중 작은 기울기를 보존하기 위해 로스 스케일링을 채택한다.
- 여러 네트워크와 데이터셋에서 저정밀도 에뮬레이션 실험을 통해 검증한다.
실험 결과
연구 질문
- RQ1다양한 모델과 데이터셋에서 정확도 손실 없이 8비트 부동 소수점 표현을 사용해 DNN을 학습시킬 수 있는가?
- RQ2저정밀도 형식을 학습 중에 사용할 때 swamping 및 누적 오류를 어떻게 완화할 수 있는가?
- RQ3메모리, 대역폭 및 에너지 효율성 측면에서 FP8 학습의 현실적인 하드웨어 함의는 무엇인가?
- RQ4FP8 학습에서 초/말층 정밀도가 성공에 어떤 역할을 하는가?
- RQ5반올림 모드가 FP8 학습 정확도에 어떤 영향을 미치는가?
주요 결과
- FP8 학습은 FP16 누적 및 FP16 가중치 업데이트와 함께 CIFAR-10 CNN, CIFAR-10 ResNet, BN50-DNN, AlexNet, ResNet-18, ResNet-50에서 FP32 기준선과 유사한 테스트 정확도를 달성한다.
- FP8 가중치 및 FP16 마스터 복사본으로 인한 가중치 메모리 및 마스터 복사본 메모리가 약 2× 감소한다.
- chunk-based 누적 및 부동 소수점 확률적 반올림이 swamping을 효과적으로 완화하여 8비트 학습을 견고하게 만든다.
- 로스 스케일링 및 마지막 층 GEMMs를 FP16으로 예약해 ImageNet과 같은 대규모 데이터셋에서 학습을 안정화한다.
- Nearest 반올림은 정확도를 저하시킬 수 있으며, FP16 가중치 업데이트 중 확률적 반올림이 기준선 성능을 유지한다.
- 하드웨어 시연은 FP8 엔진이 FP16에 비해 2-4× 더 에너지 효율적일 수 있음을 시사한다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.