Our great sponsors
-
cleanrl
High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
-
jax
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
-
InfluxDB
Power Real-Time Data Analytics at Scale. Get real-time insights from all types of time series data with InfluxDB. Ingest, query, and analyze billions of data points in real-time with unbounded cardinality.
-
jaxrl
JAX (Flax) implementation of algorithms for Deep Reinforcement Learning with continuous action spaces.
I prototyped a PyTorch + JIT version of DDPG here, but it's nowhere near as fast. So the speed up also likely comes from JAX's optimization (i.e., putting things on XLA and having the pure functional paradigm, which may make things easier to process and thus faster).
These are the main gotchas I had * https://github.com/google/jax/issues/2697 * https://github.com/deepmind/optax/issues/366
These are the main gotchas I had * https://github.com/google/jax/issues/2697 * https://github.com/deepmind/optax/issues/366
https://github.com/ikostrikov/jaxrl would be another great reference implementation. Probably you want to also checkout the docs for jax, flax, and optax.