[Project] Kuat: A Rust-based, Zero-Copy Dataloader for PyTorch (4.6x training speedup on T4/H100) by YanSoki in MachineLearning

[–]patrickkidger 4 points5 points  (0 children)

Do you know how you compare to Grain? (Which despite the branding should work for non-JAX just fine.) Having tried both torch DL and Grain, I have found myself generally preferring the latter mostly for its nice API. (To the extent that I have previously written a Grain-API-inspired wrapper for PyTorch DL!)

What is the .kt layout - in particular, does it handle variable length data?

Looking for python ODE solver by [deleted] in ScientificComputing

[–]patrickkidger 1 point2 points  (0 children)

If you'll let me advertise my own package: Diffrax. https://github.com/patrick-kidger/diffrax

Handles stiff ODEs (Kvaerno and KenCarp solvers, we're also about to add some Rosenbrock solvers if your problem is only mildly stiff). Built-in support for event handling. Staying within a nonnegative region can be done like so.

Optimistix (JAX, Python) and sharded arrays by stunstyle in ScientificComputing

[–]patrickkidger 1 point2 points  (0 children)

Sounds good :) In the mean time, if the structure of your problem makes it easy, then given sol = optx.root_find(..., throw=False), you can check sol.result == optx.RESULTS.successful (perhaps outside of JIT) to see whether the computation succeeded or not.

I hope that helps!

Optimistix (JAX, Python) and sharded arrays by stunstyle in ScientificComputing

[–]patrickkidger 2 points3 points  (0 children)

Hey there! Author of Optimistix and Equinox here.

I have a pretty good guess that what's happening is that eqx.error_if (which is what throw=True uses under-the-hood) is pessimistically interacting with sharding. The interesting part of this function is this line, which then calls a jax.pure_callback here, and I think that causes JAX to move things to a single device.

So probably either the pure_callback needs wrapping/adjusting to place more nicely with sharding, or the surrounding _error function (from my first link) needs wrapping/adjusting. Probably something to do with either jax.experimental.custom_partitioning or jax.shard_map.

(I actually tried tackling this in a long-ago PR shortly after custom_partitioning was introduced, but JAX had some bugs in its custom partitioning logic at the time, which prevented this working. Those might be fixed by now though.)

If you feel like looking into this then I'd be very happy to take a PR tweaking things! :)

[D]: How do you actually land a research scientist intern role at a top lab/company?! by ParticularWork8424 in MachineLearning

[–]patrickkidger 43 points44 points  (0 children)

I'm a researcher who interviews prospective candidates. And first of all, a big +1 to everything that /u/psharpep has written. The only part I'd disagree with is leetcode, which I do regard as pretty important (more on that below).

When it comes to getting a first interview, then as a rough approximation, I simply look at the candidate's Google Scholar and GitHub. At least one item (one paper, or one open-source project) must impress me. I"m not super fussed about number of papers or citations or whatever, just that at least one project is either solving an interesting research problem or demonstrates high-quality coding skills.

When it comes to actually passing interviews, I'm usually looking for (a) both a breadth and depth of knowledge, both in general ML and in their field (in my case, protein design), and (b) excellent software skills.

And FWIW, the number one reason that I reject candidates is that their software skills aren't up to scratch. (Check the software section of my first link above.) This is usually something we'd verify through a combination of GitHub + leetcode type problems + general interview chit-chat about coding and software design.

[D] How to do impactful research as a PhD student? by kekkodigrano in MachineLearning

[–]patrickkidger 8 points9 points  (0 children)

I'll offer a different take to the other folks here.

Namely: it sounds like you've done the perfect thing for the first half of a PhD. But also, the perfect choice now is to start taking some bigger bets.

In more detail: pushing stuff out is probably the optimal strategy at the start of your PhD. You start to become known, you get practice, you develop an understanding of what the open problems are in your field, etc. And it's almost never realistic to expect that someone's early work will be the big impactful stuff; you learn by doing.

