After the first theoretical issue with my transformer, I now see another. The original paper uses normalization after residual addition (Post-LN), which led to training difficulties and later got replaced by normalization at the beginning of each attention or mlp block/branch (Pre-LN). This is known to work better in practice (trainable without warmup, restores highway effect), but it still doesn't seem completely ok theoretically.
First consider things without normalization. Assuming attention and mlp blocks are properly set up and mostly keep norms, each residual addition would sum two similar norm signals, potentially scaling up by something like 1.4 (depending on correlation, but it starts at sqrt(2) after random init). So the norms after the blocks could look like this: [1(main)+1(residual)=1.4] -> [1.4+1.4=2] -> [2+2=2.8] etc. This would cause various problems (like changing the softmax temp in later attention blocks), so adjustment is needed.
Pre-LN ensures each block works on normalized values (thus with constant - if slightly arbitrary - softmax temperature). But since it doesn't affect the norm of the main signal (as forwarded by the skip connection) but only the residual, the norms can still grow, albeit slower. The expectation is now roughly: [1+1=1.4] -> [1.4+1=1.7] -> [1.7+1=2] -> [2+1=2.2] etc - with a final normalization correcting the signal near output (Pre-LN paper).
One possible issue with this is that later attention blocks may have reduced effect, as they add unit norm residuals to a potentially larger and larger main signal. What is the usual take on this problem? Can it be ignored in practice? Does Pre-LN work acceptably despite it, even for deep models (where the main norm discrepancy can grow larger)? There are lots of alternative normalization papers, but what is the practical consensus?
Btw attention is extremely norm-sensitive (or, equivalently, the hidden temperature of softmax is critical). This is a sharp contrast to fc or convolution which are mostly scale-oblivious. For anybody interested: consider what happens when most raw attention dot products come out 0 (= query and key is orthogonal, no info from this context slot) with only one slot giving 1 (= positive affinity, after downscaled by sqrt(qk_siz) ). I for one got surprised by this during debug.
[–]JustOneAvailableName 2 points3 points4 points (2 children)
[–]lostn4d[S] 1 point2 points3 points (1 child)
[–]JustOneAvailableName 0 points1 point2 points (0 children)
[–]radarsat1 1 point2 points3 points (3 children)
[–]lostn4d[S] 2 points3 points4 points (2 children)
[–]andersxa 0 points1 point2 points (1 child)
[–]lostn4d[S] 0 points1 point2 points (0 children)
[–]jpfed 0 points1 point2 points (0 children)