DiffEqGPU.jl
jax
DiffEqGPU.jl | jax | |
---|---|---|
2 | 89 | |
296 | 31,945 | |
2.0% | 1.6% | |
6.7 | 10.0 | |
2 days ago | 5 days ago | |
Julia | Python | |
MIT License | Apache License 2.0 |
Stars - the number of stars that a project has on GitHub. Growth - month over month growth in stars.
Activity is a relative number indicating how actively a project is being developed. Recent commits have higher weight than older ones.
For example, an activity of 9.0 indicates that a project is amongst the top 10% of the most actively developed projects that we are tracking.
DiffEqGPU.jl
-
2023 was the year that GPUs stood still
Indeed, and this year we created a system for compiling ODE code not just optimized CUDA kernels but also OneAPI kernels, AMD GPU kernels, and Metal. Peer reviewed version is here (https://www.sciencedirect.com/science/article/abs/pii/S00457...), open access is here (https://arxiv.org/abs/2304.06835), and the open source code is at https://github.com/SciML/DiffEqGPU.jl. The key that the paper describes is that in this case kernel generation is about 20x-100x faster than PyTorch and Jax (see the Jax compilation in multiple ways in this notebook https://colab.research.google.com/drive/1d7G-O5JX31lHbg7jTzz..., extra overhead though from calling Julia from Python but still shows a 10x).
The point really is that while deep learning libraries are amazing, at the end of the day they are DSL and really pull towards one specific way of computing and parallelization. It turns out that way of parallelizing is good for deep learning, but not for all things you may want to accelerate. Sometimes (i.e. cases that aren't dominated by large linear algebra) building problem-specific kernels is a major win, and it's over-extrapolating to see ML frameworks do well with GPUs and think that's the only thing that's required. There are many ways to parallelize a code, ML libraries hardcode a very specific way, and it's good for what they are used for but not every problem that can arise.
-
Julia GPU-based ODE solver 20x-100x faster than those in Jax and PyTorch
Link to GitHub repo from the abstract: https://github.com/SciML/DiffEqGPU.jl
jax
- I want a good parallel computer
-
Show HN: Localscope–Limit scope of Python functions for reproducible execution
localscope is a small Python package that disassembles functions to check if they access global variables they shouldn't. I wrote this a few years ago to detect scope bugs which are common in Jupyter notebooks. It's recently come in handy writing jax code (https://github.com/jax-ml/jax) because it requires pure functions. Thought I'd share.
- Zest
-
KlongPy: High-Performance Array Programming in Python
If you like high-performance array programming a la "numpy with JIT" I suggest looking at JAX. It's very suitable for general numeric computing (not just ML) and a very mature ecosystem.
https://github.com/jax-ml/jax
-
PyTorch is dead. Long live Jax
Nope, changing graph shape requires recompilation: https://github.com/google/jax/discussions/17191
- cuDF – GPU DataFrame Library
-
Rebuilding TensorFlow 2.8.4 on Ubuntu 22.04 to patch vulnerabilities
I found a GitHub issue that seemed similar (missing ptxas) and saw a suggestion to install nvidia-cuda-toolkit. Alright: but that exploded the container size from 6.5 GB to 12.13 GB … unacceptable 😤 (Incidentally, this is too large for Cloud Shell to build on its limited persistent disk.)
-
The Elements of Differentiable Programming
The dual numbers exist just as surely as the real numbers and have been used well over 100 years
https://en.m.wikipedia.org/wiki/Dual_number
Pytorch has had them for many years.
https://pytorch.org/docs/stable/generated/torch.autograd.for...
JAX implements them and uses them exactly as stated in this thread.
https://github.com/google/jax/discussions/10157#discussionco...
As you so eloquently stated, "you shouldn't be proclaiming things you don't actually know on a public forum," and doubly so when your claimed "corrections" are so demonstrably and totally incorrect.
-
Julia GPU-based ODE solver 20x-100x faster than those in Jax and PyTorch
On your last point, as long as you jit the topmost level, it doesn't matter whether or not you have inner jitted functions. The end result should be the same.
Source: https://github.com/google/jax/discussions/5199#discussioncom...
-
Apple releases MLX for Apple Silicon
The design of MLX is inspired by frameworks like NumPy, PyTorch, Jax, and ArrayFire.
What are some alternatives?
DiffEqBase.jl - The lightweight Base library for shared types and functionality for defining differential equation and scientific machine learning (SciML) problems
Numba - NumPy aware dynamic Python compiler using LLVM
GPUODEBenchmarks - Comparsion of Julia's GPU Kernel based ODE solvers with other open-source GPU ODE solvers
dex-lang - Research language for array processing in the Haskell/ML family
SciMLSensitivity.jl - A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
julia - The Julia Programming Language