Equinox: neural networks in JAX via callable PyTrees and filtered transformations
read the original abstract
JAX and PyTorch are two popular Python autodifferentiation frameworks. JAX is based around pure functions and functional programming. PyTorch has popularised the use of an object-oriented (OO) class-based syntax for defining parameterised functions, such as neural networks. That this seems like a fundamental difference means current libraries for building parameterised functions in JAX have either rejected the OO approach entirely (Stax) or have introduced OO-to-functional transformations, multiple new abstractions, and been limited in the extent to which they integrate with JAX (Flax, Haiku, Objax). Either way this OO/functional difference has been a source of tension. Here, we introduce `Equinox', a small neural network library showing how a PyTorch-like class-based approach may be admitted without sacrificing JAX-like functional programming. We provide two main ideas. One: parameterised functions are themselves represented as `PyTrees', which means that the parameterisation of a function is transparent to the JAX framework. Two: we filter a PyTree to isolate just those components that should be treated when transforming (`jit', `grad' or `vmap'-ing) a higher-order function of a parameterised function -- such as a loss function applied to a model. Overall Equinox resolves the above tension without introducing any new programmatic abstractions: only PyTrees and transformations, just as with regular JAX. Equinox is available at \url{https://github.com/patrick-kidger/equinox}.
This paper has not been read by Pith yet.
Forward citations
Cited by 11 Pith papers
-
Provable Data Scaling Law for Meta Learning via Complexity Minimization
A novel complexity minimization meta-learning framework provably demonstrates that few-shot adaptation error decreases as meta-training data volume increases.
-
Observer-robust energy condition verification for warp drive spacetimes
Warpax toolkit demonstrates that observer-robust optimization finds more extensive and severe energy-condition violations in warp drive metrics than single-frame Eulerian analysis.
-
AMIGO: a Data-Driven Calibration of the JWST Interferometer
AMIGO is an end-to-end differentiable forward model of JWST AMI that corrects detector systematics to recover high-precision astrometry and detect close high-contrast companions.
-
Learning partially observed systems with neural Hamiltonian ordinary differential equations
NHODE framework learns partially observed dynamical systems by combining Hamiltonian neural networks with neural ODEs, enforcing energy conservation and improving long-horizon stability over data-driven baselines on m...
-
Convex Optimization for Alignment and Preference Learning on a Single GPU
COALA applies convex optimization reformulations of neural networks to direct preference optimization, claiming single-GPU training with ~18% of DPO's TFLOPs and competitive performance on multiple datasets and models...
-
Closed-form predictive coding via hierarchical Gaussian filters
Predictive coding is recast as deep hierarchical Gaussian filters to restore precision-weighted message passing, yielding closed-form inference and online precision learning that matches backpropagation speed on Fashi...
-
A Unifying Framework for Parallelizing Sequential Models with Linear Dynamical Systems
A framework based on linear dynamical systems unifies fixed-point iteration schemes such as Newton, Picard, and Jacobi as approximate linearizations of nonlinear recursions for parallelizing sequential models.
-
On the boundary cost of source-consistent warp shells
Source-consistent warp shells fail energy conditions at the source-vacuum boundary in all examined constructions and parameter scans.
-
GCImOpt: Learning efficient goal-conditioned policies by imitating optimal trajectories
GCImOpt trains compact goal-conditioned neural policies by imitating efficiently generated optimal trajectories, achieving high success rates and near-optimal performance on cart-pole, quadcopter, and robot arm tasks ...
-
Uncertainty in Physics and AI: Taxonomy, Quantification, and Validation
A unified taxonomy of uncertainty in ML for physics is introduced together with validation tools such as coverage, calibration, and proper scoring rules, illustrated on regression and classification tasks.
-
jNO: A JAX Library for Neural Operator and Foundation Model Training
jNO introduces a unified JAX tracing system for data-driven and physics-informed neural operator training that compiles domains, residuals, losses, and diagnostics into one pipeline.
discussion (0)
Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.