-
equinox
Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
I'm going to disagree here! Classes and functional programming can go very well together, just don't expect to do in-place mutation. (I.e. OO-style programming.)
You might like Equinox (https://github.com/patrick-kidger/equinox; 1.4k GitHub stars) which deliberately offers a very PyTorch-like feel for JAX.
Regarding speed, I would strongly recommend JAX over PyTorch for SciComp. The XLA compiler seems to be much more effective for such use cases.
-
Judoscale
Save 47% on cloud hosting with autoscaling that just works. Judoscale integrates with Django, FastAPI, Celery, and RQ to make autoscaling easy and reliable. Save big, and say goodbye to request timeouts and backed-up task queues.
-
jax
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Actually that never changed. The README has always had an example of differentiating through native Python control flow:
https://github.com/google/jax/commit/948a8db0adf233f333f3e5f...
The constraints on control flow expressions come from jax.jit (because Python control flow can't be staged out) and jax.vmap (because we can't take multiple branches of Python control flow, which we might need to do for different batch elements). But autodiff of Python-native control flow works fine!
-
Jax is super useful for scientific computing. Although nbody sims might not be the best application. A naive nbody sim is very easy to implement and accelerate in jax (here’s my version: https://github.com/PWhiddy/jax-experiments/blob/main/nbody.i...), but it can be tricky to scale it. This is because efficient nbody sims usually either rely on trees or spatial hashing/sorting which are tricky to efficiently implement with jax.
-
Agree, though I wouldn’t call PyTorch a drop-in for NumPy either. CuPy is the drop-in. Excepting some corner cases, you can use the same code for both. Thinc’s ops work with both NumPy and CuPy:
https://github.com/explosion/thinc/blob/master/thinc/backend...
-
jaxonnxruntime
A user-friendly tool chain that enables the seamless execution of ONNX models using JAX as the backend.
-
-
Actually, that's never been a constraint for JAX autodiff. JAX grew out of the original Autograd (https://github.com/hips/autograd), so differentiating through Python control flow always worked. It's jax.jit and jax.vmap which place constraints on control flow, requiring structured control flow combinators like those.
-
CodeRabbit
CodeRabbit: AI Code Reviews for Developers. Revolutionize your code reviews with AI. CodeRabbit offers PR summaries, code walkthroughs, 1-click suggestions, and AST-based analysis. Boost productivity and code quality across all major languages with each PR.