JAX – NumPy on the CPU, GPU, and TPU, with great automatic differentiation

This page summarizes the projects mentioned and recommended in the original post on news.ycombinator.com

Our great sponsors
  • InfluxDB - Power Real-Time Data Analytics at Scale
  • WorkOS - The modern identity platform for B2B SaaS
  • SaaSHub - Software Alternatives and Reviews
  • equinox

    Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/

  • I'm going to disagree here! Classes and functional programming can go very well together, just don't expect to do in-place mutation. (I.e. OO-style programming.)

    You might like Equinox (https://github.com/patrick-kidger/equinox; 1.4k GitHub stars) which deliberately offers a very PyTorch-like feel for JAX.

    Regarding speed, I would strongly recommend JAX over PyTorch for SciComp. The XLA compiler seems to be much more effective for such use cases.

  • jax

    Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

  • Actually that never changed. The README has always had an example of differentiating through native Python control flow:

    https://github.com/google/jax/commit/948a8db0adf233f333f3e5f...

    The constraints on control flow expressions come from jax.jit (because Python control flow can't be staged out) and jax.vmap (because we can't take multiple branches of Python control flow, which we might need to do for different batch elements). But autodiff of Python-native control flow works fine!

  • InfluxDB

    Power Real-Time Data Analytics at Scale. Get real-time insights from all types of time series data with InfluxDB. Ingest, query, and analyze billions of data points in real-time with unbounded cardinality.

    InfluxDB logo
  • jax-experiments

  • Jax is super useful for scientific computing. Although nbody sims might not be the best application. A naive nbody sim is very easy to implement and accelerate in jax (here’s my version: https://github.com/PWhiddy/jax-experiments/blob/main/nbody.i...), but it can be tricky to scale it. This is because efficient nbody sims usually either rely on trees or spatial hashing/sorting which are tricky to efficiently implement with jax.

  • thinc

    🔮 A refreshing functional take on deep learning, compatible with your favorite libraries

  • Agree, though I wouldn’t call PyTorch a drop-in for NumPy either. CuPy is the drop-in. Excepting some corner cases, you can use the same code for both. Thinc’s ops work with both NumPy and CuPy:

    https://github.com/explosion/thinc/blob/master/thinc/backend...

  • jaxonnxruntime

    A user-friendly tool chain that enables the seamless execution of ONNX models using JAX as the backend.

  • jax-md

    Differentiable, Hardware Accelerated, Molecular Dynamics

  • autograd

    Efficiently computes derivatives of numpy code.

  • Actually, that's never been a constraint for JAX autodiff. JAX grew out of the original Autograd (https://github.com/hips/autograd), so differentiating through Python control flow always worked. It's jax.jit and jax.vmap which place constraints on control flow, requiring structured control flow combinators like those.

  • WorkOS

    The modern identity platform for B2B SaaS. The APIs are flexible and easy-to-use, supporting authentication, user identity, and complex enterprise features like SSO and SCIM provisioning.

    WorkOS logo
NOTE: The number of mentions on this list indicates mentions on common posts plus user suggested alternatives. Hence, a higher number means a more popular project.

Suggest a related project

Related posts