-
jax
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
-
InfluxDB
Power Real-Time Data Analytics at Scale. Get real-time insights from all types of time series data with InfluxDB. Ingest, query, and analyze billions of data points in real-time with unbounded cardinality.
>The memory usage and schedule of a given program/graph is statically optimized by the compiler
Is it really though? The only thing I see is
https://github.com/tensorflow/tensorflow/blob/95cdeaa8c848fd...
which traces back to
https://github.com/tensorflow/tensorflow/blob/54a8a3b373918b...
which doesn't anything smart that i can tell.
> The downside of Jax is it’s not easy to debug. PyTorch, for better or for worse, will actually run your Python code as you wrote it.
Hmm. Jax's ease of debugging was the very first thing that caught my attention: https://blog.gpt4.org/jaxtpu#:~:text=pdb.set_trace()
> I ran it on the TPU VM, saw the loss curve go down, and it was like an electric shock. "Wow! That actually... worked? Huh. that's weird. Things never work on the first try. I'm impressed."
> Then I plopped `import pdb; pdb.set_trace()` in the middle of the `loss` function and ran it again. It dropped me into the Python debugger.
> There was a tensor named `X_bt`. I typed `X_bt`. The debugger printed the value of `X_bt`.
> I was able to print out all the values of every variable, just like you'd expect Python to be able to do.
> There was a tensor named `Y_bt`. I typed `X_bt + Y_bt`. I was now staring at exactly what I expected: the sum of those two tensors.
> I could write `x + y`, or create new variables, or anything else I wanted.
> Now I was real impressed.
> If it sounds weird that I'm so easily impressed, it's because, you godda understand: until now, TPUs were a complete pain in the ass to use. I kept my feelings to myself, because I understood that the Cloud TPU team were working hard to improve TPUs, and the TFRC support team was wonderful, and I had so many TPUs to play with. But holy moly, if you were expecting any of the above examples to just work on the first try when using Tensorflow V1 on TPUs, you were in for a rude awakening. And if you thought "Well, Tensorflow v2 is supposedly a lot better, right? Surely I'll be able to do basic things without worrying...."
> ... no. Not even close. Not until Jax + TPU VMs.
In the subsequent year, it's been nothing but joy.
If the problem is that you want to see tensor values in a JIT'ed function, use a host callback. You can run actual Python wherever you want: https://jax.readthedocs.io/en/latest/jax.experimental.host_c...
> This module introduces the host callback functions call(), id_tap(), and id_print(), that send their arguments from the device to the host and invoke user-defined Python functions on the host, optionally returning results back to the device computation.
The nice part is, there's no "magic" under the hood. If you get a chance, I highly recommend reading through Autodidax: https://jax.readthedocs.io/en/latest/autodidax.html
Autodidax is a pure-python implementation of jax. (Literally in one file, on that page.) It walks you through how every aspect of jax works.
Delightfully, I found a secret branch where autodidax also implements host callbacks: https://github.com/google/jax/blob/effect-types/docs/autodid...
If you scroll to the very bottom of that file, you'll see an example of compiling your own XLA JIT'ed code which subsequently calls back into Python. TPUs do precisely the same thing.
Point being:
> PyTorch, for better or for worse, will actually run your Python code as you wrote it.
... is also true of jax, to within a rounding error less than "I personally don't mind writing id_print(x) instead of print(x)." :)
The pytorch programming model is just really hard to adapt to an XLA-like compiler. Imperative python code doesn't translate to an ML graph compiler particularly well; Jax's API is functional, so it's easier to translate to the XLA API. By contrast, torch/xla uses "lazy tensors" that record the computation graph and compile when needed. The trouble is, if the compute graph changes from run to run, you end up recompiling a lot.
I guess in Jax you'd just only apply `jax.jit` to the parts where the compute graph is static? I'd be curious to see examples of how this works in practice. Fwiw, there's an offshoot of pytorch that is aiming to provide this sort of API (see https://github.com/pytorch/functorch and look at eager_compilation.py).
(Disclaimer: I worked on this until quite recently.)