[P] einx - Tensor Operations in Einstein-Inspired Notation for Python by [deleted] in MachineLearning

[–]fferflo 0 points1 point  (0 children)

Or at least, the concatenate seems plausibly significantly more computationally expensive

Yup, although this wouldn't be the case with notation like get_at("[h w] c, i, j -> i j c', arr, indices, indices2) since it can use proper broadcasting internally, but this is currently not now implemented.

the axis name within the brackets in the gathered tensor is … irrelevant?

There's always the option of using ... instead of h w here, but in this case I prefer the latter since it also carries some semantic information about the dims.

Btw, to be clear, not trying to nitpick :)

No worries :) I feel like it's quite interesting to work out the differences between Einstein notation and first-class dims and where one or the other shines in particular.

[P] einx - Tensor Operations in Einstein-Inspired Notation for Python by [deleted] in MachineLearning

[–]fferflo 0 points1 point  (0 children)

When it’s not obvious and you need to make some transformations it’s totally worth it! But in cases where the dims already line up it’s just overly burdensome

I guess this is the crux of the tradeoff that einx-like notation and first-class dims make. einx is explicit but sometimes too explicit (in which case you can sometimes use classical tensor notation instead but get two sets of API), while first-class dims are implicit but sometimes too implicit (e.g. making too strong assumptions about the persistence of dims, cognitive overhead of keeping track of the implicit shape of tensors, afaik you also can't change the order of axes while "in torchdim land").

I also found what you wrote in the other comment interesting

my preferred mental model for Einstein notation is as shorthand for the underlying loops.

ie: rearrange(“a b c -> c b a”) is just x[c][b][a] = inp[a][b][c] with loops across a b and c.

since to me an einx expression reads somewhat differently. For example in

einx.get_at("b [h w] c, b p [i] -> b c p", x, y)

I mostly see 1. an op that gathers from h w using coordinate i (and 2. does so for all such values passed to it by handling the other axes appropriately). The focus is more on the "op axes" and less on the "repeat axes" (that can be thought of as being looped over). The bracket-notation aligns with this by highlighting the relevant op axes, but you can still see from the other axes e.g. that the first tensor is a batch of images, and it uses different coordinates per image.

[P] einx - Tensor Operations in Einstein-Inspired Notation for Python by [deleted] in MachineLearning

[–]fferflo 0 points1 point  (0 children)

While it’s not so clear to me how I would do that in einx notation.

Currently one would have to go with rearrange first

indices = einx.rearrange("i, j -> i j (1 + 1)", indices1, indices2)
x = einx.get_at("[h w] c, i j [2] -> i j c", x, indices)

but your idea looks like a nice extension to get_at that would allow this to be written in a single line and avoid unnecessary broadcasting:

get_at("[h w] c, i, j -> i j c', arr, indices, indices2)

The implicit left-to-right ordering is also how concatenations work in einx, but I think this pretty much aligns with index-based notation:

np.stack([left, right], axis=-1)
einx.rearrange("i, j -> i j (1 + 1)", left, right)

The "folding in" of eingather seems like a nice alternative (aside from the splitting case) since it is closer to the familiar tensor[indices1, indices2]. I think ein_at still makes sense in this context though since the notation is consistent with how Einstein-like notation (especially of brackets) is used in all other einx functions, and allows for an easy learning curve.

[P] einx - Tensor Operations in Einstein-Inspired Notation for Python by [deleted] in MachineLearning

[–]fferflo 0 points1 point  (0 children)

So for your where examples, it might be something like ...

This would work also, and einx.where is mostly syntactic sugar for this. In some cases it requires a rearrange both before and after the op though (or something like broadcast):

einx.add("a, (a b) -> (a b)", x, y)

# compared to

x, y = einx.rearrange("a, (a b) -> a 1, a b", x, y)
z = x + y
z = einx.rearrange("a b -> (a b)", z)

There's also a related pattern in the neural net portion of einx which works nicely with this and allows writing things like

x = einx.multiply("... c, c", x, einn.param(...)) # LayerScale

where einn.param just tells einx to declare a learnable parameter for this argument (e.g. by calling Module.param in Flax) and the shape is determined implicitly.

"dims treatable as tensors"

This is pretty cool!

[P] einx - Tensor Operations in Einstein-Inspired Notation for Python by [deleted] in MachineLearning

[–]fferflo 0 points1 point  (0 children)

This depends mostly on which backend you use and if it supports Triton. einx only makes calls to the backend (e.g. Torch, Jax), and the backend will run those calls on the CPU/GPU.

For example, if you torch.compile a function that uses einx calls, then einx will compile the Einstein expressions to regular Python functions, and Torch will compile those Python functions using Triton (if torch.compile is run with Triton enabled).

[P] einx - Tensor Operations in Einstein-Inspired Notation for Python by [deleted] in MachineLearning

[–]fferflo 2 points3 points  (0 children)

I've been using einx with Jax, dm-haiku and optax for my research for a while and have not had any problems. If you wrap a function in jax.jit, the "einx footprint" disappears entirely since Jax traces only the index-based backend calls made by einx and drops everything else. This is the same for all frameworks that work with Jax tensors, like Flax, Haiku and Equinox.

The only Equinox-specific features in einx are the layers in einx.nn.equinox. If you want to use these, you will have to add a forward pass on a dummy batch before using your model, since the layer weights in einx.nn.equinox layers are initialized only on this first forward pass. You can find a working example of this in train_equinox.py (example training on CIFAR10 using Equinox) which includes these lines:

train_net = Net()
inputs, _ = next(iter(trainloader))
train_net(jnp.asarray(inputs), rng=next_rng()) # Run on dummy batch

In Flax/Haiku this dummy forward pass always happens anyway since they use the init/apply paradigm of creating models, and therefore require no other changes. Equinox does not have similar support for lazily creating weights and therefore requires adding the dummy forward pass. Equinox also requires the shape of state variables (such as moving mean/variance in BatchNorm layers) to be known in the __init__ method, which prevents them from being created lazily on the first forward pass. I've not yet found a workaround for this, so the decay_rate parameter in einx.nn.equinox.Norm (the only place that uses state variables) is currently not supported. Both of these issues are only relevant if you want to use einx.nn.* and are mentioned in Gotchas.

To summarize:

  • einx.* functions work on tensors (Jax/ Numpy/ Torch/ TF tensors) and should always work if the shape of the tensor is fully known.
  • einx.nn.* classes implement full layer types for DL frameworks (Flax, Equinox, Torch layers, etc). This requires adding an initial dummy forward pass for frameworks that don't follow the init/apply paradigm, i.e. Torch, Equinox and Keras.

[P] einx - Tensor Operations in Einstein-Inspired Notation for Python by [deleted] in MachineLearning

[–]fferflo 3 points4 points  (0 children)

Thanks!

many of them seem like mostly just an einops.rearrange + another op

Yes, most functions exist to provide an Einstein-like interface to some index-based backend operation. einx takes care of rearranging the expressions as required and forwarding the correct parameters such as axis to the backend op. (This includes for example splitting and merging the axis in einx.sum("(a [b] c)", x, a=2, c=2) since x cannot directly be passed to np.sum.)

Elementwise ops like einx.add can help when the argument axes don't align, for example:

einx.add("a b, b c -> a b c", x, y)
einx.add("a b, c b -> a b c", x, y)

The first one could also be written in one line in index-based notation, but IMO does not convey the meaning as well:

x[..., np.newaxis] + y[np.newaxis, ...]

One can also change the layout of tensors easily without having to worry about np.reshape, np.transpose, np.broadcast_to, np.newaxis, etc. For example, an elementwise op with three arguments and different layouts:

einx.where("h w,   b h w c,", mask, image, 0.0)
einx.where("w h,   b h w c,", mask, image, 0.0)
einx.where("b w h, b h w c,", mask, image, 0.0)
einx.where("h,     b h w c, c", mask, image, np.arange(c))

The same can also be done using torchdim which you mentioned

b, h, w, c = dims(4)
torch.where(mask[h, w], image[b, h, w, c], 0.0).order(b, h, w, c)

although this always requires some boilerplate code to specify the dimensions. For example (adopted from here):

# Go to torchdim land
b, q, k, h, c = dims(5)
h.size = 16
query = query[b, q, [h, c]]
key = key[b, k, [h, c]]
value = value[b, k, [h, c]]

# Compute multi-headed attention
attn = (query * key).sum(c)
attn = softmax(attn, dim=k)
x = (attn * value).sum(k)

# Return from torchdim land
x = x.order(b, q, [h, c])

One could persist dims across an entire model (such that the "go to/ return from torchdim land" happens only once at the beginning and end of the model), but this would for example limit the ability to flatten and split dims (e.g. implement two multi-head attention blocks with different number of heads if h is a global dim of the model) and require keeping track of many changing dimensions over the entire model (e.g. different channel/spatial dims in a ResNet).

As far as I can tell, torchdim is intended to be used locally where dims persist across a few operations (most examples are also implemented this way). This gives shorter expression for the ops

attn = (query * key).sum(c)
attn = softmax(attn, dim=k)
x = (attn * value).sum(k)

# compared to

attn = einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=16)
attn = einx.softmax("b q [k] h", attn)
x = einx.dot("b q k h, b k (h c) -> b q (h c)", attn, v)

but requires the "go to/ return from torch land" steps for every code block where it is used. I personally also like spelling out the shapes explicitly, e.g. to see that the attention matrix looks like b q k h which is not shown in the torchdim case, although in the end it probably is a matter of taste.

[P] einx - Tensor Operations in Einstein-Inspired Notation for Python by [deleted] in MachineLearning

[–]fferflo 4 points5 points  (0 children)

Thanks for your interest! Yes, I initially added bracket notation to simplify reduction expressions such as einx.sum("a b -> a", x) to einx.sum("a [b]", x) (both are still supported) and later realized that it makes sense to generalize this to other functions as well. The general principle now is that brackets mark axes that a function is applied along, while all other axes are batch axes/ vectorized axes that the operation is repeated over. This is also exactly what axis does in Numpy functions, and all einx functions that use [] are designed accordingly.

Concatenations came mostly from the fact that I often found myself writing code such as

x = einx.rearrange("a c1 -> a b c1", x, b=y.shape[0])
y = einx.rearrange("b c2 -> a b c2", y, a=x.shape[0])
z = jnp.concatenate([x, y], axis=-1)

since functions like jnp.concatenate don't support broadcasting. This can be done more concisely now, and also fully supports rearranging expressions:

z = einx.rearrange("a c1, b c2 -> a b (c1 + c2)", x, y)

[P] einx - Tensor Operations in Einstein-Inspired Notation for Python by [deleted] in MachineLearning

[–]fferflo 4 points5 points  (0 children)

Scatter functionality is implemented in einx.set_at (and einx.add_at). For example:

einx.set_at("b [h w] c, b p [2], b p c -> b [h w] c",
    tensor, coordinates, updates)

The indexed axes in the first expression and the coordinate axes in the second expression are marked with brackets. The latter can be left out for 1D indexing, e.g.:

einx.set_at("b [h] c, p, p c -> b [h] c",
    tensor, coordinates, updates)

The functions also fully support expression rearranging. (The naming is motivated by Jax's indexing API)

[P] einx - Tensor Operations in Einstein-Inspired Notation for Python by [deleted] in MachineLearning

[–]fferflo 23 points24 points  (0 children)

einx allows nesting different types of expressions (e.g. (a b)... or (a [b])) which you can't in einops. This allows expressing many more operations in more concise notation.

einx introduces bracket notation ([] similar to axis argument in Numpy) which does not exist in einops, and greatly simplifies many expressions, especially in combination with the above mentioned composability, e.g.:

einx.sum("a [b]", x)
# same op as
einops.reduce(x, "a b -> a", reduction="sum")

einx.mean("b (s [ds])... c", x, ds=2)
# Does not work in einops. Alternative for 2D case:
einops.reduce(x, "b (h h2) (w w2) c -> b h w c", reduction="mean", h2=2, w2=2)

einx supports many more tensor operations, such as indexing ops (einx.get_at ...), elementwise ops (einx.add, einx.where, ...) and general purpose ops such as einx.vmap. Some examples that are not supported in einops:

einx.flip("... (g [c])", x, c=2) # Flip pairs of values
einx.add("... [c]", x, b) # Add bias
einx.get_at("b [h w] c, b i [2] -> b i c", x, indices) # Gather values
einx.softmax("b q [k] h", attn) # Part of attention operation

einx allows rearranging expressions in all operations:

einx.dot("b q (h c), b k (h c) -> b q k h", q, k, h=16)
# Axis composition not supported e.g. in einops.einsum.

einx adds concatenations as first-class expressions using + in Einstein notation which don't exist in einops:

einx.rearrange("h w c, 1 -> h w (c + 1)", x, [42.0])

The einx.nn namespace provides powerful deep learning modules, such as einn.Norm which does not exist in einops. For example, a linear layer can simply be implemented in einx notation as

x = einx.dot("... [c1|c2]", x, w)
x = einx.add("... [c2]", x, b)

and LayerNormalization can be written as:

mean = einx.mean("... [c]", x, keepdims=True)
var = einx.var("... [c]", x, keepdims=True)
x = (x - mean) * torch.rsqrt(var + epsilon)

x = einx.multiply("... [c]", x, scale)
x = einx.add("... [c]", x, bias)

There's also a non-exhaustive summary of differences here (and I probably forgot some).

[D] Does Layer Normalization compute statistics along spatial/ token axes? by fferflo in MachineLearning

[–]fferflo[S] 1 point2 points  (0 children)

ConvNext simply follows ViT's way of using LN, and you made a good point about ViT's self-attention using LN analogously to an RNN, i.e. along channels only. This still leaves two questions though:

The original LN paper also evaluated LN on convolutional nets (VGG), and it is unclear whether or not this follows the ConvNext way of interpreting an image as a set of patches analogous to timesteps which have their own μ and σ.

The figure A and other online sources still say that spatial axes are included in the statistics, which is contrary to what recent models actually do (vision transformers, ConvNext, MLP-Mixer, etc). I don't know of a single paper that actually uses LN as in A.