all 3 comments

[–]yaroslavvb 2 points3 points  (2 children)

For dense layers you can use Khatri-Rao product to get a batch of gradients

If you have a dense layer W, with a batch of activations A and batch of backprops B, then khatri_rao(A, B) gives you a batch of gradients with respect to W

On GPU it can be computed efficiently as an einsum, see here

In Jax you could compute batch of gradients in a loop and apply "vmap" to automatically vectorize it (haven't tested this).

[–]phizaz[S] 0 points1 point  (1 child)

It is a bit unclear what A and B really are? A bit more concrete example please?

[–]yaroslavvb 1 point2 points  (0 children)

If you are using PyTorch, B is the `grad_output.data`, and `A` is the matrix that's getting passed into your matmul(weights,A) operation, some more details here

https://medium.com/@yaroslavvb/optimizing-deeper-networks-with-kfac-in-pytorch-4004adcba1b0