Keras Core: Keras for TensorFlow, Jax, and PyTorch

This page summarizes the projects mentioned and recommended in the original post on news.ycombinator.com

Our great sponsors
  • InfluxDB - Power Real-Time Data Analytics at Scale
  • WorkOS - The modern identity platform for B2B SaaS
  • SaaSHub - Software Alternatives and Reviews
  • keras-nlp

    Modular Natural Language Processing workflows with Keras

  • Yes, you can check out KerasCV and KerasNLP which host pretrained models like ResNet, BERT, and many more. They run on all backends as of the latest releases (today), and converting them to be backend-agnostic was pretty smooth! It took a couple of weeks to convert the whole packages.

    https://github.com/keras-team/keras-nlp/tree/master/keras_nl...

  • keras-cv

    Industry-strength Computer Vision workflows with Keras

  • InfluxDB

    Power Real-Time Data Analytics at Scale. Get real-time insights from all types of time series data with InfluxDB. Ingest, query, and analyze billions of data points in real-time with unbounded cardinality.

    InfluxDB logo
  • keras-core

    A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.

  • We are still working on this feature. We try to have it in model.compile(jit_compile=True). https://github.com/keras-team/keras-core/blob/v0.1.0/keras_c...

  • returnn

    The RWTH extensible training framework for universal recurrent neural networks

  • That looks very interesting.

    I actually have developed (and am developing) sth very similar, what we call the RETURNN frontend, a new frontend + new backends for our RETURNN framework. The new frontend is supporting very similar Python code to define models as you see in PyTorch or Keras, i.e. a core Tensor class, a base Module class you can derive, a Parameter class, and then a core functional API to perform all the computations. That supports multiple backends, currently mostly TensorFlow (graph-based) and PyTorch, but JAX was something I also planned. Some details here: https://github.com/rwth-i6/returnn/issues/1120

    (Note that we went a bit further ahead and made named dimensions a core principle of the framework.)

    (Example beam search implementation: https://github.com/rwth-i6/i6_experiments/blob/14b66c4dc74c0...)

    One difficulty I found was how design the API in a way that works well both for eager-mode frameworks (PyTorch, TF eager-mode) and graph-based frameworks (TF graph-mode, JAX). That mostly involves everything where there is some state, or sth code which should not just execute in the inner training loop but e.g. for initialization only, or after each epoch, or whatever. So for example:

    - Parameter initialization.

    - Anything involving buffers, e.g. batch normalization.

    - Other custom training loops? Or e.g. an outer loop and an inner loop (e.g. like GAN training)?

    - How to implement sth like weight normalization? In PyTorch, the module.param is renamed, and then there is a pre-forward hook, which on-the-fly calculates module.param for each call for forward. So, just following the same logic for both eager-mode and graph-mode?

    - How to deal with control flow context, accessing values outside the loop which came from inside, etc. Those things are naturally possible eager-mode, where you would get the most recent value, and where there is no real control flow context.

    - Device logic: Have device defined explicitly for each tensor (like PyTorch), or automatically eagerly move tensors to the GPU (like TensorFlow)? Moving from one device to another (or CPU) is automatic or must be explicit?

    I see that you have keras_core.callbacks.LambdaCallback which is maybe similar, but can you effectively update the logic of the module in there?

  • i6_experiments

  • That looks very interesting.

    I actually have developed (and am developing) sth very similar, what we call the RETURNN frontend, a new frontend + new backends for our RETURNN framework. The new frontend is supporting very similar Python code to define models as you see in PyTorch or Keras, i.e. a core Tensor class, a base Module class you can derive, a Parameter class, and then a core functional API to perform all the computations. That supports multiple backends, currently mostly TensorFlow (graph-based) and PyTorch, but JAX was something I also planned. Some details here: https://github.com/rwth-i6/returnn/issues/1120

    (Note that we went a bit further ahead and made named dimensions a core principle of the framework.)

    (Example beam search implementation: https://github.com/rwth-i6/i6_experiments/blob/14b66c4dc74c0...)

    One difficulty I found was how design the API in a way that works well both for eager-mode frameworks (PyTorch, TF eager-mode) and graph-based frameworks (TF graph-mode, JAX). That mostly involves everything where there is some state, or sth code which should not just execute in the inner training loop but e.g. for initialization only, or after each epoch, or whatever. So for example:

    - Parameter initialization.

    - Anything involving buffers, e.g. batch normalization.

    - Other custom training loops? Or e.g. an outer loop and an inner loop (e.g. like GAN training)?

    - How to implement sth like weight normalization? In PyTorch, the module.param is renamed, and then there is a pre-forward hook, which on-the-fly calculates module.param for each call for forward. So, just following the same logic for both eager-mode and graph-mode?

    - How to deal with control flow context, accessing values outside the loop which came from inside, etc. Those things are naturally possible eager-mode, where you would get the most recent value, and where there is no real control flow context.

    - Device logic: Have device defined explicitly for each tensor (like PyTorch), or automatically eagerly move tensors to the GPU (like TensorFlow)? Moving from one device to another (or CPU) is automatic or must be explicit?

    I see that you have keras_core.callbacks.LambdaCallback which is maybe similar, but can you effectively update the logic of the module in there?

  • WorkOS

    The modern identity platform for B2B SaaS. The APIs are flexible and easy-to-use, supporting authentication, user identity, and complex enterprise features like SSO and SCIM provisioning.

    WorkOS logo
NOTE: The number of mentions on this list indicates mentions on common posts plus user suggested alternatives. Hence, a higher number means a more popular project.

Suggest a related project

Related posts