Python Jax

Open-source Python projects categorized as Jax

Top 23 Python Jax Projects

  • transformers

    🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.

    Project mention: GPU Comparisons: RTX 6000 ADA vs A100 80GB vs 2x 4090s | | 2022-12-02

    Looked into this last night and yeah, NVLink works the way you described because of misleading marketing- no contiguous memory pool, just a faster interconnect so maybe model parallelisation scales a bit better but you still have to implement it. Also saw an example where some PyTorch GPT2 models scaled horrifically in training with multiple PCIe V100s and 3090s that didn’t have NVLink so that’s a caveat with dual 4090s not having NVLink.

  • jax

    Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

    Project mention: [D] Pytorch or TensorFlow for development and deployment? | | 2022-11-26

    PyTorch. Datapoint 1: it’s part of the Linux foundation Datapoint 2: Jax

  • Zigi

    Delete the most useless function ever: context switching.. Zigi monitors Jira and GitHub updates, pings you when PRs need approval and lets you take fast actions - all directly from Slack! Plus it reduces cycle time by up to 75%.

  • ivy

    The Unified Machine Learning Framework

    Project mention: CoreML Stable Diffusion | | 2022-12-01

    ROCm's great for data centers, but good luck finding anything about desktop GPUs on their site apart from this lone blog post:

    There's a good explanation of AMD's ROCm targets here:

    It's currently a PITA to get common Python libs like Numba to even talk to AMD cards (admittedly Numba won't talk to older Nvidia cards either and they deprecate ruthlessly; I had to downgrade 8 versions to get it working with a 5yo mobile workstation). YC-backed Ivy claims to be working on unifying ML frameworks in a hardware-agnostic way but I don't have enough experience to assess how well they're succeeding yet:

    I was happy to see DiffusionBee does talk the GPU in my late-model intel Mac, though for some reason it only uses 50% of its power right now. I'm sure the situation will improve as Metal 3.0 and Vulkan get more established.

  • trax

    Trax — Deep Learning with Clear Code and Speed

  • einops

    Deep learning operations reinvented (for pytorch, tensorflow, jax and others)

    Project mention: Delimiter-First Code | | 2022-12-09
  • flax

    Flax is a neural network library for JAX that is designed for flexibility.

    Project mention: Announcing flax 0.2 - A fully featured ECS | | 2022-09-11

    Just as an FYI, you might be competing against another big open source project with the same name

  • datasets

    TFDS is a collection of datasets ready to use with TensorFlow, Jax, ... (by tensorflow)

  • Sonar

    Write Clean Python Code. Always.. Sonar helps you commit clean code every time. With over 225 unique rules to find Python bugs, code smells & vulnerabilities, Sonar finds the issues while you focus on the work.

  • thinc

    🔮 A refreshing functional take on deep learning, compatible with your favorite libraries

    Project mention: Tinygrad: A simple and powerful neural network framework | | 2022-11-03

    I love those tiny DNN frameworks, some examples that I studied in the past (I still use PyTorch for work related projects) : the creators of spaCy

  • dm-haiku

    JAX-based neural network library

    Project mention: Help with installing python packages. | | 2022-08-18

    I am fresh to nix os especially when it comes to using python on it how do I install packages withought using pip I need to install numpy~=1.19.5 transformers~=4.8.2 tqdm~=4.45.0 setuptools~=51.3.3 wandb>=0.11.2 einops~=0.3.0 requests~=2.25.1 fabric~=2.6.0 optax==0.0.6 git+ git+ ray[default]==1.4.1 jax~=0.2.12 Flask~=1.1.2 cloudpickle~=1.3.0 tensorflow-cpu~=2.5.0 google-cloud-storage~=1.36.2 smart_open[gcs] func_timeout ftfy fastapi uvicorn lm_dataformat ​ which‍ I can just do pip -r thetxtfile but idk how to do this in nix os also I would be using python3.7 so far this is what I have come up with but I know its wrong { pkgs ? import {} }: let packages = python-packages: with python-packages; [ mesh-transformer-jax/ jax==0.2.12 numpy~=1.19.5 transformers~=4.8.2 tqdm~=4.45.0 setuptools~=51.3.3 wandb>=0.11.2 einops~=0.3.0 requests~=2.25.1 fabric~=2.6.0 optax==0.0.6 #the other packages ]; pkgs.mkShell { nativeBuildInputs = [ pkgs.buildPackages.python37 ]; }

  • scenic

    Scenic: A Jax Library for Computer Vision Research and Beyond (by google-research)

    Project mention: Google Research Proposes an Artificial Intelligence (AI) Model to Utilize Vision Transformers on Videos | | 2022-11-25

    Quick Read: Paper:\_ViViT\_A\_Video\_Vision\_Transformer\_ICCV\_2021\_paper.pdf Github link:

  • numpyro

    Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.

  • deepxde

    A library for scientific machine learning and physics-informed learning

    Project mention: [Dev-Showcase] Airfoil Optimisation using Physics Informed Neural Networks(PINNs) | | 2022-05-24

    Due to certain limitations in MODULUS(We are unable to directly access the point cloud), we are now also exploring other available PINNs libraries and frameworks and stumbled on to deepXDE. deepXDE is a little different than Modulus and I'm currently exploring it.

  • alpa

    Training and serving large-scale neural networks

    Project mention: Alpa: Auto-parallelizing large model training and inference (by UC Berkeley) | | 2022-06-23
  • equinox

    Callable PyTrees and filtered transforms => neural networks in JAX.

    Project mention: Python 3.11 is much faster than 3.8 | | 2022-10-26

    +1 for JAX. Basically designed to be the successor to TensorFlow, and much nicer to work with. Strangely I've not seen it discussed around HN much but it's what I do 100% of my work in these days.

    Whilst I'm here: shameless self-promotion for Equinox and Diffrax:

  • diffrax

    Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

    Project mention: PyTorch 2.0 | | 2022-12-02

    At least prior to this announcement: JAX was much faster than PyTorch for differentiable physics. (Better JIT compiler; reduced Python-level overhead.)

    E.g for numerical ODE simulation, I've found that Diffrax ( is ~100 times faster than torchdiffeq on the forward pass. The backward pass is much closer, and for this Diffrax is about 1.5 times faster.

    It remains to be seen how PyTorch 2.0 will compare, or course!

    Right now my job is actually building out the scientific computing ecosystem in JAX, so feel free to ping me with any other questions.

  • jaxopt

    Hardware accelerated, batchable and differentiable optimizers in JAX.

  • long-range-arena

    Long Range Arena for Benchmarking Efficient Transformers

    Project mention: [R] The Annotated S4: Efficiently Modeling Long Sequences with Structured State Spaces | | 2022-01-16

    The Structured State Space for Sequence Modeling (S4) architecture is a new approach to very long-range sequence modeling tasks for vision, language, and audio, showing a capacity to capture dependencies over tens of thousands of steps. Especially impressive are the model’s results on the challenging Long Range Arena benchmark, showing an ability to reason over sequences of up to 16,000+ elements with high accuracy.

  • fast-soft-sort

    Fast Differentiable Sorting and Ranking

  • elegy

    A High Level API for Deep Learning in JAX

    Project mention: [D] Any less-boilerplate framework for Jax/Flax/Haiku? | | 2022-09-29

    Elegy might be worth a look.

  • git-re-basin

    Code release for "Git Re-Basin: Merging Models modulo Permutation Symmetries"

    Project mention: I'm testing if the 1.5 and 2.0 model combine in Automatic 1111 now... | | 2022-11-28
  • PDEBench

    PDEBench: An Extensive Benchmark for Scientific Machine Learning

    Project mention: [D] what are the SOTA neural PDE solvers besides FNO? | | 2022-11-22


  • prompt-tuning

    Original Implementation of Prompt Tuning from Lester, et al, 2021

    Project mention: Need advice on learning to apply techniques from research papers. | | 2022-05-10

    For instance, right now I want to try applying the prompt-tuning technique on simple classification tasks. The github page looks as follows:

  • pyhpc-benchmarks

    A suite of benchmarks for CPU and GPU performance of the most popular high-performance libraries for Python :rocket:

  • InfluxDB

    Build time-series-based applications quickly and at scale.. InfluxDB is the Time Series Data Platform where developers build real-time applications for analytics, IoT and cloud-native services in less time with less code.

NOTE: The open source projects on this list are ordered by number of github stars. The number of mentions indicates repo mentiontions in the last 12 Months or since we started tracking (Dec 2020). The latest post mention was on 2022-12-09.

Python Jax related posts


What are some of the best open-source Jax projects in Python? This list will help you:

Project Stars
1 transformers 75,627
2 jax 21,086
3 ivy 7,751
4 trax 7,201
5 einops 6,102
6 flax 3,774
7 datasets 3,483
8 thinc 2,627
9 dm-haiku 2,270
10 scenic 1,593
11 numpyro 1,593
12 deepxde 1,377
13 alpa 1,201
14 equinox 884
15 diffrax 691
16 jaxopt 601
17 long-range-arena 481
18 fast-soft-sort 442
19 elegy 425
20 git-re-basin 344
21 PDEBench 270
22 prompt-tuning 263
23 pyhpc-benchmarks 254
Truly a developer’s best friend
Scout APM is great for developers who want to find and fix performance issues in their applications. With Scout, we'll take care of the bugs so you can focus on building great things 🚀.