CleanRL now has a DDPG + JAX implementation roughly 2.5-4x faster than DDPG + PyTorch

This page summarizes the projects mentioned and recommended in the original post on /r/reinforcementlearning

Our great sponsors
  • InfluxDB - Power Real-Time Data Analytics at Scale
  • WorkOS - The modern identity platform for B2B SaaS
  • SaaSHub - Software Alternatives and Reviews
  • cleanrl

    High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)

  • I prototyped a PyTorch + JIT version of DDPG here, but it's nowhere near as fast. So the speed up also likely comes from JAX's optimization (i.e., putting things on XLA and having the pure functional paradigm, which may make things easier to process and thus faster).

  • jax

    Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

  • These are the main gotchas I had * https://github.com/google/jax/issues/2697 * https://github.com/deepmind/optax/issues/366

  • InfluxDB

    Power Real-Time Data Analytics at Scale. Get real-time insights from all types of time series data with InfluxDB. Ingest, query, and analyze billions of data points in real-time with unbounded cardinality.

    InfluxDB logo
  • optax

    Optax is a gradient processing and optimization library for JAX.

  • These are the main gotchas I had * https://github.com/google/jax/issues/2697 * https://github.com/deepmind/optax/issues/366

  • jaxrl

    JAX (Flax) implementation of algorithms for Deep Reinforcement Learning with continuous action spaces.

  • https://github.com/ikostrikov/jaxrl would be another great reference implementation. Probably you want to also checkout the docs for jax, flax, and optax.

NOTE: The number of mentions on this list indicates mentions on common posts plus user suggested alternatives. Hence, a higher number means a more popular project.

Suggest a related project

Related posts