mpi4jax
extending-jax
Our great sponsors
mpi4jax | extending-jax | |
---|---|---|
1 | 2 | |
371 | 352 | |
7.3% | - | |
6.7 | 3.5 | |
15 days ago | 6 months ago | |
Python | Python | |
MIT License | MIT License |
Stars - the number of stars that a project has on GitHub. Growth - month over month growth in stars.
Activity is a relative number indicating how actively a project is being developed. Recent commits have higher weight than older ones.
For example, an activity of 9.0 indicates that a project is amongst the top 10% of the most actively developed projects that we are tracking.
mpi4jax
-
[D] Jax (or other libraries) when not using GPUs/TPUs but CPUs.
I've seen a couple of posts of folks using JAX for scientific computing (e.g. physics) workloads without much issue. The parallel primitives work just as well across multiple CPUs as they do on accelerators. If you're on a cluster, also worth looking into https://github.com/PhilipVinc/mpi4jax.
extending-jax
-
[D] Should We Be Using JAX in 2022?
You can check out this or this for more info. I think it is safe to assume that it is less stable than PyTorch - some other commenters have spoken about running into trouble with XLA in certain corner cases, but I have not experienced this so I can't speak to it.
- Extending JAX with custom C++ and CUDA code
What are some alternatives?
horovod - Distributed training framework for TensorFlow, Keras, PyTorch, and Apache MXNet.
einops - Flexible and powerful tensor operations for readable and reliable code (for pytorch, jax, TF and others)
Dask - Parallel computing with task scheduling
thinc - 🔮 A refreshing functional take on deep learning, compatible with your favorite libraries
Bulk - A modern interface for implementing bulk-synchronous parallel programs.
equinox - Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
devito - DSL and compiler framework for automated finite-differences and stencil computation
trax - Trax — Deep Learning with Clear Code and Speed
elegy - A High Level API for Deep Learning in JAX