Our great sponsors
-
swarm-jax
Swarm training framework using Haiku + JAX + Ray for layer parallel transformer language models on unreliable, heterogeneous nodes
-
WorkOS
The modern identity platform for B2B SaaS. The APIs are flexible and easy-to-use, supporting authentication, user identity, and complex enterprise features like SSO and SCIM provisioning.
I used Ray to train a massive GPT model by putting each layer on a separate TPU. Ray was able to send all the gradients back and forth as needed.
It scaled fine up to 33 TPUs (i.e. 33 layers).
Ray is impressive as hell.
By the way, I didn't write the code to do any of that. kindiana, aka "the guy that wrote GPT-J", also happened to write this: https://github.com/kingoflolz/swarm-jax/tree/master/swarm_ja...
I just ran it and it worked. Which is extraordinarily unusual for TPUs, historically speaking.
Given how Ray "provides [...] exactly-once semantics" for its actors, you could draw similarities between it and workflow-as-code frameworks such as https://temporal.io. The way that Ray splits up actors and tasks looks similar to Temporal's Workflows + Activities split: Workflows (Ray actors) contain orchestration logic and have their method calls/results durably logged. Activities (Ray tasks) perform the expensive computations and any interaction with external systems and are not durably logged.
If you're in the .NET ecosystem or interested in distributed systems in general, you may like Orleans (https://github.com/dotnet/orleans), which I work on at Microsoft. Orleans contributes the Virtual Actor model which other modern actor frameworks are starting to adopt since it is well suited for the hectic, failure-prone environment of distributed systems (which those so-called Cloud Native Apps live in). The Ray paper linked from the article (https://www.usenix.org/system/files/osdi18-moritz.pdf) discusses some similarities. Slight correction on the paper: it states that "For message delivery, Orleans provides at-least-once [...] semantics". It's at-most-once. At-least-once messaging semantics (usually implemented via automatic retries) aren't ideal for these kinds of systems, in my opinion.
There is also ucx-py that can be used with dask_cuda for rapid GPU-GPU communication