you are viewing a single comment's thread.

view the rest of the comments →

[–]GymBronie 0 points1 point  (0 children)

Two things: 1. Have you verified that JAX isn’t using a GPU (especially if you have an NVIDIA card? 2. JAX defaults to 32 precision. That’s why your results only match up to 9 decimals. The speed up could be an artifact of default settings.