jax-models
equinox
jax-models | equinox | |
---|---|---|
6 | 31 | |
138 | 1,837 | |
- | - | |
0.0 | 9.2 | |
almost 2 years ago | 6 days ago | |
Python | Python | |
Apache License 2.0 | Apache License 2.0 |
Stars - the number of stars that a project has on GitHub. Growth - month over month growth in stars.
Activity is a relative number indicating how actively a project is being developed. Recent commits have higher weight than older ones.
For example, an activity of 9.0 indicates that a project is amongst the top 10% of the most actively developed projects that we are tracking.
jax-models
-
[D] How to contribute to open source ML and DL without having access to high quality setup?
I was in the same position as you are and the best thing you can do is to start reproducing papers (that's what I did with jax-models). This will
-
[D] Should We Be Using JAX in 2022?
I've been using JAX, especially Flax for quite some time now for my reproducibility initiative (jax_models) and this is what I really appreciate about the framework
- Weekly updated open sourced model implementations in Flax
- Weekly updated open sourced deep learning model implementations in Flax
- [P] Weekly updated open sourced model implementations in Flax
equinox
-
Ask HN: What side projects landed you a job?
I wrote a JAX-based neural network library (Equinox [1]) and numerical differential equation solving library (Diffrax [2]).
At the time I was just exploring some new research ideas in numerics -- and frankly, procrastinating from writing up my PhD thesis!
But then one of the teams at Google starting using them, so they offered me a job to keep developing them for their needs. Plus I'd get to work in biotech, which was a big interest of mine. This was a clear dream job offer, so I accepted.
Since then both have grown steadily in popularity (~2.6k GitHub stars) and now see pretty widespread use! I've since started writing several other JAX libraries and we now have a bit of an ecosystem going.
[1] https://github.com/patrick-kidger/equinox
-
[P] Optimistix, nonlinear optimisation in JAX+Equinox!
The elevator pitch is Optimistix is really fast, especially to compile. It plays nicely with Optax for first-order gradient-based methods, and takes a lot of design inspiration from Equinox, representing the state of all the solvers as standard JAX PyTrees.
-
JAX – NumPy on the CPU, GPU, and TPU, with great automatic differentiation
If you like PyTorch then you might like Equinox, by the way. (https://github.com/patrick-kidger/equinox ; 1.4k GitHub stars now!)
- Equinox: Elegant easy-to-use neural networks in Jax
- Show HN: Equinox (1.3k stars), a JAX library for neural networks and sciML
-
Pytrees
You're thinking of `jax.closure_convert`. :)
(Although technically that works by tracing and extracting all constants from the jaxpr, rather than introspecting the function's closure cells -- it sounds like your trick is the latter.)
When you discuss dynamic allocation, I'm guessing you're mainly referring to not being able to backprop through `jax.lax.while_loop`. If so, you might find `equinox.internal.while_loop` interesting, which is an unbounded while loop that you can backprop through! The secret sauce is to use a treeverse-style checkpointing scheme.
https://github.com/patrick-kidger/equinox/blob/f95a8ba13fb35...
-
Writing Python like it’s Rust
I'm a big fan of using ABCs to declare interfaces -- so much so that I have an improved abc.ABCMeta that also handles abstract instance variables and abstract class variables: https://github.com/patrick-kidger/equinox/blob/main/equinox/_better_abstract.py
-
[D] JAX vs PyTorch in 2023
For the daily research, I use Equinox (https://github.com/patrick-kidger/equinox) as a DL librarry in JAX.
- [Machinelearning] [D] État actuel de JAX vs Pytorch?
-
Training Deep Networks with Data Parallelism in Jax
It sounds like you're concerned about how downstream libraries tend to wrap JAX transformations to handle their own thing? (E.g. `haiku.grad`.)
If so, then allow me to make my usual advert here for Equinox:
https://github.com/patrick-kidger/equinox
This actually works with JAX's native transformations. (There's no `equinox.vmap` for example.)
On higher-order functions more generally, Equinox offers a way to control these quite carefully, by making ubiquitous use of callables that are also pytrees. E.g. a neural network is both a callable in that it has a forward pass, and a pytree in that it records its parameters in its tree structure.
What are some alternatives?
datasets - TFDS is a collection of datasets ready to use with TensorFlow, Jax, ...
flax - Flax is a neural network library for JAX that is designed for flexibility.
dm-haiku - JAX-based neural network library
flaxmodels - Pretrained deep learning models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet, etc.
torchtyping - Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
GradCache - Run Effective Large Batch Contrastive Learning Beyond GPU/TPU Memory Constraint
treex - A Pytree Module system for Deep Learning in JAX
elegy - A High Level API for Deep Learning in JAX
extending-jax - Extending JAX with custom C++ and CUDA code
jax - Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
diffrax - Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/