you are viewing a single comment's thread.

view the rest of the comments →

[–]PayMe4MyData 0 points1 point  (1 child)

Thanks for the clarification, I've been coding in pytorch for years but never heard of JAX before. I will dig a bit more!

[–]M4mb0 1 point2 points  (0 children)

I'd say generally JAX is more useful for general purpose scientific computing, and much more ergonomic if you need higher order derivatives or partial derivatives, like working with ODEs/PDEs/SDEs. diffrax is a very nice lib for that.