[논문 리뷰] Automatic Cross-Replica Sharding of Weight Update in Data-Parallel Training
이 논문은 데이터 병렬 학습에서 모든 레플리카에 반복적으로 수행되는 비분할 업데이트로 인해 발생하는 성능 저하를 줄이기 위해 가중치 업데이트 계산의 자동화된 크로스-레플리카 샤딩을 제안한다. XLA 내의 정적 분석과 그래프 변환을 통해, 코드 수정 없이도 기존 하드웨어를 활용해 가중치와 옵티마이저 보조 변수를 효율적으로 샤딩하고 최적화된 통신을 구현함으로써, Transformer와 같은 대규모 모델에서 최대 45%의 속도 향상을 달성한다.
In data-parallel synchronous training of deep neural networks, different devices (replicas) run the same program with different partitions of the training batch, but weight update computation is repeated on all replicas, because the weights do not have a batch dimension to partition. This can be a bottleneck for performance and scalability in typical language models with large weights, and models with small per-replica batch size which is typical in large-scale training. This paper presents an approach to automatically shard the weight update computation across replicas with efficient communication primitives and data formatting, using static analysis and transformations on the training computation graph. We show this technique achieves substantial speedups on typical image and language models on Cloud TPUs, requiring no change to model code. This technique helps close the gap between traditionally expensive (ADAM) and cheap (SGD) optimizers, as they will only take a small part of training step time and have similar peak memory usage. It helped us to achieve state-of-the-art training performance in Google's MLPerf 0.6 submission.
연구 동기 및 목표
- 데이터 병렬 학습에서 반복적이고 비분할된 가중치 업데이트 계산으로 인한 성능 저하 문제를 해결하기 위해.
- 모든 레플리카에서 전체 가중치 업데이트를 수행하는 데서 기인한 학습 시간을 지배하는 고비용 옵티마이저(예: ADAM)의 런타임을 줄이기 위해.
- 추가 장치를 추가하지 않고도 기존 레플리카를 활용해 가중치와 보조 변수(예: 모멘텀, 분산)를 효율적으로 샤딩할 수 있도록 하기 위해.
- 지능적인 샤딩 전략과 통신 패턴 선택을 통해 통신 및 데이터 포맷 오버헤드를 최소화하기 위해.
- 기존 모델 코드와의 호환성을 유지하면서 대규모 모델에서 상당한 속도 향상과 메모리 절감을 달성하기 위해.
제안 방법
- XLA 계산 그래프에 정적 분석을 적용하여 샤딩에 적합한 반복적 연산(예: 가중치 업데이트)을 식별하기 위해.
- 제어 흐름 분석을 사용해 최적의 통신 지점과 샤딩 후보 연산의 성능 향상 정도를 추정하기 위해.
- 전략적 위치에 샤드된 연산과 통신 원시 연산(예: all-gather, all-reduce)을 삽입하기 위해 계산 그래프를 변환하기 위해.
- 가속기 메모리 레이아웃(예: 타일링된 메모리)과 일치하는 데이터 샤딩 형식을 설계하여 통신 및 메모리 접근 비용을 최소화하기 위해.
- 소규모 학습(샤드 크기 최소화)과 대규모 학습(통신 지연 최소화) 시나리오에 맞는 서로 다른 샤딩 전략을 적용하기 위해.
- 측면 효과가 최소한인 XLA의 기능적 IR을 활용하여 분석을 단순화하고, 가중치 텐서의 라이브 범위를 줄이는 등의 고급 최적화를 가능하게 하기 위해.
실험 결과
연구 질문
- RQ1데이터 병렬 학습에서 모델 코드 수정 없이도 가중치 업데이트 계산을 자동으로 레플리카 간으로 샤딩할 수 있는가?
- RQ2가중치와 보조 변수를 레플리카 간으로 샤딩할 경우의 성능 및 통신 오버헤드 간 상호 트레이드오프는 어떠한가?
- RQ3자동 샤딩이 가중치 및 보조 텐서의 라이브 범위를 줄여 메모리 사용, 특히 피크 메모리에 어떤 영향을 미치는가?
- RQ4대규모 학습에서 ADAM과 비교해 SGD에 비해 고비용 옵티마이저의 런타임을 얼마나 줄일 수 있는가?
- RQ5다양한 가중치 크기와 배치 크기를 가진 모델들 간에 샤딩 전략의 효과는 어떻게 다를 수 있는가?
주요 결과
- 1024개의 TPUv3 코어에서 Transformer 모델의 스텝 시간이 46.5ms에서 25.6ms로 감소하여 45%의 감소를 기록했다.
- 작은 모델인 ResNet-50의 경우에도 1024개 코어에서 스케일업 시 6%의 속도 향상을 관찰하여 광범위한 적용 가능성을 입증했다.
- 언어 모델인 Transformer는 16개 코어에서 작은 스케일에서 9%의 성능 향상을 기록했으며, 이는 대규모 가중치를 가진 모델에 미치는 영향을 잘 보여준다.
- 최적화로 인해 피크 메모리 사용량이 감소하여 보조 변수용 버퍼 재사용이 가능해졌고, 특히 큰 보조 변수 오버헤드를 가진 NCF와 같은 모델에 유리했다.
- ADAM과 SGD 간의 피크 메모리 사용량이 유사해졌으며, 고비용과 저비용 옵티마이저 간의 메모리 격차를 효과적으로 해소했다.
- 모델 코드 수정 없이 기존 레플리카만을 사용하여 추가 하드웨어나 인프라 없이도 구현 가능했다.
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.