Skip to main content
QUICK REVIEW

[論文レビュー] Axe: A Simple Unified Layout Abstraction for Machine Learning Compilers

Bohan Hou, Hongyi Jin|arXiv (Cornell University)|Jan 27, 2026
Parallel Computing and Optimization Techniques被引用数 0
ひとこと要約

Axe は、デバイスメッシュ、メモリ階層、アクセラレータ全体でデータと計算のマッピングを統一する名前付き軸を備えたハードウェア認識レイアウト抽象を導入し、手動チューニング済みカーネル性能に近づくマルチ粒度 DSL とコンパイラを実現します。

ABSTRACT

Scaling modern deep learning workloads demands coordinated placement of data and compute across device meshes, memory hierarchies, and heterogeneous accelerators. We present Axe Layout, a hardware-aware abstraction that maps logical tensor coordinates to a multi-axis physical space via named axes. Axe unifies tiling, sharding, replication, and offsets across inter-device distribution and on-device layouts, enabling collective primitives to be expressed consistently from device meshes to threads. Building on Axe, we design a multi-granularity, distribution-aware DSL and compiler that composes thread-local control with collective operators in a single kernel. Experiments show that our unified approach can bring performance close to hand-tuned kernels on across latest GPU devices and multi-device environments and accelerator backends.

研究の動機と目的

  • 分散デバイスとメモリ階層全体での大規模 DL ワークロードに対するランタイム最適化の課題を動機付ける。
  • Axe レイアウトを提案し、名前付きハードウェア軸を用いてデバイス間およびデバイス内のデータと計算のマッピングを統一する。
  • スレッドローカル制御と集合演算子を単一カーネル内で組み合わせる分布認識 DSL とコンパイラを開発する。
  • レイアウト操作とコード生成技術を提供して、ハードウェアネイティブなスケジュールを実現できるようにする。
  • 現代の GPU およびマルチデバイス環境で手動チューニング済みカーネルに近い性能向上を示す。

提案手法

  • Axe を、D(Shard)、R(Replica)、O(Offset)を含む名前付き軸上の座標への論理インデックスの集合値写像として定義する。
  • コード生成の分析・変換のためのレイアウト演算子(canonicalize、group、tile、slice)を導入する。
  • スレッドローカル制御と集合演算子を一つのカーネル内で組み合わせる、マルチ粒度・分布認識の Axe DSL とコンパイラを構築する。
  • 分散テンソルと異種ハードウェアレイアウトを Axe で表現し、ランタイム検査とスケジュール選択を可能にする。
  • ハイレベル演算子(コピー、リダクション、GEMM など)とスケジュールのライブラリを提供し、ハードウェアネイティブ実装(例:TMA、NVSHMEM)へマッピングする。
  • ターゲット(GPU ワープ、マルチ GPU メッシュ、Trainium のような AI アクセラレータ)間でのコード生成を三段階のパイプラインとレイアウト駆動の降下で実演する。

実験結果

リサーチクエスチョン

  • RQ1単一のハードウェア認識レイアウト抽象が、論理テンソル座標をデバイスとメモリ階層を跨る多軸の物理空間へどのようにマッピングできるか?
  • RQ2Axe はデバイス間のシャーディング、デバイス内のタイル、および複製を単一の形式モデルで統一できるか?
  • RQ3Axe 上に構築された分布認識 DSL とコンパイラは、GPU やアクセラレータ全体で手動チューニング済みのカーネル性能に近づくか?
  • RQ4異種環境における分析とコード生成のための Axe レイアウト操作はどれほど効果的か?

主な発見

  • Axe は強力なベースラインと同等またはそれを上回る性能を達成し、FlashInfer および SGLang に対して B200 MoE レイヤで最大 1.32x、1.23x を達成。
  • Axe は cuBLAS+NCCL および Triton-Distributed に対するマルチGPU GEMM+Reduce-Scatter で最大 1.40x のスピードアップを達成。
  • Axe は AI アクセラレータ(Trainium-1)における MHA でベンダーライブラリと比べ最大 1.44x の改善を達成。
  • FP16 GEMM では Axe が cuBLAS のスループットの 97% 〜 100% を達成し、Triton の約 90% を上回る。
  • FP8 GEMM では Axe が DeepGEMM のスループットの 92% 〜 96% を達成。
  • 分散カーネルの状況で、Axe は通信と計算の細粒度のオーバーラップを有効にすることで、ベースラインより低遅延を実現する。

より良い研究を、今すぐ始めましょう

論文設計から論文執筆まで、研究時間を劇的に削減しましょう。

クレジットカード登録不要

このレビューはAIが作成し、人間の編集者が確認しました。