JAX – NumPy on the CPU, GPU, and TPU, with great automatic differentiation

This page summarizes the projects mentioned and recommended in the original post on news.ycombinator.com

Our great sponsors
  • InfluxDB - Collect and Analyze Billions of Data Points in Real Time
  • Onboard AI - Learn any GitHub repo in 59 seconds
  • SaaSHub - Software Alternatives and Reviews
  • equinox

    Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/

    I'm going to disagree here! Classes and functional programming can go very well together, just don't expect to do in-place mutation. (I.e. OO-style programming.)

    You might like Equinox (https://github.com/patrick-kidger/equinox; 1.4k GitHub stars) which deliberately offers a very PyTorch-like feel for JAX.

    Regarding speed, I would strongly recommend JAX over PyTorch for SciComp. The XLA compiler seems to be much more effective for such use cases.

  • jax

    Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

    Actually that never changed. The README has always had an example of differentiating through native Python control flow:


    The constraints on control flow expressions come from jax.jit (because Python control flow can't be staged out) and jax.vmap (because we can't take multiple branches of Python control flow, which we might need to do for different batch elements). But autodiff of Python-native control flow works fine!

  • InfluxDB

    Collect and Analyze Billions of Data Points in Real Time. Manage all types of time series data in a single, purpose-built database. Run at any scale in any environment in the cloud, on-premises, or at the edge.

  • jax-experiments

    Jax is super useful for scientific computing. Although nbody sims might not be the best application. A naive nbody sim is very easy to implement and accelerate in jax (here’s my version: https://github.com/PWhiddy/jax-experiments/blob/main/nbody.i...), but it can be tricky to scale it. This is because efficient nbody sims usually either rely on trees or spatial hashing/sorting which are tricky to efficiently implement with jax.

  • thinc

    🔮 A refreshing functional take on deep learning, compatible with your favorite libraries

    Agree, though I wouldn’t call PyTorch a drop-in for NumPy either. CuPy is the drop-in. Excepting some corner cases, you can use the same code for both. Thinc’s ops work with both NumPy and CuPy:


  • jaxonnxruntime

    A user-friendly tool chain that enables the seamless execution of ONNX models using JAX as the backend.

  • jax-md

    Differentiable, Hardware Accelerated, Molecular Dynamics

  • autograd

    Efficiently computes derivatives of numpy code.

    Actually, that's never been a constraint for JAX autodiff. JAX grew out of the original Autograd (https://github.com/hips/autograd), so differentiating through Python control flow always worked. It's jax.jit and jax.vmap which place constraints on control flow, requiring structured control flow combinators like those.

  • Onboard AI

    Learn any GitHub repo in 59 seconds. Onboard AI learns any GitHub repo in minutes and lets you chat with it to locate functionality, understand different parts, and generate new code. Use it for free at www.getonboard.dev.

NOTE: The number of mentions on this list indicates mentions on common posts plus user suggested alternatives. Hence, a higher number means a more popular project.

Suggest a related project

Related posts