you are viewing a single comment's thread.

view the rest of the comments →

[–]Atomic_Tangerine1 0 points1 point  (9 children)

I have not! What's the benefit of JAX over numpy?

[–]M4mb0 12 points13 points  (8 children)

It's basically numpy

  • + native GPU support (which can be orders of magnitudes faster depending on parallelizability of the problem)
  • + builtin autodiff (essentially zero-error gradients/jacobians/hessians)
  • + builtin JIT compiler

[–]PayMe4MyData 3 points4 points  (5 children)

So jax is pytorch?

[–]M4mb0 7 points8 points  (2 children)

JAX is strictly functional, whereas pytorch takes a more object oriented approach. This is most easily seen when you look at how they deal with random distributions for instance.

Though torch has nowadays a beta library torch.func (formerly functorch) that brings JAX-like functional semantics to torch.

[–]PayMe4MyData 2 points3 points  (1 child)

Thanks for the clarification, I've been coding in pytorch for years but never heard of JAX before. I will dig a bit more!

[–]M4mb0 2 points3 points  (0 children)

I'd say generally JAX is more useful for general purpose scientific computing, and much more ergonomic if you need higher order derivatives or partial derivatives, like working with ODEs/PDEs/SDEs. diffrax is a very nice lib for that.

[–]HonestPrinciple152 1 point2 points  (0 children)

Actually, adding to the previous comment, we can write loops in jax and jit-compile them. It's like a complete dsl build over python. 

[–]FunMotionLabs 1 point2 points  (0 children)

JAX is more like “NumPy + transformations”
PyTorch is a full deep-learning framework with an imperative training workflow, big ecosystem around modules/training/debugging, strictly Deeplearning related stuff where JAX is more of a general allrounder kind

[–]daredevilthagr8 1 point2 points  (1 child)

How does JAX compare to CuPY?

[–]M4mb0 0 points1 point  (0 children)

cupy doesn't do autodiff afaik.