awesome-jax
equinox
awesome-jax | equinox | |
---|---|---|
3 | 31 | |
1,312 | 1,837 | |
- | - | |
6.2 | 9.2 | |
8 days ago | 2 days ago | |
Python | ||
Creative Commons Zero v1.0 Universal | Apache License 2.0 |
Stars - the number of stars that a project has on GitHub. Growth - month over month growth in stars.
Activity is a relative number indicating how actively a project is being developed. Recent commits have higher weight than older ones.
For example, an activity of 9.0 indicates that a project is amongst the top 10% of the most actively developed projects that we are tracking.
awesome-jax
-
[D] Any less-boilerplate framework for Jax/Flax/Haiku?
Have you looked in here?
-
[D] JAX learning resources?
Here's a compilation of resources: https://github.com/n2cholas/awesome-jax
equinox
-
Ask HN: What side projects landed you a job?
I wrote a JAX-based neural network library (Equinox [1]) and numerical differential equation solving library (Diffrax [2]).
At the time I was just exploring some new research ideas in numerics -- and frankly, procrastinating from writing up my PhD thesis!
But then one of the teams at Google starting using them, so they offered me a job to keep developing them for their needs. Plus I'd get to work in biotech, which was a big interest of mine. This was a clear dream job offer, so I accepted.
Since then both have grown steadily in popularity (~2.6k GitHub stars) and now see pretty widespread use! I've since started writing several other JAX libraries and we now have a bit of an ecosystem going.
[1] https://github.com/patrick-kidger/equinox
-
[P] Optimistix, nonlinear optimisation in JAX+Equinox!
The elevator pitch is Optimistix is really fast, especially to compile. It plays nicely with Optax for first-order gradient-based methods, and takes a lot of design inspiration from Equinox, representing the state of all the solvers as standard JAX PyTrees.
-
JAX – NumPy on the CPU, GPU, and TPU, with great automatic differentiation
If you like PyTorch then you might like Equinox, by the way. (https://github.com/patrick-kidger/equinox ; 1.4k GitHub stars now!)
- Equinox: Elegant easy-to-use neural networks in Jax
- Show HN: Equinox (1.3k stars), a JAX library for neural networks and sciML
-
Pytrees
You're thinking of `jax.closure_convert`. :)
(Although technically that works by tracing and extracting all constants from the jaxpr, rather than introspecting the function's closure cells -- it sounds like your trick is the latter.)
When you discuss dynamic allocation, I'm guessing you're mainly referring to not being able to backprop through `jax.lax.while_loop`. If so, you might find `equinox.internal.while_loop` interesting, which is an unbounded while loop that you can backprop through! The secret sauce is to use a treeverse-style checkpointing scheme.
https://github.com/patrick-kidger/equinox/blob/f95a8ba13fb35...
-
Writing Python like it’s Rust
I'm a big fan of using ABCs to declare interfaces -- so much so that I have an improved abc.ABCMeta that also handles abstract instance variables and abstract class variables: https://github.com/patrick-kidger/equinox/blob/main/equinox/_better_abstract.py
-
[D] JAX vs PyTorch in 2023
For the daily research, I use Equinox (https://github.com/patrick-kidger/equinox) as a DL librarry in JAX.
- [Machinelearning] [D] État actuel de JAX vs Pytorch?
-
Training Deep Networks with Data Parallelism in Jax
It sounds like you're concerned about how downstream libraries tend to wrap JAX transformations to handle their own thing? (E.g. `haiku.grad`.)
If so, then allow me to make my usual advert here for Equinox:
https://github.com/patrick-kidger/equinox
This actually works with JAX's native transformations. (There's no `equinox.vmap` for example.)
On higher-order functions more generally, Equinox offers a way to control these quite carefully, by making ubiquitous use of callables that are also pytrees. E.g. a neural network is both a callable in that it has a forward pass, and a pytree in that it records its parameters in its tree structure.
What are some alternatives?
awesome-production-machine-learning - A curated list of awesome open source libraries to deploy, monitor, version and scale your machine learning
flax - Flax is a neural network library for JAX that is designed for flexibility.
get-started-with-JAX - The purpose of this repo is to make it easy to get started with JAX, Flax, and Haiku. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning about the JAX ecosystem.
dm-haiku - JAX-based neural network library
awesome-ocr
torchtyping - Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
awesome-ai-in-finance - 🔬 A curated list of awesome LLMs & deep learning strategies & tools in financial market.
treex - A Pytree Module system for Deep Learning in JAX
awesome-deep-learning - A curated list of awesome Deep Learning tutorials, projects and communities.
extending-jax - Extending JAX with custom C++ and CUDA code
Pytorch - Tensors and Dynamic neural networks in Python with strong GPU acceleration
diffrax - Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/