Trade-Offs in Automatic Differentiation: TensorFlow, PyTorch, Jax, and Julia

This page summarizes the projects mentioned and recommended in the original post on news.ycombinator.com

Our great sponsors
  • InfluxDB - Power Real-Time Data Analytics at Scale
  • WorkOS - The modern identity platform for B2B SaaS
  • SaaSHub - Software Alternatives and Reviews
  • julia

    The Julia Programming Language

  • You probably could get perturbation confusion in reverse mode, but it's not an easy trap like it is with forward mode. The problem with forward mode AD is that it's deceptively easy: you transform every number a into a pair of numbers (a,b) and you change every function from f(a) to (f(a),f'(a)), put the chain rule on it, and you're done. Whether you call it dual number arithmetic or a compiler transformation it's the same thing in the end. The issue with perturbation confusion is that you've created this "secret extra number" to store the derivative data, and so if two things are differentiating code at the same time, you need to make sure you turn (a,b) into (a,a',b,b') and that all layers of AD are always grabbing the right value out of there. Note that in the way I wrote it, if you assumed "the b term is always num[2] in the tuple", oops perturbation confusion, and so your generated code needs to be "smart" (but not lose efficiency!). Thus the fixes are proofs and tagging systems that ensure the right perturbation terms are always used in the right places.

    With reverse mode AD, this is much less likely to be an issue because the AD system isn't necessarily storing and working on hidden extensions to the values, it's running a function forwards and then running a separate function backwards having remembered some values from the forward pass. If the remembered values are correct and never modified, then generating a higher order derivative is just as safe as the first. But that last little detail is thus what I think is most akin to perturbation confusion in reverse mode: reverse mode has the assumption that the objects captured in the forward pass will not be changed (or will be at least be back in the correct state) when it is trying to reverse. The easy way to break this assumption doesn't even require second derivatives. The easiest way to break it is mutation: if you walk forward by doing Ax, then the reverse pass wants to do A'v so it just keeps the pointer to A, but if A gets mutated in the meantime then using that pointer is incorrect. This is the reason why most AD systems simply disallow mutation except in very special unoptimized cases (PyTorch, Jax, Zygote, ...).

    Enzyme.jl is an exception because it takes a global analysis of the program it's differentiating (with proper escape analysis etc. passes at the LLVM level) in order to know that any mutation going forward will be reversed during the reverse path, so by the time it gets back to A'*v it knows A will be the same. Higher level ADs could go cowboy YOLO style and just assume the reversed matrix is correct (and it might be a lot of the time), though that causes some pretty major concerns for correctness. The other option is to simply make a full copy of A every time you mutate an element, so have fun if you loop through your weight matrix. The Diffractor.jl near future approach is more like Haskell GHC where it just wants you to give it the non-mutating code so it can try and generate the mutating code when that would be more efficient (https://github.com/JuliaLang/julia/pull/42465).

    So with forward-mode AD there was an entire literature around schemes of provable safety to perturbation confusion, and I'm surprised we haven't already started seeing papers about provable safety with respect to mutation in higher-level reverse-mode AD. I would suspect that the only reason why it hasn't started is that the people who write type-theoretic proofs tend to be the functional programming pure function folks that tell people to never mutate anyways, so the literature might instead go the direction of escape analysis proofs to optimize immutable array code to (and beyond) the performance of mutation code on commonly mutating applications. Either way it's getting there with the same purpose in mind.

  • autograd

    Efficiently computes derivatives of numpy code.

  • > fun fact, the Jax folks at Google Brain did have a Python source code transform AD at one point but it was scrapped essentially because of these difficulties

    I assume you mean autograd?

    https://github.com/HIPS/autograd

  • InfluxDB

    Power Real-Time Data Analytics at Scale. Get real-time insights from all types of time series data with InfluxDB. Ingest, query, and analyze billions of data points in real-time with unbounded cardinality.

    InfluxDB logo
  • tangent

    Discontinued Source-to-Source Debuggable Derivatives in Pure Python

  • No, autograd acts similarly to PyTorch in that it builds a tape that it reverses while PyTorch just comes with more optimized kernels (and kernels that act on GPUs). The AD that I was referencing was tangent (https://github.com/google/tangent). It was an interesting project but it's hard to see who the audience is. Generating Python source code makes things harder to analyze, and you cannot JIT compile the generated code unless you could JIT compile Python. So you might as well first trace to a JIT-compliable sublanguage and do the actions there, which is precisely what Jax does. In theory tangent is a bit more general, and maybe you could mix it with Numba, but then it's hard to justify. If it's more general then it's not for the standard ML community for the same reason as the Julia tools, but then it better do better than the Julia tools in the specific niche that they are targeting. Jax just makes much more sense for the people who were building it, it chose its niche very well.

  • kotlingrad

    🧩 Shape-Safe Symbolic Differentiation with Algebraic Data Types

  • and that there is a mature library for autodiff https://github.com/breandan/kotlingrad

  • dex-lang

    Research language for array processing in the Haskell/ML family

  • You might want to look at dex[1].

    [1]: https://github.com/google-research/dex-lang

  • Enzyme

    High-performance automatic differentiation of LLVM and MLIR. (by EnzymeAD)

  • that seems one of the points of enzyme[1], which was mentioned in the article.

    [1] - https://enzyme.mit.edu/

    being able in effect do interprocedural cross language analysis seems awesome.

NOTE: The number of mentions on this list indicates mentions on common posts plus user suggested alternatives. Hence, a higher number means a more popular project.

Suggest a related project

Related posts