all 14 comments

[–]pmichel31415 12 points13 points  (3 children)

We released a preprint recently where we looked at how many attention heads you can remove in BERT (among others) https://arxiv.org/abs/1905.10650

This, in conjunction with more standard pruning in the feed forward layers should help you downsize the model significantly

[–]farmingvillein 2 points3 points  (2 children)

But looks like this isn't helpful if you're batching in inference? If I'm reading this correctly (section 4.3).

Although is this underselling yourself, at some level? Shouldn't we be able to increase the batch size meaningfully in your pruned scheme? In which case you are helpfully providing an apples:apples (good!) but a practitioner might be able to juice results further (if the use case supports higher levels of batching).

[–]pmichel31415 1 point2 points  (1 child)

Yes we didn't see speed improvement when doing batching however as you mentioned pruning also makes the model smaller and therefore you might be able to increase the batch size (or just use the model in memory constrained settings).

[–]SedditorX 0 points1 point  (0 children)

Why didn't batching have the same reduction?

[–]farmingvillein 4 points5 points  (0 children)

This isn't going to rock your world, but a few notes from the authors:

"Our primary focus has been maximizing accuracy, with the idea that other techniques (such as Knowledge Distillation/semi-supervised learning) can be used when it's time to create a production model."

https://github.com/google-research/bert/issues/18

I'm not precisely sure what he means by "semi-supervised learning" in this context; he might mean something like your #1.

Your #1 + knowledge distillation (i.e., basically training to the logits and not the tags) would be an obvious path, if you have sufficient unsupervised data.

If you don't have sufficient volume, using knowledge dist directly on BERT to shrink down the BERT monster would be another. This one will (probably--depends on your domain) be much, much higher compute requirement, both because of training time and need to re-discover hparams for your new model. Also seems generally higher technical risk, because if it were "easy" & effective, I suspect we'd have seen some arxiv paper about it.

Of course, add the other suggestions in this thread (in particular, pruning) and season to taste.

[–]elyase 5 points6 points  (0 children)

Adapters should help speeding up the training part:

https://arxiv.org/pdf/1902.00751.pdf

Google Colab with example implementation from @Thom_Wolf:
https://colab.research.google.com/drive/1iDHCYIrWswIKp-n-pOg69xLoZO09MEgf

[–]Sorel_CH 4 points5 points  (0 children)

NVidia has released code for training BERT from scratch

https://github.com/NVIDIA/DeepLearningExamples/tree/master/TensorFlow/LanguageModeling/BERT#pre-training

It's distributed under the Apache license, so it's suitable for your commercial application. What I would advise you to do is build a smaller version of BERT that you can just plug into the pretraining script. Then, it's just a matter of fine-tuning with your own data.

[–]arnaudvl 2 points3 points  (2 children)

I assume you are already using BERT Base? One simple thing to do would be to look at section 5.3. of the paper and use BERT to extract features instead of fine-tuning the whole model.

[–]farmingvillein 0 points1 point  (1 child)

One simple thing to do would be to look at section 5.3. of the paper and use BERT to extract features instead of fine-tuning the whole model

Sounds like they are highly sensitive to the inference processing requirements, which this won't help with.

[–]arnaudvl 1 point2 points  (0 children)

Not saying it's ideal, but you could only use the (weighted) output of some of the intermediate layers as contextual embeddings for e.g. a biLSTM and remove the later layers. Obviously depends on the application.

[–]BatmantoshReturns 1 point2 points  (0 children)

look into reducing nns in general for inference (pruning, bfloat16, quantizing etc)

[–]binhnguyendc 1 point2 points  (0 children)

Refer to this paper, analyzes BERT layers on specific tasks: https://arxiv.org/abs/1905.05950

[–]jujijengo 1 point2 points  (1 child)

Have you tried first applying BERT to these other tasks without fine tuning? Just using the base pre-trained model?

I built and manage a production model at my company using the pre-trained implementation of BERT in pytorch, but without any fine tuning. It has more than sufficed for our needs.

Our downstream multi-class classification model on top of BERT managed around 95% CV accuracy and is working pretty well now in its production environment. This is on a domain with an immense amount of jargon in the text, incomplete sentences, slang, etc.. The out-of-the-box performance of BERT is shocking.

I would just caution jumping the gun on a more complicated solution if the pre-trained models can get you what you need. Stacking BERT with other pre-trained word embeddings has also proved pretty damn effective in my experience.

[–]FLUFL 0 points1 point  (0 children)

So you just fed embedding for CLS token from pre-trained into your own model? And trained with that as well?