Very interesting that it's coming from Google. I did my masters in tokamak simulation, so my first question is about performance. Python is very rarely used in this space just for performance reasons. Even though Python can call out to BLAS or whatever, it's still usually worth it to code in Fortran or C or maybe Julia.
I just noticed this podcast episode on Deep RL for fusion reactors was recently published, if anyone likes this stuff. I have not listened yet, but this podcast in general is great.
I recently started using JAX for some ion-optics work in accelerator physics. I have found it very very good. The autodiff stuff is magical for doing optimisation work, but even just as a compiled-numpy, I have found it very easy to get highly performant code. For reference, I previously tried roughly the same thing in “numba”, and wasn’t able to get anywhere near the same performance as JAX, even running on the CPU, which I understand is JAX’s weakest backend. By and large I have just written basically idiomatic Python/numpy code — sprinkled a few “vmap”s and “scan”s around, and got great results. I’m very pleased with JAX.