[D] Should We Be Using JAX in 2022?

This page summarizes the projects mentioned and recommended in the original post on /r/MachineLearning

SaaSHub - Software Alternatives and Reviews
SaaSHub helps you find the best software and product alternatives
www.saashub.com
featured
  1. flax

    Flax is a neural network library for JAX that is designed for flexibility.

    What's your favorite Deep Learning API for JAX - Flax, Haiku, Elegy, something else?

  2. SaaSHub

    SaaSHub - Software Alternatives and Reviews. SaaSHub helps you find the best software and product alternatives

    SaaSHub logo
  3. dm-haiku

    JAX-based neural network library

    What's your favorite Deep Learning API for JAX - Flax, Haiku, Elegy, something else?

  4. elegy

    A High Level API for Deep Learning in JAX

    What's your favorite Deep Learning API for JAX - Flax, Haiku, Elegy, something else?

  5. jax

    Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

    It really is quite hard to tell at this point. If we're talking just about Deep Learning, I think that JAX could be an awesome supplement to TensorFlow - since they both use XLA it's easy to move a model from JAX to TensorFlow, so hypothetically you could build in JAX and move to TF for deployment, but I don't know that that will be that useful in an industry setting.

  6. equinox

    Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/

    Author of Equinox here. I'm glad to see it being mentioned in the wild!

  7. diffrax

    Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/

    Some nice examples of this -- and in fact the whole reason Equinox exists -- can be found ubiquitously throughout Diffrax library. (A new JAX-based suite of diffeq solvers.) For example diffrax.AbstractSolver is an abstract parameterised function; diffrax.PIDController is a concrete instantiation of another abstract parameterised function. You can do some pretty cool stuff with this :)

  8. jax-models

    Unofficial JAX implementations of deep learning research papers

    I've been using JAX, especially Flax for quite some time now for my reproducibility initiative (jax_models) and this is what I really appreciate about the framework

  9. extending-jax

    Extending JAX with custom C++ and CUDA code

    You can check out this or this for more info. I think it is safe to assume that it is less stable than PyTorch - some other commenters have spoken about running into trouble with XLA in certain corner cases, but I have not experienced this so I can't speak to it.

NOTE: The number of mentions on this list indicates mentions on common posts plus user suggested alternatives. Hence, a higher number means a more popular project.

Suggest a related project

Related posts

  • An Introduction to Neural Ordinary Differential Equations [pdf]

    2 projects | news.ycombinator.com | 11 Jan 2025
  • [P] Optimistix, nonlinear optimisation in JAX+Equinox!

    3 projects | /r/MachineLearning | 14 Oct 2023
  • Show HN: Optimistix: Nonlinear Optimisation in Jax+Equinox

    2 projects | news.ycombinator.com | 10 Oct 2023
  • [D] JAX vs PyTorch in 2023

    5 projects | /r/MachineLearning | 9 Mar 2023
  • Merge-Stable-Diffusion-models-without-distortion-gui

    5 projects | /r/StableDiffusion | 12 Dec 2022

Did you know that Python is
the 1st most popular programming language
based on number of references?