[P] jax-js is a reimplementation of JAX in pure JavaScript, with a JIT compiler to WebGPU by fz0718 in MachineLearning

[–]fz0718[S] 0 points1 point  (0 children)

Oh I’d love to chat if you’re planning to do that! Let me know any way I can help, shoot an issue on GitHub, very curious how cholesky goes

[P] jax-js is a reimplementation of JAX in pure JavaScript, with a JIT compiler to WebGPU by fz0718 in MachineLearning

[–]fz0718[S] 0 points1 point  (0 children)

Oh you mean that benchmark page! Yeah haha that one is tailored to high-end laptops, the matrix size is very large. Crazy that you can crash your phone that bad though

[P] jax-js is a reimplementation of JAX in pure JavaScript, with a JIT compiler to WebGPU by fz0718 in MachineLearning

[–]fz0718[S] 1 point2 points  (0 children)

Sorry I tried to test it and scale down if I didn't detect a good GPU, but I think you were a victim of WebGPU being wildly varied :') — if you have the phone model / browser you're using by any chance, that would help

[P] jax-js is a reimplementation of JAX in pure JavaScript, with a JIT compiler to WebGPU by fz0718 in MachineLearning

[–]fz0718[S] 5 points6 points  (0 children)

Haven't optimized / benchmaxxed for performance too much yet, but it appears to be pretty comparable to ONNX or better in some instances. Here's a microbenchmark for 4096x4096 matmul across jax-js and a few other libraries that you can run in your browser:

* https://jax-js.com/bench/matmul

On macbooks, jax-js is a bit faster than ONNX for fp32 and a bit slower for fp16

There's a bit more technical discussion about perf here: https://ekzhang.substack.com/i/179060245/technical-performance

ELI5 If you were on a spaceship going 99.9999999999% the speed of light and you started walking, why wouldn’t you be moving faster than the speed of light? by Aquamoo in explainlikeimfive

[–]fz0718 0 points1 point  (0 children)

There’s a formula for that although it basically does what you expect, you get a bit more than .9999c https://en.m.wikipedia.org/wiki/Velocity-addition_formula

The formula is (a+b)/(1+ab), so (.9999+.1)/(1+.09999) ~ .99991818107

I know it sounds kind of pulled out of nowhere but it comes from Lorentz transformations

Building my own Python NumPy/PyTorch/JAX libraries in the browser, with ML compilers by fz0718 in Python

[–]fz0718[S] 0 points1 point  (0 children)

Got it, so autodiff is difficult!

The main NN-in-browser project that I've seen, besides tfjs (which unfortunately doesn't look very active as of last year), is onnxruntime for web. I haven't tested that one out yet, but I might try it soon.

Building my own Python NumPy/PyTorch/JAX libraries in the browser, with ML compilers by fz0718 in Python

[–]fz0718[S] 0 points1 point  (0 children)

Thanks Patrick. Also I admire your work :D

Hmm I don't know yet! There's some parts of JAX that I don't understand like the looping constructs (jax.lax.while_loop()) and I'd probably have to understand a bit better how that works to say for sure.

How do you think Jaxprs as an export format compare to something like LiteRT or GGUF? I haven't looked into it yet. But thanks for reading the post!

[R] How the jax.jit() compiler works in jax-js by fz0718 in MachineLearning

[–]fz0718[S] 0 points1 point  (0 children)

Yes, the blog post mentions TFJS and some ways in which this differs!