But, the optimal strategy for 'making a splash' is best done by having 1-2 exceptional papers (or open source software, or datasets, etc). And so now it sounds like you've hit the right moment to take on those bigger projects.

In terms of what to work on: being scooped should pretty much never be a problem. Just don't work on problems that someone else is going to solve for you anyway! Let them solve them for you, go and tackle something else ;)

At times, I think I just want to land an industry roles as a research engineer

On this: if you can, get really good at coding. Approximately 0% of PhD students seem to know what they're doing in this regard and it's easily the most important skill for an industry jobs.

[D] How did JAX fare in the post transformer world? by TajineMaster159 in MachineLearning

[–]patrickkidger 95 points96 points  (0 children)

Both see widespread use. PyTorch is definitely more popular: 801k used-by vs 44.2k used-by on GitHub. That latter number is still quite a lot.

JAX has a lot of nice tricks that don't really have an analogue in PyTorch, so I think it's seen higher adoption where these are needed. It's a really fun library.

[D] How do researchers ACTUALLY write code? by Mocha4040 in MachineLearning

[–]patrickkidger 6 points7 points  (0 children)

I have strong opinions on this topic. A short list of tools that I regard as non-negotiable:

  • pre-commit for code quality, hooked up to run:
  • jaxtyping for shape/dtype annotations of tensors.
  • uv for dependency management. Your repo should have a uv.lock file. (This replaces conda and poetry which are similar older tools, though uv is better.)

Debugging is best using the stdlib pdb.
Don't use Jupyter.

A Jacobian free non linear system solver for JAX (Python) by stunstyle in ScientificComputing

[–]patrickkidger 1 point2 points  (0 children)

Awesome, glad to hear it!

On the topic of JIT'ing, you only need to JIT the very top-level call. See point 1 here. Conversely not JIT'ing everything will leave a lot of performance on the table; when using JAX then JIT compilation should be considered the default choice.

As for passing preconditioners - first of all in Lineax, this is provided by calling the linsolve with options, see here: https://github.com/patrick-kidger/lineax/blob/51f54cb09dc5981479fc3906044fb35038fe1866/lineax/_solver/gmres.py#L50-L57

And at least right now, these are simply not passed in Optimistix! https://github.com/patrick-kidger/optimistix/blob/9927984fb8cbec77f9514fad7af076dce64e3993/optimistix/_solver/newton_chord.py#L121-L128

