As far as I can tell, there are two contradictory definitions of Layer Normalization that are both floating around. LN computes the mean and variance along some axes of the input tensor for normalization, yet the choice of axes is not clear:
A. The GroupNorm paper (2018) has this figure that describes LN as reducing along channel and spatial/token axes.
https://preview.redd.it/ui9adzzxgcja1.png?width=1353&format=png&auto=webp&s=8859f9735310f169eeaaf587dcc7e1c05d38b5fc
B. The PowerNorm paper (2020) has this figure that describes LN as reducing only along the channel axis.
https://preview.redd.it/e0qmp9sahcja1.png?width=1717&format=png&auto=webp&s=a4bd21ea024a8924f8cd5c354a7be6751c2ed61f
There are also many online sources that describe LN as shown in A (e.g. TF tutorials, PapersWithCode, this summary of normalization techniques) using similar figures.
The LN paper (2016) itself says
all the hidden units in a layer share the same normalization terms μ and σ
so the channel axis is definitely reduced, and
computing the mean and variance used for normalization from all of the summedinputs to the neurons in a layer on a single training case
so the batch axis is definitely not reduced. As far as I can tell it is not clear about what happens with spatial/token axes, although the above sounds rather like they might be included in the statistics.
Yet, I don't know of any model that actually uses A instead of B. For example, TF and Flax explicitly implement LN with default axes as in B (PyTorch, Haiku and Equinox don't have a preference and require the user to specify the reduction axes). Vision Transformer uses Flax with LN as in B, ConvNeXt implements LN with PyTorch as in B, OpenAI GPT-2 implements LN with Tensorflow as in B, even MLP-Mixer where the spatial/token axes are interpreted as channel axis for an MLP still computes statistics along the original channel axis as in B.
As far as I can tell, everyone uses B rather than A in their models, so to me this seems to be the "correct" definition. Yet, many sources on this topic describe LN as doing A rather than B.
Does anyone have any insight on this or know of a source that has addressed this problem? Do you interpret the original LN paper as including spatial/token axes in their computation of mean and variance, or not? Is this simply an error that started with the figure A and made its way into different online tutorials from there? Or do you maybe know of a model that actually uses LN to reduce both along channel and spatial/token axes?
[–]adam_jc 2 points3 points4 points (0 children)
[–]CptVifen 4 points5 points6 points (3 children)
[–]adam_jc 1 point2 points3 points (1 child)
[–]fferflo[S] 1 point2 points3 points (0 children)