Our great sponsors
-
WorkOS
The modern identity platform for B2B SaaS. The APIs are flexible and easy-to-use, supporting authentication, user identity, and complex enterprise features like SSO and SCIM provisioning.
There is no free lunch:).
I remember spending a summer using Template Model Builder (TMB), which is a useful R/C++ automatic differentiation (AD) framework, for working with accelerated failure time models. For these models, the survival to time T given covariates X is defined by S(t|X) = P(T>t|X) = S_0(t exp(-beta^T X)) for baseline survival S_0(t). I wanted to use splines for the baseline survival and then use AD for gradients and random effects. Unfortunately, after implementing the splines in template C++, I found a web page entitled "Things you should NOT do in TMB" (https://github.com/kaskr/adcomp/wiki/Things-you-should-NOT-d...) - which included using if statements that are based on coefficients. In this case, the splines for S_0 depend on beta, which is this specific excluded case:(. An older framework (ADMB) did not have this constraint, but dissemination of code was more difficult. Finally, PyTorch did not have an implementation of B-splines or an implementation for Laplace's approximation. Returning to my opening comment, there is no free lunch.
"Maybe they let you declare some subgraph as 'dynamic' to avoid static optimizations?" What you just described is Tensorflow Eager and why it has some performance issues. XLA makes some pretty strong assumptions and I don't that should change. Tensorflow's ability to automatically generate good parallelized production code stems from the restrictions it has imposed. So I wouldn't even try for a "one true AD to rule them all" since making things more flexible will reduce the amount of compiler optimizations that can be automatically performed.
To get the more flexible form, you really would want to do it in a way that uses a full programming language's IR as its target. I think trying to use a fully dynamic programming language IR directly (Python, R, etc.) directly would be pretty insane because it would be hard to enforce rules and get performance. So some language that has a front end over an optimizing compiler (LLVM) would probably make the most sense. Zygote and Diffractor uses Julia's IR, but there are other ways to do this as well. Enzyme (https://github.com/wsmoses/Enzyme.jl) uses the LLVM IR directly for doing source-to-source translations. Using some dialect of LLVM (provided by MLIR) might be an interesting place to write a more ML-focused flexible AD system. Swift for Tensorflow used the Swift IR. This mindset starts to show why those tools were chosen.
Related posts
- Custom gradients in Enzyme
- Engineering Trade-Offs in Automatic Differentiation: from TensorFlow and PyTorch to Jax and Julia
- Enzyme – High-performance automatic differentiation of LLVM (r/MachineLearning)
- Enzyme – High-performance automatic differentiation of LLVM (r/MachineLearning)
- Enzyme – High-performance automatic differentiation of LLVM