Our great sponsors
-
InfluxDB
Power Real-Time Data Analytics at Scale. Get real-time insights from all types of time series data with InfluxDB. Ingest, query, and analyze billions of data points in real-time with unbounded cardinality.
-
diffrax
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
-
WorkOS
The modern identity platform for B2B SaaS. The APIs are flexible and easy-to-use, supporting authentication, user identity, and complex enterprise features like SSO and SCIM provisioning.
The Flux.jl example did this. A PR to the PyTorch example to do this would be welcome: https://github.com/chriselrod/LeNetTorch
...a C++ library with a CUDA backend. But these high-performance building blocks might only be saturating the GPU fully if the data is large enough.
I haven't looked at implementing these things, but I imagine uf you have smaller networks and thus less data, the large building blocks may not be optimal. You may for example want to fuse some operations to reduce memory latency from repeated memory access.
In PyTorch world, there are approaches for small networks as well, there is https://github.com/NVlabs/tiny-cuda-nn - as far as I understand from the first link in the README, it makes clever use of the CUDA shared memory, which can hold all the weights of a tiny network (but not larger ones).
Ask them to download Julia and try it, and file an issue if it is not fast enough. We try to have the latest available.
See for example: https://github.com/JuliaLinearAlgebra/RecursiveFactorization...
Ask them to download Julia and try it, and file an issue if it is not fast enough. We try to have the latest available.
See for example: https://github.com/JuliaLinearAlgebra/RecursiveFactorization...
A library I designed a few years ago (https://github.com/Netflix/vectorflow) is also much faster than pytorch/tensorflow in these cases.
In "small" or "very sparse" setups, you're memory bound, not compute bound. TF and Pytorch are bad at that because they assume memory movements are worth it and do very little in-place operations.
Different tools for different jobs.
The article asks "Which Micro-optimizations matter for BLAS3?", implying small dimensions, but doesn't actually tell me. The problem is well-studied, depending on what you consider "small". The most important thing is to avoid the packing step below an appropriate threshold. Implementations include libxsmm, blasfeo, and the "sup" version in blis (with papers on libxsmm and blasfeo). Eigen might also be relevant.
https://libxsmm.readthedocs.io/
https://blasfeo.syscop.de/
https://github.com/flame/blis
Taking a union of differential equations is a pretty bad idea for exactly the reasons you describe. But it's absolutely possible to parallelise multiple diffeq solves without needing to use the same time steps for each solve. That is, the time steps can be vectorised, rather than broadcast, over the batch axis.
So the only thing you actually need are the same number of steps for each solve, which can be easily accomplished by just padding out the solves that finish with slightly fewer steps. In practice this ends up introducing negligible overhead whilst solving the above issue very neatly. For example this is precisely what Diffrax (https://github.com/patrick-kidger/diffrax) does under `jax.vmap`.
I've not dug into what Julia does here; is this not already done when broadcasting `DifferentialEquations.solve`?