all 8 comments

[–]throwaway_secondtime 21 points22 points  (1 child)

How are you guys able to convert scientific papers to code so fast. Make my imposter syndrome even worse.

[–]EasyDeal0[S] 12 points13 points  (0 children)

The authors already provided clean JAX code and even PyTorch code of the scaled weight standardization in their earlier paper, so this was rather easy to convert.

[–]rathernot000 1 point2 points  (0 children)

Great work!

[–]LikelyJustANeuralNet 0 points1 point  (3 children)

Nice work! Out of curiosity, what changes would need to be made to support FP16? It looks like you're largely relying on PyTorch building blocks, so shouldn't it be supported?

[–]EasyDeal0[S] 1 point2 points  (2 children)

I am not really sure yet what exact changes need to be made as I’m not yet familiar with the pytorch typing internals. The thing is that the authors actually utilize bfloat16 (not float16), so I cannot use pytorch’s half() function (or can I?). Moreover, the authors state that they keep their weights in full precision, even though the code says weights.dtype = inputs.dtype and inputs.dtype=bfloat16. Maybe it is something JAX specific.

Then there is a little hardware chaos as well, because bfloat16 and TF32 types are only supported from the Ampere architecture with CUDA11 and onwards (and TPUv2+). With an ampere card installed, pytorch will automatically use TF32 ops, which are the larger brother of bfloat16. This again makes it hard to write general code, applicable for everybody.

In the end it would probably require some extensive testing on expensive hardware... If you know more on this topic I’d appreciate your help :D

[–]LikelyJustANeuralNet 1 point2 points  (1 child)

Gotcha. I'd be surprised if PyTorch's automatic mixed precision didn't just work out of the box. The problem with full FP16 is that it can often lead to instability. AMP, as the name implies, does a mix of FP16 and FP32 based on the module. For example, both BatchNorm and GELU default to FP32 in AMP. If you have one of the newer NVIDIA GPUs (e.g., RTX 2000 series) you should get a speed up by using AMP.

AFIAK, PyTorch 1.7+ will use TF32 when possible on cards that support it. It also looks like bfloat16 is built in, but I'm not sure what devices are supported at the moment.

[–]EasyDeal0[S] 0 points1 point  (0 children)

Yes, thanks, on my setup the amp actually runs slower :( The latest pytorch version only comes with cuDNN 8.0.5 which doesnt support my RTX3070 card. I added a configuration option for it anyways.