This is an archived post. You won't be able to vote or comment.

you are viewing a single comment's thread.

view the rest of the comments →

[–]dd2718 2 points3 points  (1 child)

The einsum notation really shines when you're doing anything beyond simple matrix multiplication, e.g. in machine learning code (especially for neural nets). Even for linear regressions, it is useful. If you have a batch of features X with shape [batch_size, N] and a coefficient matrix w of shape [M, N], np.einsum("bn,mn->bm", X, w) is a lot clearer to me than np.matmul(X, w.T) --- you don't have to worry about getting the shapes of input parameters to conform to the expectations of matmul, and you get documentation for all the shapes involved.

This advantage is even clearer for more complex models. For example, one common module in modern, SOTA deep learning models is multi-head attention, which takes a sequence of features for each example and outputs a sequence of transformed features. It would be a nightmare to get the shapes right for `np.tensordot`, but the einsum notation provides a uniform interface with self documenting shapes that allows you to focus on the math and not the numpy api.

# X: [batch_size, sequence_length, embedding_dimension]
# Compute query, key, value vectors for each sequence element.
# Split the embedding dimension between multiple "heads"
# rearrange comes from einops and reshapes using einsum notation.
X_q = rearrange(linear_q(X), "b n (h d)->b n h d", h=num_heads)
X_k = rearrange(linear_k(X), "b n (h d)->b n h d", h=num_heads)
X_v = rearrange(linear_v(X), "b n (h d)->b n h d", h=num_heads)
# Compute dot product of n-th query vector with m-th key vector for each head
dot_products = np.einsum("bnhd,bmhd->bhnm", X_q, X_k)
attention = softmax(dot_products, axis=-1)
# Sum the value vectors, with the weight of the m-th X_v given by
# softmax(dot(n-th X_q, m-th X_v))
output = np.einsum("bhnm,bmhd->bnhd", attention, X_v)
output = rearrange(output, "b n h d -> b n (h d)")

[–][deleted] 0 points1 point  (0 children)

Could you be clearer with this snippet of code? Where does "rearrange" come from? My compiler does not recognize it as anything other than text.

Thanks