[论文解读] Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro
本文介绍 NumPyro,一个基于 NumPy 的概率编程库,使用可组合的 effect 处理程序与 JAX 的变换来实现端到端的 JIT 编译并带来显著加速,其中包括一个迭代式的 NUTS 采样器,优于现有实现。
NumPyro is a lightweight library that provides an alternate NumPy backend to the Pyro probabilistic programming language with the same modeling interface, language primitives and effect handling abstractions. Effect handlers allow Pyro's modeling API to be extended to NumPyro despite its being built atop a fundamentally different JAX-based functional backend. In this work, we demonstrate the power of composing Pyro's effect handlers with the program transformations that enable hardware acceleration, automatic differentiation, and vectorization in JAX. In particular, NumPyro provides an iterative formulation of the No-U-Turn Sampler (NUTS) that can be end-to-end JIT compiled, yielding an implementation that is much faster than existing alternatives in both the small and large dataset regimes.
研究动机与目标
- 演示类似 Pyro 的效应处理程序可以与基于 JAX 的后端集成,从而扩展 NumPyro 的建模与推断接口。
- 展示可组合变换(jit、grad、vmap)如何在保持 Pyro 兼容的建模语言的同时,加速推断子例程。
- 给出一个迭代式、端到端 JIT 编译的 NUTS 实现,利用 XLA 在 CPU 与 GPU 上带来加速。
- 展示使用 vmap 将子例程向量化以实现批量推断与预测的好处。
- 对 NumPyro 与 Stan 和 Pyro 进行基准测试,以量化在小型和大型数据集上的性能提升。
提出的方法
- 在 NumPyro 内采用类似 Pyro 的效应处理抽象(seed、trace、condition),以在 JAX 后端实现可组合的程序变换。
- 利用 JAX 变换(jit、grad、vmap)构建并优化具有函数式、可追踪代码的推断子例程。
- 将 NUTS BuildTree 子例程转换成迭代形式,以实现端到端 JIT 编译并通过 XLA 提高算子融合。
- 使用 vmap 将常见推断任务(先验采样、后验预测采样和对数似然计算)向量化批处理。
- 演示效应处理器与 JAX 变换的集成,以在获得加速的同时保持 Pyro 兼容的建模接口。
实验结果
研究问题
- RQ1Pyro 风格的效应处理程序是否可以与 JAX 变换有效组合,以加速 NumPyro 中的概率编程工作负载?
- RQ2与 Pyro 和 Stan 相比,NumPyro 对推断子例程(如 NUTS)进行端到端 JIT 编译所带来的性能提升是多少?
- RQ3通过 vmap 的向量化如何与效应处理程序交互,以实现可扩展的批量推断与预测?
- RQ4迭代、JIT 编译的 NUTS 实现在不同规模数据集上在保持正确性的同时是否仍然提供显著的加速?
- RQ5在使用 NumPyro 的可组合变换处理大规模模型时,在 CPU、GPU 和 TPU 后端上的实际收益是多少?
主要发现
- NumPyro 的迭代式 NUTS 实现对所测试的模型显著快于 Pyro 和 Stan,在某些任务上报告的加速大约为 Pyro 的 340x 和 Stan 的 6x。
- 通过 XLA 的端到端 JIT 编译通过实现完整融合和推断子例程的优化执行带来显著加速。
- 使用 vmap 将推断向量化使先验和后验预测分布的批量采样以及对数似然计算的批量化变得高效。
- NumPyro 的框架在保持 Pyro 兼容的建模语言的同时,利用 JAX 变换在 CPU、GPU 与 TPU 后端实现加速。
- 该论文显示端到端加速在小规模 CPU 情况和大规模 GPU 情况下均有观察到,推动了在 JAX 下可组合效应的更广泛适用性。
更好的研究,从现在开始
从论文设计到论文写作,大幅缩短您的研究时间。
无需绑定信用卡
本解读由 AI 生成,并经人工编辑审核。