[논문 리뷰] GSPMD: General and Scalable Parallelization for ML Computation Graphs
GSPMD는 자동적이고 컴파일러 기반의 시스템으로, 단순 텐서 샤딩 주석을 사용하여 ML 계산 그래프를 자동으로 분할하며, 장치 간 확장 가능한 데이터/모델/파이프라인 병렬화를 가능하게 한다.
We present GSPMD, an automatic, compiler-based parallelization system for common machine learning computations. It allows users to write programs in the same way as for a single device, then give hints through a few annotations on how to distribute tensors, based on which GSPMD will parallelize the computation. Its representation of partitioning is simple yet general, allowing it to express different or mixed paradigms of parallelism on a wide variety of models. GSPMD infers the partitioning for every operator based on limited user annotations, making it convenient to scale existing single-device programs. It solves several technical challenges for production usage, allowing GSPMD to achieve 50% to 62% compute utilization on up to 2048 Cloud TPUv3 cores for models with up to one trillion parameters.
연구 동기 및 목표
- Motivate scalable parallelization of large ML models beyond single-device execution.
- Provide a general, compiler-driven partitioning mechanism that supports multiple parallelism paradigms.
- Allow users to write models as on a single device and annotate sharding to drive distributed execution.
- Enable nested and mixed parallelism patterns (data, model, spatial, optimizer-state) within a unified framework.
제안 방법
- Extend the XLA-based compiler backend to implement a general tensor sharding representation (replicated, tiled, partially tiled).
- Introduce mesh_split API to map tensor dimensions to a device mesh and generate sharding annotations.
- Develop sharding completion and per-operator partitioning passes that propagate and merge shardings across operators.
- Support Single Program Multiple Data (SPMD) partitioning to scale to thousands of partitions and manage static shapes with padding/masking.
- Provide pipeline-parallelism reduction to tensor sharding via a wrapper that translates micro-batching pipelines into stage-sharded tensors.
- Incorporate nested and recursive partitioning to handle complex operators (e.g., Einsum, Convolution) and rank-polymorphic ops.
실험 결과
연구 질문
- RQ1Can GSPMD express and combine diverse parallelism patterns (data, model, spatial, optimizer-state, pipeline) in a single framework?
- RQ2How effectively can limited user annotations drive complete sharding across an ML computation graph?
- RQ3What are the practical challenges of SPMD partitioning (static shapes, halo exchanges, communication) and how can they be addressed?
- RQ4How does GSPMD perform in terms of compute utilization and memory scaling on large TPU deployments?
- RQ5Can GSPMD enable nested and heterogeneous parallelism patterns for multimodal or large-scale models?
주요 결과
- Achieves 50% to 62% compute utilization on up to 2048 Cloud TPUv3 cores for models up to one trillion parameters.
- Supports uneven/static shapes and merges shardings across operators to enable nested parallelism without extensive handwritten rules.
- Expresses data, model, spatial, and optimizer-state sharding in a unified tensor-sharding model with a simple device mesh API (mesh_split).
- Reduces pipeline parallelism to layer-wise sharding and enables pipelining with a lightweight wrapper, avoiding separate pipeline infrastructure.
- Maintains a single program for all partitions (SPMD) to keep compilation scalable while handling complex operator semantics (e.g., Convolution, Einsum).
- Demonstrates integration with XLA/TensorFlow/JAX backends and supports production-style compilation and communication primitives (AllReduce, AllGather, ReduceScatter, CollectivePermute).
더 나은 연구,지금 바로 시작하세요
연구 설계부터 논문 작성까지, 연구 시간을 획기적으로 줄여보세요.
카드 등록 없음 · 무료 플랜 제공
이 리뷰는 AI가 만들고, 인간 에디터가 검토했습니다.