dex-lang VS jax

Compare dex-lang vs jax and see what are their differences.


Research language for array processing in the Haskell/ML family (by google-research)


Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more (by google)
Our great sponsors
  • Nanos - Run Linux Software Faster and Safer than Linux with Unikernels
  • Scout APM - A developer's best friend. Try free for 14-days
  • SaaSHub - Software Alternatives and Reviews
dex-lang jax
12 27
1,118 15,176
3.6% 3.0%
9.5 9.9
5 days ago 6 days ago
Haskell Python
BSD 3-clause "New" or "Revised" License Apache License 2.0
The number of mentions indicates the total number of mentions that we've tracked plus the number of user suggested alternatives.
Stars - the number of stars that a project has on GitHub. Growth - month over month growth in stars.
Activity is a relative number indicating how actively a project is being developed. Recent commits have higher weight than older ones.
For example, an activity of 9.0 indicates that a project is amongst the top 10% of the most actively developed projects that we are tracking.


Posts with mentions or reviews of dex-lang. We have used some of these posts to build our list of alternatives and similar projects. The last one was on 2021-11-09.


Posts with mentions or reviews of jax. We have used some of these posts to build our list of alternatives and similar projects. The last one was on 2021-11-26.
  • PyTorch: Where we are headed and why it looks a lot like Julia (but not exactly)
    19 projects | | 26 Nov 2021
  • JAX on WSL2 - The "Couldn't read CUDA driver version." problem.
    1 project | | 18 Nov 2021
    As noted here, the path (file):
  • Show HN: How does Jax allocate memory on a TPU? An interactive C++ walkthrough
    4 projects | | 6 Nov 2021
    > The downside of Jax is it’s not easy to debug. PyTorch, for better or for worse, will actually run your Python code as you wrote it.

    Hmm. Jax's ease of debugging was the very first thing that caught my attention:

    > I ran it on the TPU VM, saw the loss curve go down, and it was like an electric shock. "Wow! That actually... worked? Huh. that's weird. Things never work on the first try. I'm impressed."

    > Then I plopped `import pdb; pdb.set_trace()` in the middle of the `loss` function and ran it again. It dropped me into the Python debugger.

    > There was a tensor named `X_bt`. I typed `X_bt`. The debugger printed the value of `X_bt`.

    > I was able to print out all the values of every variable, just like you'd expect Python to be able to do.

    > There was a tensor named `Y_bt`. I typed `X_bt + Y_bt`. I was now staring at exactly what I expected: the sum of those two tensors.

    > I could write `x + y`, or create new variables, or anything else I wanted.

    > Now I was real impressed.

    > If it sounds weird that I'm so easily impressed, it's because, you godda understand: until now, TPUs were a complete pain in the ass to use. I kept my feelings to myself, because I understood that the Cloud TPU team were working hard to improve TPUs, and the TFRC support team was wonderful, and I had so many TPUs to play with. But holy moly, if you were expecting any of the above examples to just work on the first try when using Tensorflow V1 on TPUs, you were in for a rude awakening. And if you thought "Well, Tensorflow v2 is supposedly a lot better, right? Surely I'll be able to do basic things without worrying...."

    > ... no. Not even close. Not until Jax + TPU VMs.

    In the subsequent year, it's been nothing but joy.

    If the problem is that you want to see tensor values in a JIT'ed function, use a host callback. You can run actual Python wherever you want:

    > This module introduces the host callback functions call(), id_tap(), and id_print(), that send their arguments from the device to the host and invoke user-defined Python functions on the host, optionally returning results back to the device computation.

    The nice part is, there's no "magic" under the hood. If you get a chance, I highly recommend reading through Autodidax:

    Autodidax is a pure-python implementation of jax. (Literally in one file, on that page.) It walks you through how every aspect of jax works.

    Delightfully, I found a secret branch where autodidax also implements host callbacks:

    If you scroll to the very bottom of that file, you'll see an example of compiling your own XLA JIT'ed code which subsequently calls back into Python. TPUs do precisely the same thing.

    Point being:

    > PyTorch, for better or for worse, will actually run your Python code as you wrote it.

    ... is also true of jax, to within a rounding error less than "I personally don't mind writing id_print(x) instead of print(x)." :)

  • [N] Jax now Supports Apple Silicon [CPU ONLY]
    1 project | | 30 Oct 2021
    Check this thread to install jaxlib:
  • An Introduction to Probabilistic Programming
    2 projects | | 22 Oct 2021
    note that these are not exclusive. you could divide ML into a traditional statistical approach and a probabilistic one that is concerned with deriving the underlying probability distribution. probabilistic programming is kind of like a domain specific language for achieving this. there is also differential programming that works on the same principle. there are certainly industrial usages of this paradigm. look up pyro ( for ppl and jax ( for differential programming
  • [R] Google AI 0pen Sources ‘FedJAX’, A JAX-based Python Library for Federated Learning Simulations
    2 projects | | 5 Oct 2021
    A new google study introduces FedJAX, a JAX-based open-source library for federated learning simulations that emphasizes ease-of-use in research. FedJAX intends to construct and assess federated algorithms faster and easier for academics by providing basic building blocks for implementing federated algorithms, preloaded datasets, models, and algorithms, and fast simulation speed.
  • [P] Training a spiking NN to produce images
    2 projects | | 10 Sep 2021
    We used the spiking neuron modules in Rockpool to build the network. These particular modules are based on Jax, which gives us compilation to GPU/TPU/CPU as well as automatic differentiation "for free". There are similar modules in Rockpool based on Torch, if you prefer to use a torch training pipeline.
  • Running AlphaFold on a IBM Power9 cluster?
    1 project | | 7 Sep 2021
    much of alphafold2 is implemented in jax, and jax does not have power9 builds. see this related issue, which is still open at the time of writing.
  • JaxNetwork
    1 project | | 5 Sep 2021
    JAX is a numerical computing library that combines NumPy, automatic differentiation, and first-class GPU/TPU support.
  • Jax and Haskell
    1 project | | 27 Aug 2021

What are some alternatives?

When comparing dex-lang and jax you can also consider the following projects:

mesh-transformer-jax - Model parallel transformers in JAX and Haiku

futhark - :boom::computer::boom: A data-parallel functional programming language

julia - The Julia Programming Language

hasktorch - Tensors and neural networks in Haskell

Pytorch - Tensors and Dynamic neural networks in Python with strong GPU acceleration



tensorflow - An Open Source Machine Learning Framework for Everyone

functorch - functorch is a prototype of JAX-like composable function transforms for PyTorch.

mesh-transformer-jax - Model parallel transformers in JAX and Haiku