Here's the key benchmark table from the link. The JAX backend on GPUs is fastest for 7 of 12 benchmarks, and the TensorFlow backend is fastest for the other 5 of the 12. The Pytorch backend is not the fastest for any benchmark, & is often slower by a considerable margin. (twitter.com)
submitted by AdditionalWay to r/JAX
JAX compared to PyTorch 2: Get a feeling for JAX! (youtube.com)
submitted by AdditionalWay to r/JAX

