Training Deep Networks with Data Parallelism in Jax

This page summarizes the projects mentioned and recommended in the original post on news.ycombinator.com

InfluxDB - Power Real-Time Data Analytics at Scale
Get real-time insights from all types of time series data with InfluxDB. Ingest, query, and analyze billions of data points in real-time with unbounded cardinality.
www.influxdata.com
featured
SaaSHub - Software Alternatives and Reviews
SaaSHub helps you find the best software and product alternatives
www.saashub.com
featured
  • jax

    Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

  • Thanks for taking the time to explain these.

    > It's been a bit, but I think the most frustrating errors were around mapping pytrees (like this issue https://github.com/google/jax/issues/9928).

    We've improved some of these pytree error messages but it seems that vmap one is still not great. Thanks for the ping on it.

    > Also the barriers where I couldn't disable jit. IIRC pmap automatically jits, so there was no way to avoid staging that part out.

    That was indeed a longstanding issue in pmap's implementation. And since people came to expect jit to be "built in" to pmap, it wasn't easy to revise.

    However, we recently (https://github.com/google/jax/pull/11854) made `jax.disable_jit()` work with pmap, in the sense that it makes pmap execute eagerly, so that you can print/pdb/etc to your heart's content. (The pmap successor, shard_map (https://jax.readthedocs.io/en/latest/jep/14273-shard-map.htm...), is eager by default. Also it has uniformly good error messages from the start!)

    > Next time I encounter something particularly opaque, I'll share on the github issue tracker.

    Thank you for the constructive feedback!

  • equinox

    Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/

  • It sounds like you're concerned about how downstream libraries tend to wrap JAX transformations to handle their own thing? (E.g. `haiku.grad`.)

    If so, then allow me to make my usual advert here for Equinox:

    https://github.com/patrick-kidger/equinox

    This actually works with JAX's native transformations. (There's no `equinox.vmap` for example.)

    On higher-order functions more generally, Equinox offers a way to control these quite carefully, by making ubiquitous use of callables that are also pytrees. E.g. a neural network is both a callable in that it has a forward pass, and a pytree in that it records its parameters in its tree structure.

  • InfluxDB

    Power Real-Time Data Analytics at Scale. Get real-time insights from all types of time series data with InfluxDB. Ingest, query, and analyze billions of data points in real-time with unbounded cardinality.

    InfluxDB logo
NOTE: The number of mentions on this list indicates mentions on common posts plus user suggested alternatives. Hence, a higher number means a more popular project.

Suggest a related project

Related posts