Scientific computing in JAX

This page summarizes the projects mentioned and recommended in the original post on /r/ScientificComputing

Our great sponsors
  • InfluxDB - Power Real-Time Data Analytics at Scale
  • WorkOS - The modern identity platform for B2B SaaS
  • SaaSHub - Software Alternatives and Reviews
  • sympy2jax

    Turn SymPy expressions into trainable JAX expressions.

  • sympy2jax: sympy->JAX conversion;

  • jaxtyping

    Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/

  • jaxtyping: rich shape & dtype annotations for arrays and tensors (also supports PyTorch/TensorFlow/NumPy);

  • InfluxDB

    Power Real-Time Data Analytics at Scale. Get real-time insights from all types of time series data with InfluxDB. Ingest, query, and analyze billions of data points in real-time with unbounded cardinality.

    InfluxDB logo
  • eqxvision

    A Python package of computer vision models for the Equinox ecosystem.

  • Eqxvision: computer vision.

  • diffrax

    Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/

  • Sure. So I've got some PyTorch benchmarks here. The main take-away so far has been that for a neural ODE, the backward pass takes about 50% longer in PyTorch, and the forward (inference) pass takes an incredible 100x longer.

NOTE: The number of mentions on this list indicates mentions on common posts plus user suggested alternatives. Hence, a higher number means a more popular project.

Suggest a related project

Related posts