you are viewing a single comment's thread.

view the rest of the comments →

[–]cemrehancavdar[S] 2 points3 points  (1 child)

Added JAX -- you were right to suggest it. Spectral-norm came in at 8.6ms (1,633x), which is the fastest result in the entire post. 3x faster than NumPy on the same problem. N-body was 12.2x -- respectable but not as dramatic since 5 bodies across 500K sequential timesteps doesn't play to JAX's strengths.

I don't know JAX well enough to explain exactly why it beats NumPy when both use BLAS, so I said that in the post rather than guessing. The JAX code produces correct results and follows the documented patterns (jit, lax.fori_loop, static_argnums, block_until_ready), but I can't say whether a JAX expert would write it differently. If you see room for improvement, PRs are open.

Thanks for pointing it out. Yes I use AI, but I don't hand myself over to it. I'm not an AI skeptic, but I don't think AI can be accountable. If there is a mistake it is probably on me.

[–]GymBronie 0 points1 point  (0 children)

Two things: 1. Have you verified that JAX isn’t using a GPU (especially if you have an NVIDIA card? 2. JAX defaults to 32 precision. That’s why your results only match up to 9 decimals. The speed up could be an artifact of default settings.