Show HN: How does Jax allocate memory on a TPU? An interactive C++ walkthrough

This page summarizes the projects mentioned and recommended in the original post on news.ycombinator.com

InfluxDB - Power Real-Time Data Analytics at Scale
Get real-time insights from all types of time series data with InfluxDB. Ingest, query, and analyze billions of data points in real-time with unbounded cardinality.
www.influxdata.com
featured
SaaSHub - Software Alternatives and Reviews
SaaSHub helps you find the best software and product alternatives
www.saashub.com
featured
  • tensorflow

    An Open Source Machine Learning Framework for Everyone

  • >The memory usage and schedule of a given program/graph is statically optimized by the compiler

    Is it really though? The only thing I see is

    https://github.com/tensorflow/tensorflow/blob/95cdeaa8c848fd...

    which traces back to

    https://github.com/tensorflow/tensorflow/blob/54a8a3b373918b...

    which doesn't anything smart that i can tell.

  • jax

    Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

  • > The downside of Jax is it’s not easy to debug. PyTorch, for better or for worse, will actually run your Python code as you wrote it.

    Hmm. Jax's ease of debugging was the very first thing that caught my attention: https://blog.gpt4.org/jaxtpu#:~:text=pdb.set_trace()

    > I ran it on the TPU VM, saw the loss curve go down, and it was like an electric shock. "Wow! That actually... worked? Huh. that's weird. Things never work on the first try. I'm impressed."

    > Then I plopped `import pdb; pdb.set_trace()` in the middle of the `loss` function and ran it again. It dropped me into the Python debugger.

    > There was a tensor named `X_bt`. I typed `X_bt`. The debugger printed the value of `X_bt`.

    > I was able to print out all the values of every variable, just like you'd expect Python to be able to do.

    > There was a tensor named `Y_bt`. I typed `X_bt + Y_bt`. I was now staring at exactly what I expected: the sum of those two tensors.

    > I could write `x + y`, or create new variables, or anything else I wanted.

    > Now I was real impressed.

    > If it sounds weird that I'm so easily impressed, it's because, you godda understand: until now, TPUs were a complete pain in the ass to use. I kept my feelings to myself, because I understood that the Cloud TPU team were working hard to improve TPUs, and the TFRC support team was wonderful, and I had so many TPUs to play with. But holy moly, if you were expecting any of the above examples to just work on the first try when using Tensorflow V1 on TPUs, you were in for a rude awakening. And if you thought "Well, Tensorflow v2 is supposedly a lot better, right? Surely I'll be able to do basic things without worrying...."

    > ... no. Not even close. Not until Jax + TPU VMs.

    In the subsequent year, it's been nothing but joy.

    If the problem is that you want to see tensor values in a JIT'ed function, use a host callback. You can run actual Python wherever you want: https://jax.readthedocs.io/en/latest/jax.experimental.host_c...

    > This module introduces the host callback functions call(), id_tap(), and id_print(), that send their arguments from the device to the host and invoke user-defined Python functions on the host, optionally returning results back to the device computation.

    The nice part is, there's no "magic" under the hood. If you get a chance, I highly recommend reading through Autodidax: https://jax.readthedocs.io/en/latest/autodidax.html

    Autodidax is a pure-python implementation of jax. (Literally in one file, on that page.) It walks you through how every aspect of jax works.

    Delightfully, I found a secret branch where autodidax also implements host callbacks: https://github.com/google/jax/blob/effect-types/docs/autodid...

    If you scroll to the very bottom of that file, you'll see an example of compiling your own XLA JIT'ed code which subsequently calls back into Python. TPUs do precisely the same thing.

    Point being:

    > PyTorch, for better or for worse, will actually run your Python code as you wrote it.

    ... is also true of jax, to within a rounding error less than "I personally don't mind writing id_print(x) instead of print(x)." :)

  • InfluxDB

    Power Real-Time Data Analytics at Scale. Get real-time insights from all types of time series data with InfluxDB. Ingest, query, and analyze billions of data points in real-time with unbounded cardinality.

    InfluxDB logo
  • functorch

    functorch is JAX-like composable function transforms for PyTorch.

  • The pytorch programming model is just really hard to adapt to an XLA-like compiler. Imperative python code doesn't translate to an ML graph compiler particularly well; Jax's API is functional, so it's easier to translate to the XLA API. By contrast, torch/xla uses "lazy tensors" that record the computation graph and compile when needed. The trouble is, if the compute graph changes from run to run, you end up recompiling a lot.

    I guess in Jax you'd just only apply `jax.jit` to the parts where the compute graph is static? I'd be curious to see examples of how this works in practice. Fwiw, there's an offshoot of pytorch that is aiming to provide this sort of API (see https://github.com/pytorch/functorch and look at eager_compilation.py).

    (Disclaimer: I worked on this until quite recently.)

NOTE: The number of mentions on this list indicates mentions on common posts plus user suggested alternatives. Hence, a higher number means a more popular project.

Suggest a related project

Related posts

  • Show HN: Designing Bridges with PyTorch

    4 projects | news.ycombinator.com | 11 Jan 2024
  • Side Quest Devblog #1: These Fakes are getting Deep

    3 projects | dev.to | 29 Apr 2024
  • My Favorite DevTools to Build AI/ML Applications!

    9 projects | dev.to | 23 Apr 2024
  • TensorFlow-metal on Apple Mac is junk for training

    1 project | news.ycombinator.com | 16 Jan 2024
  • Open Source Advent Fun Wraps Up!

    10 projects | dev.to | 5 Jan 2024