When I first started my project in Google JAX, I knew that that Google JAX was a pretty new coding library made for GPUs and that examples and teaching materials would be scarce for that. And in addition, I was pretty new to GPU programming in general, so I had an additional hurdle to at least learn about the basics of GPU programming and a simple understanding of why GPU programming is advantageous.
I would say that the this video of one of the architects of CUDA explaining GPU programming is a must-have to learn about GPU programming and why GPU is so much faster than CPU at larger scale calculations.
Another resource that came in handy when learning about the beginning of Google JAX programming is the this video and this blog post showing the basics of JAX programming.
I would say that the documentation of Google JAX is pretty extensive but I would like to point out this particular page about the peculiarities of JAX programming: 🔪 JAX - The Sharp Bits 🔪 — JAX documentation. It shows how everything in Google JAX is vectorized and many common tactics like array manipulation is untenable in this coding and how we need to use pure functional programming in order to use Google JAX and its very powerful just-in-time compilation (JIT).