JAX-AMG: A GPU-Accelerated Differentiable Sparse Linear Solver Library for JAX
read the original abstract
Sparse linear systems from PDE discretizations are central to scientific computing, yet no existing JAX-ecosystem solver simultaneously provides GPU-accelerated algebraic multigrid (AMG), automatic differentiation (AD), and distributed multi-GPU execution. JAX-AMG fills this gap by wrapping the Nvidia AmgX solver suite as a native JAX primitive, exposing AMG and Krylov methods with configurable preconditioners through a unified interface compatible with JIT compilation, reverse-mode AD via adjoint methods, batched solves, and MPI-based distributed execution. Solver caching amortizes setup costs across repeated solves, making JAX-AMG practical for PDE-constrained optimization and inverse problems. The result is a robust, scalable sparse linear algebra layer that integrates seamlessly into differentiable simulation and scientific machine learning pipelines.
This paper has not been read by Pith yet.
discussion (0)
Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.