equinox
pytorch-lightning
equinox | pytorch-lightning | |
---|---|---|
31 | 19 | |
1,819 | 19,188 | |
- | - | |
9.2 | 9.9 | |
16 days ago | almost 2 years ago | |
Python | Python | |
Apache License 2.0 | Apache License 2.0 |
Stars - the number of stars that a project has on GitHub. Growth - month over month growth in stars.
Activity is a relative number indicating how actively a project is being developed. Recent commits have higher weight than older ones.
For example, an activity of 9.0 indicates that a project is amongst the top 10% of the most actively developed projects that we are tracking.
equinox
-
Ask HN: What side projects landed you a job?
I wrote a JAX-based neural network library (Equinox [1]) and numerical differential equation solving library (Diffrax [2]).
At the time I was just exploring some new research ideas in numerics -- and frankly, procrastinating from writing up my PhD thesis!
But then one of the teams at Google starting using them, so they offered me a job to keep developing them for their needs. Plus I'd get to work in biotech, which was a big interest of mine. This was a clear dream job offer, so I accepted.
Since then both have grown steadily in popularity (~2.6k GitHub stars) and now see pretty widespread use! I've since started writing several other JAX libraries and we now have a bit of an ecosystem going.
[1] https://github.com/patrick-kidger/equinox
-
[P] Optimistix, nonlinear optimisation in JAX+Equinox!
The elevator pitch is Optimistix is really fast, especially to compile. It plays nicely with Optax for first-order gradient-based methods, and takes a lot of design inspiration from Equinox, representing the state of all the solvers as standard JAX PyTrees.
-
JAX – NumPy on the CPU, GPU, and TPU, with great automatic differentiation
If you like PyTorch then you might like Equinox, by the way. (https://github.com/patrick-kidger/equinox ; 1.4k GitHub stars now!)
- Equinox: Elegant easy-to-use neural networks in Jax
- Show HN: Equinox (1.3k stars), a JAX library for neural networks and sciML
-
Pytrees
You're thinking of `jax.closure_convert`. :)
(Although technically that works by tracing and extracting all constants from the jaxpr, rather than introspecting the function's closure cells -- it sounds like your trick is the latter.)
When you discuss dynamic allocation, I'm guessing you're mainly referring to not being able to backprop through `jax.lax.while_loop`. If so, you might find `equinox.internal.while_loop` interesting, which is an unbounded while loop that you can backprop through! The secret sauce is to use a treeverse-style checkpointing scheme.
https://github.com/patrick-kidger/equinox/blob/f95a8ba13fb35...
-
Writing Python like it’s Rust
I'm a big fan of using ABCs to declare interfaces -- so much so that I have an improved abc.ABCMeta that also handles abstract instance variables and abstract class variables: https://github.com/patrick-kidger/equinox/blob/main/equinox/_better_abstract.py
-
[D] JAX vs PyTorch in 2023
For the daily research, I use Equinox (https://github.com/patrick-kidger/equinox) as a DL librarry in JAX.
- [Machinelearning] [D] État actuel de JAX vs Pytorch?
-
Training Deep Networks with Data Parallelism in Jax
It sounds like you're concerned about how downstream libraries tend to wrap JAX transformations to handle their own thing? (E.g. `haiku.grad`.)
If so, then allow me to make my usual advert here for Equinox:
https://github.com/patrick-kidger/equinox
This actually works with JAX's native transformations. (There's no `equinox.vmap` for example.)
On higher-order functions more generally, Equinox offers a way to control these quite carefully, by making ubiquitous use of callables that are also pytrees. E.g. a neural network is both a callable in that it has a forward pass, and a pytree in that it records its parameters in its tree structure.
pytorch-lightning
-
Problem with pytorch lightning and optuna with multiple callbacks
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # Trainer calls `on_validation_end` for sanity check. Therefore, it is necessary to avoid # calling `trial.report` multiple times at epoch 0. For more details, see # https://github.com/PyTorchLightning/pytorch-lightning/issues/1391. if trainer.sanity_checking: return
-
Please comment on my planned research project structure
Under the hood, the ModelWrapper object will create a ML model based on the config (so far, an XGBoost model and a PyTorch Lightning model). Each of those will have a wrapper that conducts training and evaluation (since from my understanding of Lightning, Trainers are required to be outside of the class). In lack of a better name, I call these wrappers Fitters. For uniformity, I thought about adding a common interface IFitter, which is inherited by all model wrappers as outlined below.
-
Watch out for the (PyTorch) Lightning
Join their Slack to ask the community questions and check out the GitHub here.
-
[P] Composer: a new PyTorch library to train models ~2-4x faster with better algorithms
Pytorch lightning benchmarks against pytorch on every PR (benchmarks to make sure that it is mot slower.
-
[D] What Repetitive Tasks Related to Machine Learning do You Hate Doing?
There is already a ton of momentum around automating ML workflows. I would suggest you contribute to a preexisting project like, for instance, PyTorch Lightning or fast.ai.
- PyTorch Lightening
-
[D] Are you using PyTorch or TensorFlow going into 2022?
Is the problem the sheer number of options, or the fact that they are all together in one place? Would it be better if they were organized into the different trainer entrypoints (fit, validate, ...)? If that is the case, there was an RFC proposing this which you might find interesting, feel free to drop by and comment on the issue: https://github.com/PyTorchLightning/pytorch-lightning/issues/10444
-
[D] Colab TPU low performance
I wanted to make a quick performance comparison between the GPU (Tesla K80) and TPU (v2-8) available in Google Colab with PyTorch. To do so quickly, I used an MNIST example from pytorch-lightning that trains a simple CNN.
-
[D] How to avoid CPU bottlenecking in PyTorch - training slowed by augmentations and data loading?
We've noticed GPU 0 on our 3 GPU system is sometimes idle (which would explain performance differences). However its unclear to us why that may be. Similar to this issue
-
[P] An introduction to PyKale https://github.com/pykale/pykale, a PyTorch library that provides a unified pipeline-based API for knowledge-aware multimodal learning and transfer learning on graphs, images, texts, and videos to accelerate interdisciplinary research. Welcome feedback/contribution!
If you want a good example for reference, take a look at Pytorch Lightning's readme (https://github.com/PyTorchLightning/pytorch-lightning) It answers the 3 questions of "what is this", "why should I care", and "how do i use it" almost instantly
What are some alternatives?
flax - Flax is a neural network library for JAX that is designed for flexibility.
mmdetection - OpenMMLab Detection Toolbox and Benchmark
dm-haiku - JAX-based neural network library
pytorch-grad-cam - Advanced AI Explainability for computer vision. Support for CNNs, Vision Transformers, Classification, Object detection, Segmentation, Image similarity and more.
torchtyping - Type annotations and dynamic checking for a tensor's shape, dtype, names, etc.
detectron2 - Detectron2 is a platform for object detection, segmentation and other visual recognition tasks.
treex - A Pytree Module system for Deep Learning in JAX
fastai - The fastai deep learning library
extending-jax - Extending JAX with custom C++ and CUDA code
composer - Supercharge Your Model Training
diffrax - Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
sparktorch - Train and run Pytorch models on Apache Spark.