-
This might be a tangent, but why does JAX only support the saving / serialization of AOT compilation executables for TPU [1]? It would be great to have the ability to save compiled functions and not have to JIT compile something every time you restart a session.
(Julia used to have this problem too, but they've made great progress on caching JIT compiled functions to reduce latency.)
[1]: https://github.com/google/maxtext?tab=readme-ov-file#ahead-o...
-
InfluxDB
InfluxDB – Built for High-Performance Time Series Workloads. InfluxDB 3 OSS is now GA. Transform, enrich, and act on time series data directly in the database. Automate critical tasks and eliminate the need to move data externally. Download now.
-
EasyLM
Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
-
-
[3]: https://github.com/google-research/t5x
Asking because I have worked extensively on training a large model on a TPU cluster, and started with Levanter, then tried MaxText, and finally ended up on EasyLM. My thoughts are:
- Levanter is well intentioned but is unproven and lacking in features. For instance, their sharding is odd in that it requires embedding dimension to be a multiple of the number of devices, so I can't test using a model with embedding dimension 768 on a 512-device pod. Lost confidence in Levanter after finding some glaring correctness bugs (and helping get them fixed). Also, while I'm a huge fan of Equinox's approach, it's sadly underdeveloped (for instance, there's no way to specify non-default weight initialization strategies without manually doing model surgery to set weights).
- MaxText was just very difficult to work with. We felt like we were fighting against it every time we needed to change something because we would be digging through numerous needless layers of abstraction. My favorite was after one long day of debugging, I found a function who's only purpose was to pass its arguments to another function untouched; this function's only purpose was to pass its arguments untouched to a new, third function, that then slightly changed them and passed them to a fourth function that did the work.
- EasyLM is, as the name says, easy. But on a deeper dive, the sharding functionality seems to be underdeveloped. What they call "FSDP" is not necessarily true FSDP, it's literally just a certain axis that the JAX mesh is being sharded around that happens to shard some data axes and some model weight axes.
I'm still searching for a "perfect" JAX LLM codebase - any pointers?
-
Is t5x an encoder/decoder architecture?
Some more general options.
The Flax ecosystem
https://github.com/google/flax?tab=readme-ov-file
or dm-haiku
https://github.com/google-deepmind/dm-haiku
were some of the best developed communities in the Jax AI field
Perhaps the “trax” repo? https://github.com/google/trax
Some HF examples https://github.com/huggingface/transformers/tree/main/exampl...
Sadly it seems much of the work is proprietary these days, but one example could be Grok-1, if you customize the details. https://github.com/xai-org/grok-1/blob/main/run.py
-
Is t5x an encoder/decoder architecture?
Some more general options.
The Flax ecosystem
https://github.com/google/flax?tab=readme-ov-file
or dm-haiku
https://github.com/google-deepmind/dm-haiku
were some of the best developed communities in the Jax AI field
Perhaps the “trax” repo? https://github.com/google/trax
Some HF examples https://github.com/huggingface/transformers/tree/main/exampl...
Sadly it seems much of the work is proprietary these days, but one example could be Grok-1, if you customize the details. https://github.com/xai-org/grok-1/blob/main/run.py
-
Is t5x an encoder/decoder architecture?
Some more general options.
The Flax ecosystem
https://github.com/google/flax?tab=readme-ov-file
or dm-haiku
https://github.com/google-deepmind/dm-haiku
were some of the best developed communities in the Jax AI field
Perhaps the “trax” repo? https://github.com/google/trax
Some HF examples https://github.com/huggingface/transformers/tree/main/exampl...
Sadly it seems much of the work is proprietary these days, but one example could be Grok-1, if you customize the details. https://github.com/xai-org/grok-1/blob/main/run.py
-
Stream
Stream - Scalable APIs for Chat, Feeds, Moderation, & Video. Stream helps developers build engaging apps that scale to millions with performant and flexible Chat, Feeds, Moderation, and Video APIs and SDKs powered by a global edge network and enterprise-grade infrastructure.
-
transformers
🤗 Transformers: the model-definition framework for state-of-the-art machine learning models in text, vision, audio, and multimodal models, for both inference and training.
Is t5x an encoder/decoder architecture?
Some more general options.
The Flax ecosystem
https://github.com/google/flax?tab=readme-ov-file
or dm-haiku
https://github.com/google-deepmind/dm-haiku
were some of the best developed communities in the Jax AI field
Perhaps the “trax” repo? https://github.com/google/trax
Some HF examples https://github.com/huggingface/transformers/tree/main/exampl...
Sadly it seems much of the work is proprietary these days, but one example could be Grok-1, if you customize the details. https://github.com/xai-org/grok-1/blob/main/run.py
-
Is t5x an encoder/decoder architecture?
Some more general options.
The Flax ecosystem
https://github.com/google/flax?tab=readme-ov-file
or dm-haiku
https://github.com/google-deepmind/dm-haiku
were some of the best developed communities in the Jax AI field
Perhaps the “trax” repo? https://github.com/google/trax
Some HF examples https://github.com/huggingface/transformers/tree/main/exampl...
Sadly it seems much of the work is proprietary these days, but one example could be Grok-1, if you customize the details. https://github.com/xai-org/grok-1/blob/main/run.py