jax-models
dm-haiku
jax-models | dm-haiku | |
---|---|---|
6 | 10 | |
138 | 2,816 | |
- | 1.2% | |
0.0 | 7.8 | |
almost 2 years ago | 6 days ago | |
Python | Python | |
Apache License 2.0 | Apache License 2.0 |
Stars - the number of stars that a project has on GitHub. Growth - month over month growth in stars.
Activity is a relative number indicating how actively a project is being developed. Recent commits have higher weight than older ones.
For example, an activity of 9.0 indicates that a project is amongst the top 10% of the most actively developed projects that we are tracking.
jax-models
-
[D] How to contribute to open source ML and DL without having access to high quality setup?
I was in the same position as you are and the best thing you can do is to start reproducing papers (that's what I did with jax-models). This will
-
[D] Should We Be Using JAX in 2022?
I've been using JAX, especially Flax for quite some time now for my reproducibility initiative (jax_models) and this is what I really appreciate about the framework
- Weekly updated open sourced model implementations in Flax
- Weekly updated open sourced deep learning model implementations in Flax
- [P] Weekly updated open sourced model implementations in Flax
dm-haiku
-
Maxtext: A simple, performant and scalable Jax LLM
Is t5x an encoder/decoder architecture?
Some more general options.
The Flax ecosystem
https://github.com/google/flax?tab=readme-ov-file
or dm-haiku
https://github.com/google-deepmind/dm-haiku
were some of the best developed communities in the Jax AI field
Perhaps the “trax” repo? https://github.com/google/trax
Some HF examples https://github.com/huggingface/transformers/tree/main/exampl...
Sadly it seems much of the work is proprietary these days, but one example could be Grok-1, if you customize the details. https://github.com/xai-org/grok-1/blob/main/run.py
-
Help with installing python packages.
I am fresh to nix os especially when it comes to using python on it how do I install packages withought using pip I need to install numpy~=1.19.5 transformers~=4.8.2 tqdm~=4.45.0 setuptools~=51.3.3 wandb>=0.11.2 einops~=0.3.0 requests~=2.25.1 fabric~=2.6.0 optax==0.0.6 git+https://github.com/deepmind/dm-haiku git+https://github.com/EleutherAI/lm-evaluation-harness/ ray[default]==1.4.1 jax~=0.2.12 Flask~=1.1.2 cloudpickle~=1.3.0 tensorflow-cpu~=2.5.0 google-cloud-storage~=1.36.2 smart_open[gcs] func_timeout ftfy fastapi uvicorn lm_dataformat which I can just do pip -r thetxtfile but idk how to do this in nix os also I would be using python3.7 so far this is what I have come up with but I know its wrong { pkgs ? import {} }: let packages = python-packages: with python-packages; [ mesh-transformer-jax/ jax==0.2.12 numpy~=1.19.5 transformers~=4.8.2 tqdm~=4.45.0 setuptools~=51.3.3 wandb>=0.11.2 einops~=0.3.0 requests~=2.25.1 fabric~=2.6.0 optax==0.0.6 #the other packages ]; pkgs.mkShell { nativeBuildInputs = [ pkgs.buildPackages.python37 ]; }
-
[D] Should We Be Using JAX in 2022?
What's your favorite Deep Learning API for JAX - Flax, Haiku, Elegy, something else?
-
[D] Current State of JAX vs Pytorch?
Just going to add that you should check out haiku if you are considering JAX: https://github.com/deepmind/dm-haiku
-
PyTorch vs. TensorFlow in 2022
As a researcher in RL & ML in a big industry lab, I would say most of my colleagues are moving to JAX 0https://github.com/google/jax], which this article kind of ignores. JAX is XLA-accelerated NumPy, it's cool beyond just machine learning, but only provides low-level linear algebra abstractions. However you can put something like Haiku [https://github.com/deepmind/dm-haiku] or Flax [https://github.com/google/flax] on top of it and get what the cool kids are using :)
-
[D] JAX learning resources?
- https://github.com/deepmind/dm-haiku/tree/main/examples
- Why would I want to develop yet another deep learning framework?
- Help with installing python packages
What are some alternatives?
datasets - TFDS is a collection of datasets ready to use with TensorFlow, Jax, ...
flax - Flax is a neural network library for JAX that is designed for flexibility.
jax-resnet - Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
flaxmodels - Pretrained deep learning models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet, etc.
trax - Trax — Deep Learning with Clear Code and Speed
equinox - Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
GradCache - Run Effective Large Batch Contrastive Learning Beyond GPU/TPU Memory Constraint
elegy - A High Level API for Deep Learning in JAX
jax - Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more