Skip to main content
QUICK REVIEW

[论文解读] GSPMD: General and Scalable Parallelization for ML Computation Graphs

Yuanzhong Xu, HyoukJoong Lee|arXiv (Cornell University)|May 10, 2021
Parallel Computing and Optimization Techniques参考文献 24被引用 37
一句话总结

GSPMD 是一个自动化、编译器驱动的系统,使用简单的张量分片注释自动分割 ML 计算图,实现跨设备的可扩展数据/模型/流水线并行。

ABSTRACT

We present GSPMD, an automatic, compiler-based parallelization system for common machine learning computations. It allows users to write programs in the same way as for a single device, then give hints through a few annotations on how to distribute tensors, based on which GSPMD will parallelize the computation. Its representation of partitioning is simple yet general, allowing it to express different or mixed paradigms of parallelism on a wide variety of models. GSPMD infers the partitioning for every operator based on limited user annotations, making it convenient to scale existing single-device programs. It solves several technical challenges for production usage, allowing GSPMD to achieve 50% to 62% compute utilization on up to 2048 Cloud TPUv3 cores for models with up to one trillion parameters.

研究动机与目标

  • 推动大规模 ML 模型超越单设备执行的可扩展并行化。
  • 提供通用、由编译器驱动的分区机制,支持多种并行范式。
  • 允许用户像在单一设备上编写模型,并对分片进行注释以驱动分布式执行。
  • 在一个统一框架内启用嵌套和混合并行模式(数据、模型、空间、优化器状态)。

提出的方法

  • 扩展基于 XLA 的编译后端以实现通用张量分片表示(replicated, tiled, partially tiled)。
  • 引入 mesh_split API 将张量维度映射到设备网格并生成分片注释。
  • 开发分片完成和每个算子分区化的传递,能够跨算子传播并合并分片信息。
  • 支持单程序多数据(SPMD)分区,以扩展到数千个分区并通过填充/掩码管理静态形状。
  • 通过一个包装器提供流水线并行对张量分片的简化,将微批处理流水线转换为阶段分片张量。
  • 纳入嵌套和递归分区以处理复杂算子(例如 Einsum、Convolution)和秩多态运算。

实验结果

研究问题

  • RQ1GSPMD 能否在一个框架中表达并组合多样的并行模式(数据、模型、空间、优化器状态、流水线)?
  • RQ2有限的用户注释在多大程度上能推动对整个 ML 计算图的完整分片?
  • RQ3SPMD 分区的实际挑战(静态形状、halo 交换、通信)有哪些,如何解决?
  • RQ4GSPMD 在大型 TPU 部署上的计算利用率和内存扩展方面的表现如何?
  • RQ5GSPMD 是否能够为多模态或大规模模型实现嵌套和异构并行模式?

主要发现

  • 在高达 2048 Cloud TPUv3 核心的模型(参数量高达一万亿)上实现 50% 到 62% 的计算利用率。
  • 支持不均匀/静态形状并跨算子合并分片,以在不需要大量手写规则的情况下实现嵌套并行。
  • 在简单的设备网格 API(mesh_split)下,将数据、模型、空间和优化器状态分片表达为统一的张量分片模型。
  • 将流水线并行简化为逐层分片,并通过轻量包装器实现流水线,避免单独的流水线基础设施。
  • 为所有分区维护一个单一程序(SPMD),在处理诸如卷积、爱因斯坦求和等复杂算子语义时保持编译可扩展性。
  • 展示与 XLA/TensorFlow/JAX 后端的集成并支持生产级编译和通信原语(AllReduce, AllGather, ReduceScatter, CollectivePermute)。

更好的研究,从现在开始

从论文设计到论文写作,大幅缩短您的研究时间。

无需绑定信用卡

本解读由 AI 生成,并经人工编辑审核。