-
What's your favorite Deep Learning API for JAX - Flax, Haiku, Elegy, something else?
-
SaaSHub
SaaSHub - Software Alternatives and Reviews. SaaSHub helps you find the best software and product alternatives
-
What's your favorite Deep Learning API for JAX - Flax, Haiku, Elegy, something else?
-
What's your favorite Deep Learning API for JAX - Flax, Haiku, Elegy, something else?
-
jax
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
It really is quite hard to tell at this point. If we're talking just about Deep Learning, I think that JAX could be an awesome supplement to TensorFlow - since they both use XLA it's easy to move a model from JAX to TensorFlow, so hypothetically you could build in JAX and move to TF for deployment, but I don't know that that will be that useful in an industry setting.
-
equinox
Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Author of Equinox here. I'm glad to see it being mentioned in the wild!
-
diffrax
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Some nice examples of this -- and in fact the whole reason Equinox exists -- can be found ubiquitously throughout Diffrax library. (A new JAX-based suite of diffeq solvers.) For example diffrax.AbstractSolver is an abstract parameterised function; diffrax.PIDController is a concrete instantiation of another abstract parameterised function. You can do some pretty cool stuff with this :)
-
I've been using JAX, especially Flax for quite some time now for my reproducibility initiative (jax_models) and this is what I really appreciate about the framework
-
You can check out this or this for more info. I think it is safe to assume that it is less stable than PyTorch - some other commenters have spoken about running into trouble with XLA in certain corner cases, but I have not experienced this so I can't speak to it.