Recognition: 2 theorem links
· Lean TheoremComposable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro
Pith reviewed 2026-05-15 15:58 UTC · model grok-4.3
The pith
NumPyro composes Pyro effect handlers with JAX to deliver a fully JIT-compiled iterative NUTS sampler.
A machine-rendered reading of the paper's core claim, the machinery that carries it, and where it could break.
Core claim
NumPyro shows that Pyro's effect handlers compose with JAX's functional transformations to preserve the original modeling API while adding hardware acceleration and automatic differentiation. In particular it supplies an iterative formulation of the No-U-Turn Sampler that can be compiled end-to-end with JAX's JIT, producing faster runtimes than existing implementations in both the small-data and large-data regimes.
What carries the argument
Effect handlers that extend Pyro's modeling abstractions to JAX's functional transformations for acceleration and compilation.
If this is right
- Probabilistic models written in the Pyro interface can run with full JIT compilation and hardware acceleration.
- The same modeling code benefits from vectorization and automatic differentiation supplied by JAX.
- Inference scales to both small and large datasets without separate code paths.
- Effect-handler composition becomes a reusable pattern for adding new backends to probabilistic programming languages.
Where Pith is reading between the lines
- The same handler-composition technique could be applied to accelerate other MCMC or variational methods inside JAX.
- Models could be automatically ported between CPU, GPU, and TPU execution without rewriting inference logic.
- New modeling primitives that exploit JAX's functional purity might become feasible once the handler layer is stable.
Load-bearing premise
Pyro effect handlers compose cleanly with JAX transformations without introducing correctness problems or reducing modeling expressiveness.
What would settle it
A direct runtime comparison on standard benchmark models showing that the NumPyro NUTS implementation is not faster than existing Pyro or Stan alternatives in either the small-dataset or large-dataset regime.
read the original abstract
NumPyro is a lightweight library that provides an alternate NumPy backend to the Pyro probabilistic programming language with the same modeling interface, language primitives and effect handling abstractions. Effect handlers allow Pyro's modeling API to be extended to NumPyro despite its being built atop a fundamentally different JAX-based functional backend. In this work, we demonstrate the power of composing Pyro's effect handlers with the program transformations that enable hardware acceleration, automatic differentiation, and vectorization in JAX. In particular, NumPyro provides an iterative formulation of the No-U-Turn Sampler (NUTS) that can be end-to-end JIT compiled, yielding an implementation that is much faster than existing alternatives in both the small and large dataset regimes.
Editorial analysis
A structured set of objections, weighed in public.
Referee Report
Summary. The paper introduces NumPyro as a lightweight NumPy-based backend for the Pyro probabilistic programming language that preserves the same modeling interface and effect-handling abstractions. It demonstrates that Pyro effect handlers compose with JAX's functional transformations (JIT, autodiff, vectorization) to support an iterative formulation of the No-U-Turn Sampler (NUTS) that is end-to-end JIT-compilable, yielding substantially faster inference than existing alternatives across both small- and large-dataset regimes.
Significance. If the performance and correctness claims hold, the work is significant for showing how effect-handler composition can bridge imperative PPL APIs with functional autodiff frameworks, enabling scalable, hardware-accelerated sampling without loss of modeling expressiveness. The engineering result directly addresses practical bottlenecks in Bayesian inference for machine-learning models.
major comments (2)
- [Abstract and §4] Abstract and §4 (results): the central claim that the JIT-compiled iterative NUTS is 'much faster than existing alternatives in both the small and large dataset regimes' is load-bearing yet supported only by high-level statements; specific wall-clock timings, hardware specifications, baseline implementations (e.g., Pyro, Stan, TensorFlow Probability), and dataset sizes must be reported with error bars or multiple runs to allow verification.
- [§3] §3 (NUTS formulation): the iterative NUTS algorithm is presented as end-to-end JIT-compatible, but the manuscript does not explicitly address potential non-differentiable control flow or side-effect leakage when the effect handlers are transformed; a short proof sketch or counter-example check would strengthen the correctness argument.
minor comments (2)
- [§4] Add a table or figure in §4 that directly tabulates speedup factors versus the closest competing samplers for the reported models.
- [Introduction] Clarify in the introduction whether the modeling API is byte-for-byte identical to Pyro or admits any syntactic differences.
Simulated Author's Rebuttal
We thank the referee for their careful reading and positive recommendation for minor revision. We address the major comments point-by-point below, agreeing to incorporate additional details and clarifications in the revised manuscript.
read point-by-point responses
-
Referee: [Abstract and §4] Abstract and §4 (results): the central claim that the JIT-compiled iterative NUTS is 'much faster than existing alternatives in both the small and large dataset regimes' is load-bearing yet supported only by high-level statements; specific wall-clock timings, hardware specifications, baseline implementations (e.g., Pyro, Stan, TensorFlow Probability), and dataset sizes must be reported with error bars or multiple runs to allow verification.
Authors: We agree that the performance claims would benefit from more detailed empirical support. In the revised manuscript, we will expand §4 to include specific wall-clock timings, hardware specifications (such as the CPU and GPU models used), the exact baseline implementations (Pyro, Stan, TensorFlow Probability), dataset sizes, and results reported as means with standard deviations over multiple independent runs. This will provide the necessary quantitative evidence for the 'much faster' claim in both small and large dataset regimes. revision: yes
-
Referee: [§3] §3 (NUTS formulation): the iterative NUTS algorithm is presented as end-to-end JIT-compatible, but the manuscript does not explicitly address potential non-differentiable control flow or side-effect leakage when the effect handlers are transformed; a short proof sketch or counter-example check would strengthen the correctness argument.
Authors: We thank the referee for highlighting this point on correctness. The effect handlers in NumPyro are implemented to be fully compatible with JAX's functional transformations, ensuring no side-effect leakage and that control flow remains traceable. In the revision, we will add a short paragraph in §3 providing a sketch of why the iterative NUTS formulation avoids non-differentiable operations and side effects, referencing the pure functional nature of the handlers and JAX's tracing mechanism. If space permits, we can include a brief counter-example check or note on the absence of such issues in our implementation. revision: yes
Circularity Check
No significant circularity in implementation description
full rationale
This is an implementation paper presenting NumPyro as a JAX-based backend for Pyro's modeling interface. The central claim concerns the engineering outcome of composing effect handlers with JAX transformations to enable an end-to-end JIT-compilable iterative NUTS sampler, with reported performance gains. No mathematical derivations, parameter fits, or self-referential equations appear in the provided text that reduce to their own inputs by construction. The work is self-contained as a software design and benchmarking description, with no load-bearing steps that match the enumerated circularity patterns.
Axiom & Free-Parameter Ledger
axioms (1)
- domain assumption Pyro effect handlers can be composed with JAX program transformations while preserving modeling semantics
Forward citations
Cited by 20 Pith papers
-
Mixed neural posterior estimation for simulators with discrete and continuous parameters
Extends NPE to mixed discrete-continuous parameter spaces via a factorized inference network combining an autoregressive classifier and generative model, trained jointly to yield accurate calibrated posteriors.
-
Variational predictive resampling
Variational predictive resampling iteratively imputes data from a variational predictive to produce posterior samples that converge to the exact Bayesian posterior in Gaussian models where mean-field VI retains a gap.
-
Variational predictive resampling
Variational predictive resampling uses sequential imputation from variational predictives to generate samples whose distribution converges to the exact Bayesian posterior in Gaussian models and improves dependence cap...
-
Bayesian Doppler Imaging: Simultaneous Inference of Surface Maps and Geometric Parameters
A fully Bayesian pixel-based Doppler imaging framework uses Gaussian Process priors and Hamiltonian Monte Carlo to simultaneously infer surface maps and geometric parameters from spectral data.
-
ADELIA: Automatic Differentiation for Efficient Laplace Inference Approximations
ADELIA is the first AD-enabled INLA system that computes exact hyperparameter gradients via a structure-exploiting multi-GPU backward pass, delivering 4.2-7.9x per-gradient speedups and 5-8x better energy efficiency t...
-
Archival Multiband Gravitational-Wave Signals from Massive Black Hole Binary Mergers
Massive black hole binary mergers produce orphaned low-frequency signals in PTA pulsar terms that can be stacked for archival multiband gravitational-wave detection.
-
High-dimensional inference for the $\gamma$-ray sky with differentiable programming
A differentiable forward model and likelihood enable probabilistic inference over many spatial morphologies for the Galactic Center gamma-ray Excess using variational methods on GPUs.
-
Stories in Space: In-Context Learning Trajectories in Conceptual Belief Space
LLMs perform in-context learning as trajectories through a structured low-dimensional conceptual belief space, with the structure visible in both behavior and internal representations and causally manipulable via inte...
-
A hierarchical Bayesian pipeline for soliton-plus-NFW inference on SPARC rotation curves: diagnostics and prior-boundary behaviour
A hierarchical Bayesian pipeline applied to 106 SPARC galaxies yields posteriors that reach prior boundaries for soliton parameters, indicating no detectable interior population-level soliton within the Schive-normali...
-
What You Don't Know Won't Hurt You: Self-Consistent Hierarchical Inference with Unknown Follow-up Selection Strategies
Hierarchical Bayesian inference allows accurate recovery of intrinsic astrophysical source populations even when follow-up selection is unknown and correlated with parameters of interest.
-
Bayesian Modeling and Prediction of Generalized Contact Matrices
A Bayesian model for multi-feature contact matrices that uses tensor structures and contingency table theory to satisfy structural constraints and impute missing contact features, validated on simulations and US/Germa...
-
Towards E-Value Based Stopping Rules for Bayesian Deep Ensembles
E-value sequential tests enable early stopping of MCMC sampling in Bayesian deep ensembles, often needing only a fraction of the full budget while improving over standard deep ensembles.
-
A unified harmonic framework for dark siren cosmology
The GW-galaxy cross-correlation method, unified with spectral sirens in a harmonic framework, can measure H0 to 1% and Omega_m to 5% precision with 2 years of data from next-generation detectors like Einstein Telescop...
-
Towards Understanding Sycophancy in Language Models
Sycophancy is prevalent in state-of-the-art AI assistants and is likely driven in part by human preferences that favor agreement over truthfulness.
-
A Uniform Determination of the Bulk Metallicities and Alpha Enrichments of Confirmed Exoplanet Systems with TRES
A uniform spectroscopic catalog of 625 exoplanet hosts shows subsolar-metallicity giant-planet hosts are alpha-enhanced relative to both iron-rich hosts and typical metal-poor field stars.
-
Plato's view on supermassive black hole binaries: Exploring the faint limit of ESA's Plato space mission
Simulations show Plato can recover relativistic photometric signatures of supermassive black hole binaries in bright quasars (G≤18) via Bayesian inference on mock light curves.
-
Mitigating effects of telescope jitter through differentiable forward-modeling
Differentiable optical simulation models telescope jitter blurring and shows that two-dimensional jitter models avoid systematic bias in binary separation measurements for the TOLIMAN exoplanet mission.
-
Fast Bayesian equipment condition monitoring via simulation based inference: applications to heat exchanger health
Amortized neural posterior estimation via simulation-based inference delivers 82x faster inference than MCMC for heat exchanger fouling and leakage diagnosis while maintaining comparable accuracy on synthetic data.
-
Neural posterior estimation for scalable and accurate inverse parameter inference in Li-ion batteries
NPE delivers millisecond-scale parameter inference for Li-ion batteries that matches or exceeds Bayesian calibration accuracy while adding local sensitivity interpretability, though with higher voltage prediction errors.
-
GW250114: testing Hawking's area law and the Kerr nature of black holes
GW250114 data confirm the remnant black hole ringdown frequencies lie within 30% of Kerr predictions and that the final horizon area is larger than the sum of the progenitors' areas to high credibility.
Reference graph
Works this paper leans on
-
[1]
Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro
and Edward2 [6] based on TensorFlow, and PyMC3 [7] based on Theano. NumPyro is a package for probabilistic programming built atop JAX [8, 9], which is a high-level tracing library for program transformations (e.g. automatic differentiation, vectorization and JIT compilation) of Python and NumPy functions. Thus NumPyro enables users to write probabilistic ...
work page internal anchor Pith review Pith/arXiv arXiv 1912
-
[2]
An introduction to probabilistic programming
Jan-Willem van de Meent, Brooks Paige, Hongseok Yang, and Frank Wood. An introduction to probabilistic programming. arXiv preprint arXiv:1809.10756, 2018
-
[3]
Eli Bingham, Jonathan P. Chen, Martin Jankowiak, Fritz Obermeyer, Neeraj Pradhan, Theofanis Karaletsos, Rohit Singh, Paul Szerlip, Paul Horsfall, and Noah D. Goodman. Pyro: Deep universal probabilistic programming. Journal of Machine Learning Research , 20(28):1–6, 2019. URL http://jmlr.org/ papers/v20/18-403.html
work page 2019
-
[4]
Learning disentangled representations with semi-supervised deep generative models
Siddharth Narayanaswamy, T Brooks Paige, Jan-Willem Van de Meent, Alban Desmaison, Noah Goodman, Pushmeet Kohli, Frank Wood, and Philip Torr. Learning disentangled representations with semi-supervised deep generative models. In Advances in Neural Information Processing Systems , pages 5925–5935, 2017
work page 2017
-
[5]
Automatic differentiation in pytorch
Adam Paszke, Sam Gross, Soumith Chintala, Gregory Chanan, Edward Yang, Zachary DeVito, Zeming Lin, Alban Desmaison, Luca Antiga, and Adam Lerer. Automatic differentiation in pytorch. 2017
work page 2017
-
[6]
Joshua V Dillon, Ian Langmore, Dustin Tran, Eugene Brevdo, Srinivas Vasudevan, Dave Moore, Brian Patton, Alex Alemi, Matt Hoffman, and Rif A Saurous. Tensorflow distributions. arXiv preprint arXiv:1711.10604, 2017
work page internal anchor Pith review Pith/arXiv arXiv 2017
-
[7]
Simple, distributed, and accelerated probabilistic programming
Dustin Tran, Matthew W Hoffman, Dave Moore, Christopher Suter, Srinivas Vasudevan, and Alexey Radul. Simple, distributed, and accelerated probabilistic programming. In Advances in Neural Information Processing Systems, pages 7598–7609, 2018
work page 2018
-
[8]
Wiecki, and Christopher Fonnesbeck
John Salvatier, Thomas V . Wiecki, and Christopher Fonnesbeck. Probabilistic programming in python using PyMC3. PeerJ Computer Science , 2:e55, apr 2016. doi: 10.7717/peerj-cs.55. URL https: //doi.org/10.7717/peerj-cs.55
-
[9]
JAX: composable transformations of Python+NumPy programs, 2018
James Bradbury, Roy Frostig, Peter Hawkins, Matthew James Johnson, Chris Leary, Dougal Maclaurin, and Skye Wanderman-Milne. JAX: composable transformations of Python+NumPy programs, 2018. URL http://github.com/google/jax
work page 2018
-
[10]
Compiling machine learning programs via high-level tracing
Roy Frostig, Matthew Johnson, and Chris Leary. Compiling machine learning programs via high-level tracing. 2018. URL http://www.sysml.cc/doc/2018/146.pdf. 5
work page 2018
-
[11]
https://www.tensorflow.org/xla/
XLA: Optimizing Compiler for Machine Learning. https://www.tensorflow.org/xla/
-
[12]
Effect Handling for Composable Program Transformations in Edward2
Dave Moore and Maria I. Gorinova. Effect handling for composable program transformations in edward2. CoRR, abs/1811.06150, 2018. URL http://arxiv.org/abs/1811.06150
work page internal anchor Pith review Pith/arXiv arXiv 2018
-
[13]
Gordon Plotkin and Matija Pretnar. Handlers of algebraic effects. In Giuseppe Castagna, editor, Program- ming Languages and Systems , pages 80–94, Berlin, Heidelberg, 2009. Springer Berlin Heidelberg. ISBN 978-3-642-00590-9
work page 2009
-
[14]
The JAX Team. JAX PRNG Design. https://github.com/google/jax/blob/master/design_ notes/prng.md, 2019
work page 2019
-
[15]
Matthew D. Hoffman and Andrew Gelman. The no-u-turn sampler: Adaptively setting path lengths in hamiltonian monte carlo. Journal of Machine Learning Research , 15:1593–1623, 2014. URL http: //jmlr.org/papers/v15/hoffman14a.html
work page 2014
-
[16]
Simon Duane, Anthony D Kennedy, Brian J Pendleton, and Duncan Roweth. Hybrid monte carlo. Physics letters B, 195(2):216–222, 1987
work page 1987
-
[17]
MCMC Using Hamiltonian Dynamics
Radford Neal. MCMC Using Hamiltonian Dynamics . CRC Press, May 2011. doi: 10.1201/b10905-6. URL http://dx.doi.org/10.1201/b10905-6
-
[18]
Stochastic variational inference
Matthew D Hoffman, David M Blei, Chong Wang, and John Paisley. Stochastic variational inference. The Journal of Machine Learning Research, 14(1):1303–1347, 2013
work page 2013
-
[19]
Stan: A probabilistic programming language
Bob Carpenter, Andrew Gelman, Matthew D Hoffman, Daniel Lee, Ben Goodrich, Michael Betancourt, Marcus Brubaker, Jiqiang Guo, Peter Li, and Allen Riddell. Stan: A probabilistic programming language. Journal of statistical software, 76(1), 2017
work page 2017
-
[20]
Allen Riddell, Ari Hartikainen, Daniel Lee, riddell stan, Marco Inacio, Daniel Chen, Kenneth C. Arnold, Dougal J. Sutherland, Aki Vehtari, Shinya SUZUKI, Takahiro Kubo, Todd Small, Tobias Erhardt, Stephen Hoover, Stephan Hoyer, Richard C Gerkin, Joerg Rings, Jackie, J. J. Ramsey, Aaron Darling, seantalts, Skipper Seabold, Max Shron, Liam Brannigan, Kyle F...
work page 2018
-
[21]
UCI machine learning repository, 2017
Dheeru Dua and Casey Graff. UCI machine learning repository, 2017. URL http://archive.ics.uci. edu/ml
work page 2017
-
[22]
The kernel interaction trick: Fast Bayesian discovery of pairwise interactions in high dimensions
Raj Agrawal, Brian Trippe, Jonathan Huggins, and Tamara Broderick. The kernel interaction trick: Fast Bayesian discovery of pairwise interactions in high dimensions. In Kamalika Chaudhuri and Ruslan Salakhutdinov, editors, Proceedings of the 36th International Conference on Machine Learning , volume 97 of Proceedings of Machine Learning Research, pages 14...
-
[23]
URL http://proceedings.mlr.press/v97/agrawal19a.html
PMLR. URL http://proceedings.mlr.press/v97/agrawal19a.html
-
[24]
Stan Modeling Language User’s Guide and Reference Manual, V ersion 2.18.0
Stan Development Team. Stan Modeling Language User’s Guide and Reference Manual, V ersion 2.18.0
-
[25]
URL http://mc-stan.org. 6 index 1 2 3 4 0 depth 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 Figure 3: A graphical representation of how binary trees are constructed in ITERATIVE BUILD TREE. The orange node is the leaf generated at the current step. Blue nodes are the leaves stored in memory for the purpose of checking the U-Turn condition. White nodes are past ...
work page 2080
discussion (0)
Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.