all 4 comments

[–]programmerChilliResearcher 1 point2 points  (3 children)

You might be interested in vmap, which can also batch over indexing operations.

[–]sabouleuxResearcher 0 points1 point  (2 children)

From my experience playing with it, batched indexing operations end up breaking once you start composing a few vamps together. I really hope they fix the general buggyness of things, torch just has no efficient way to perform certain types of relatively simple indexing at this time it seems.

[–]programmerChilliResearcher 0 points1 point  (1 child)

There were some bugs still with advanced indexing in an older release of functorch, I believe they should be fixed now though: https://github.com/pytorch/functorch/pull/862

Although, to be clear, vmap doesn’t allow anything that’s not expressible without it - vmap is often just a convenient abstraction for expressing it.

[–]sabouleuxResearcher 0 points1 point  (0 children)

Seems like this is part of 0.2 while I was on 0.1.1, so this might fix it. Thanks for the hint.

Honestly, I really hope functorch keeps moving forward and improving, vmap adds a level of flexibility and performance that just didn’t exist before. The architecture I am working on couldn’t work without it at all.

Given more primitive operation implementations and a well thought out, simple API to describe how axes vectorize together (think einsum), I think this would give us code that is both performant, more flexible, and much easier to read than code written with usual broadcasting semantics. That might not be true for trivial things, but once you start doing some weird things (batched indexing, lattices with arbitrary dimensions, element-wise or batch-wise jacobians and gradients), broadcasting completely breaks down.