Python Jax

Open-source Python projects categorized as Jax

Top 23 Python Jax Projects

  • transformers

    🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.

    Project mention: Schedule-Free Learning – A New Way to Train | | 2024-04-06

    * Superconvergence + LR range finder + Fast AI's Ranger21 optimizer was the goto optimizer for CNNs, and worked fabulously well, but on transformers, the learning rate range finder sadi 1e-3 was the best, whilst 1e-5 was better. However, the 1 cycle learning rate stuck.

  • Keras

    Deep Learning for humans

    Project mention: Getting Started with Gemma Models | | 2024-04-15

    After setting the variables for the environment, the next step is to install dependencies. To use Gemma, KerasNLP is the dependency used. KerasNLP is a collection of natural language processing (NLP) models implemented in Keras and runnable on JAX, PyTorch, and TensorFlow.

  • WorkOS

    The modern identity platform for B2B SaaS. The APIs are flexible and easy-to-use, supporting authentication, user identity, and complex enterprise features like SSO and SCIM provisioning.

  • jax

    Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

    Project mention: The Elements of Differentiable Programming | | 2024-03-22

    The dual numbers exist just as surely as the real numbers and have been used well over 100 years

    Pytorch has had them for many years.

    JAX implements them and uses them exactly as stated in this thread.

    As you so eloquently stated, "you shouldn't be proclaiming things you don't actually know on a public forum," and doubly so when your claimed "corrections" are so demonstrably and totally incorrect.

  • d2l-en

    Interactive deep learning book with multi-framework code, math, and discussions. Adopted at 500 universities from 70 countries including Stanford, MIT, Harvard, and Cambridge.

  • best-of-ml-python

    🏆 A ranked list of awesome machine learning Python libraries. Updated weekly.

  • ivy

    The Unified AI Framework

    Project mention: Keras 3.0 | | 2023-11-28

    See also which I have not tried but seems along the lines of what you are describing, working with all the major frameworks

  • trax

    Trax — Deep Learning with Clear Code and Speed

    Project mention: Replit's new Code LLM was trained in 1 week | | 2023-05-03

    and the implementation if you are interested.

    Hope you get to look into this!

  • InfluxDB

    Power Real-Time Data Analytics at Scale. Get real-time insights from all types of time series data with InfluxDB. Ingest, query, and analyze billions of data points in real-time with unbounded cardinality.

  • einops

    Flexible and powerful tensor operations for readable and reliable code (for pytorch, jax, TF and others)

    Project mention: Einops: Flexible and powerful tensor operations for readable and reliable code | | 2023-12-12
  • flax

    Flax is a neural network library for JAX that is designed for flexibility.

    Project mention: What is the JAX/Flax equivalent of torch.nn.Parameter? | /r/JAX | 2023-04-24

  • datasets

    TFDS is a collection of datasets ready to use with TensorFlow, Jax, ... (by tensorflow)

  • alpa

    Training and serving large-scale neural networks with auto parallelization.

  • scenic

    Scenic: A Jax Library for Computer Vision Research and Beyond (by google-research)

  • dm-haiku

    JAX-based neural network library

  • thinc

    🔮 A refreshing functional take on deep learning, compatible with your favorite libraries

    Project mention: JAX – NumPy on the CPU, GPU, and TPU, with great automatic differentiation | | 2023-09-28

    Agree, though I wouldn’t call PyTorch a drop-in for NumPy either. CuPy is the drop-in. Excepting some corner cases, you can use the same code for both. Thinc’s ops work with both NumPy and CuPy:

  • foolbox

    A Python toolbox to create adversarial examples that fool neural networks in PyTorch, TensorFlow, and JAX

    Project mention: More snake-oil | /r/DefendingAIArt | 2023-06-26

    Go ahead, play with any adversarial attacks from you will not find an attack that is both robust to perturbations and almost visually imperceptible

  • deepxde

    A library for scientific machine learning and physics-informed learning

  • EasyLM

    Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.

    Project mention: How To Fine-Tune LLaMA, OpenLLaMA, And XGen, With JAX On A GPU Or A TPU | /r/LocalLLaMA | 2023-07-04
  • mctx

    Monte Carlo tree search in JAX

    Project mention: About Monte Carlo tree search in Jax | | 2023-11-23
  • pennylane

    PennyLane is a cross-platform Python library for differentiable programming of quantum computers. Train a quantum computer the same way as a neural network.

  • numpyro

    Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.

    Project mention: Bayesian Analysis with Python | | 2024-02-10
  • equinox

    Elegant easy-to-use neural networks + scientific computing in JAX.

    Project mention: Ask HN: What side projects landed you a job? | | 2023-12-03

    I wrote a JAX-based neural network library (Equinox [1]) and numerical differential equation solving library (Diffrax [2]).

    At the time I was just exploring some new research ideas in numerics -- and frankly, procrastinating from writing up my PhD thesis!

    But then one of the teams at Google starting using them, so they offered me a job to keep developing them for their needs. Plus I'd get to work in biotech, which was a big interest of mine. This was a clear dream job offer, so I accepted.

    Since then both have grown steadily in popularity (~2.6k GitHub stars) and now see pretty widespread use! I've since started writing several other JAX libraries and we now have a bit of an ecosystem going.


  • machine_learning_refined

    Notes, examples, and Python demos for the 2nd edition of the textbook "Machine Learning Refined" (published by Cambridge University Press).

  • TransformerEngine

    A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.

    Project mention: Benchmarking Large Language Models on NVIDIA H100 GPUs with CoreWeave (Part 1) | /r/nvidia | 2023-04-30

    4090 now has its 8-bit float enabled as well, see the [transformer engine issue](

  • SaaSHub

    SaaSHub - Software Alternatives and Reviews. SaaSHub helps you find the best software and product alternatives

NOTE: The open source projects on this list are ordered by number of github stars. The number of mentions indicates repo mentiontions in the last 12 Months or since we started tracking (Dec 2020). The latest post mention was on 2024-04-15.

Python Jax related posts


What are some of the best open-source Jax projects in Python? This list will help you:

Project Stars
1 transformers 124,115
2 Keras 60,854
3 jax 27,735
4 d2l-en 21,564
5 best-of-ml-python 15,284
6 ivy 14,016
7 trax 7,948
8 einops 7,875
9 flax 5,474
10 datasets 4,157
11 alpa 2,971
12 scenic 2,963
13 dm-haiku 2,797
14 thinc 2,790
15 foolbox 2,650
16 deepxde 2,311
17 EasyLM 2,215
18 mctx 2,192
19 pennylane 2,098
20 numpyro 2,025
21 equinox 1,775
22 machine_learning_refined 1,578
23 TransformerEngine 1,395
SaaSHub - Software Alternatives and Reviews
SaaSHub helps you find the best software and product alternatives