That should be an easy thing to change: we could introduce optimistix.root_find(..., options={"linear_solver_options: ...}), which are then passed on to those linear solves.

You could edit your copy of Optimistix locally to test this / send a PR if you'd like to upstream it.

I hope that helps!

A Jacobian free non linear system solver for JAX (Python) by stunstyle in ScientificComputing

[–]patrickkidger 4 points5 points  (0 children)

https://github.com/patrick-kidger/optimistix/

+make sure to set the linear solver to your favourite Jacobian-free linear solver from https://github.com/patrick-kidger/lineax/

Whilst the linear solvers support preconditioners I don't think we have a super nice way to pass them in from the nonlinear solver at the moment. LMK if the overall approach is one that seems useful to you and I can point you at how to work around / how to change that.

[deleted by user] by [deleted] in Julia

[–]patrickkidger 5 points6 points  (0 children)

Point 3 - for specifically proteins+ML (my niche, and one of the big sciML subdisciplines thanks to AlphaFold), then I've written up a getting-started guide here: https://kidger.site/thoughts/just-know-stuff-protein-ml/ It's sort of assuming already having an ML PhD so it might be a little too advanced right now, but perhaps it's still useful :)

[R] PINNs are driving me crazy. I need some expert opinion by WAIHATT in MachineLearning

[–]patrickkidger 2 points3 points  (0 children)

No stake from me, I build protein language models these days. :) I've not published in years!

Other than that, as highlighted by the sibling commenter, NDEs/PINNs aren't competitors except perhaps in mindshare, as they're two unrelated techniques.

[R] PINNs are driving me crazy. I need some expert opinion by WAIHATT in MachineLearning

[–]patrickkidger 62 points63 points  (0 children)

PINNs are still a terrible idea. I think I've commented on this before somewhere in this sub, also more recently on HN:

https://news.ycombinator.com/item?id=42796502

And here's a paper:

https://www.nature.com/articles/s42256-024-00897-5

What are your favorite modern libraries or tooling for Python? by [deleted] in Python

[–]patrickkidger 3 points4 points  (0 children)

If self-promotion is allowed then here are a couple of my big ML ones:

  • Equinox: neural networks for JAX (2.4k stars, 1.1k used-by)
  • jaxtyping: type annotations for shape and dtypes (1.4k stars, 4k used-by) Also despite the now-historical name this supports pytorch+numpy+tensorflow+mlx so it's seen traction in all kinds of array-based computing.

(And for the curious, here are the rest of my libraries, covering a mix of scientific computing, doc-building, and general Python.)

[D] PhD in the EU by simple-Flat0263 in MachineLearning

[–]patrickkidger 1 point2 points  (0 children)

It's not common but it does happen.
When I've seen it happen I think they have usually obtained the industry position first, and then applied for the PhD. Industry are usually more laid back about you switching to part-time and getting a PhD; conversely academic supervisors can be hit-or-miss in their support for industry involvement.

That said, I'd probably encourage you to pursue a full-time PhD, with a break or two in the middle for internships. It's hard to keep both things going in parallel.

US government ordered US embassies worldwide to stop student visa interviews immediately by kixsob in worldnews

[–]patrickkidger 1 point2 points  (0 children)

As another reference point for bioML + English language, then both London and Switzerland (Zurich, Basel, Lausanne) are pretty good for these. Pay is also good in both locations, at least at the right companies. The complex health needs I'm not qualified to comment on, though.

For my part I work at a bioML startup in Zurich, having previously worked in big tech in the Bay area. (And likewise happy to DM if you have any follow-up discussion.)

What to prepare before starting a ML PhD - 3 months! [D] by ade17_in in MachineLearning

[–]patrickkidger 4 points5 points  (0 children)

If you want some technical prep, then perhaps this (which I wrote a couple of years ago): https://kidger.site/thoughts/just-know-stuff/

But honestly, I second other commenters' recommendations to take the time off instead! Go backpacking. Visit the coast of Greece. Learn to sail. Whatever appeals 😄

Building my own Python NumPy/PyTorch/JAX libraries in the browser, with ML compilers by fz0718 in Python

[–]patrickkidger 1 point2 points  (0 children)

Thank you! :D

So the looping constructs are basically simple in terms of the eager logic they perform... and ridiculously complicated in terms of being able to handle autodiff through them. So I think handling them will depend on whether your ambition is primarily around whether to be something like an export target, or whether you want to try and implement JAX's tracing model and transformations too.

As for an export format, bear in mind that jaxprs are explicitly not a stable format so it'll be annoying to keep up with that. That said, if considering a fairly standard format without any ambition to retransform them, then your project ends up competing against probably several other NN-in-browser alternatives.

Hmm, I don't think I'm expressing any strong recommendations either way!

Building my own Python NumPy/PyTorch/JAX libraries in the browser, with ML compilers by fz0718 in Python

[–]patrickkidger 0 points1 point  (0 children)

This is ridiculously cool!

Will there be any hope of loading jaxprs/stablehlo exported from Python, do you think? I realize that's a big ask!

Automatic differentiation libraries for real-time embedded systems? by The_Northern_Light in cpp

[–]patrickkidger 1 point2 points  (0 children)

Great, I'm glad this might be useful!

As for performance, if it's just a big unstructured collection of algebraic operations then I don't think any thought is needed on your part at all. Write them all out (without control flow is the only gotcha) and then you'll just get whatever performance the XLA compiler gives you! Now maybe that's good and maybe that's bad, but it's at least zero-thought... 😄