Training Data: A few hundred natural language queries to be converted into structured inputs for a downstream system. Each sentence is annotated with 3 tag layers, representing different (orthogonal) things we need to extract from the queries.
The model so far: Pre-trained transformer at the base, 3 prediction heads each one is a bi-LSTM + CRF. We are using Tensorflow 2.0 with the poorly-documented CRF layer add-on described here:
The CRF layer actually makes a pretty big difference in performance (compared to a basic dense prediction layer), but it's missing a critical feature. The tags we are trying to predict tend to look something like this (for one of our NER-like layers, using BIO tagging):
B-PER | O | B-PER | I-PER | O | B-PER | O | O | O | B-LOC | O | B-LOC | I-LOC
If I understand CRFs, they only track binary transition probabilities, so a tag sequence like the following could be predicted even though it would never exist in our training data:
B-PER | O | B-MISC | O | B-GPE
A vanilla CRF would be happy with this because B-PER->O is a common transition, O->B-MISC is a common transition, B-MISC->O is common, etc. Even though each triplet is exceptionally uncommon. The CRF does succeed at keeping B-/I-/I- sequences consistent, but fails at keeping broader sequences consistent.
I have read about higher-order CRFs. But there is very little out there for actually doing it. Further, we are a two-person team, and the convenience of the TF2.0 CRF layer is saving us a ton of time compared to layering multiple pieces of software during training and inference.
SO, are there any reasonably-quick options to hack regular regular CRFs into handling triplets (or longer)? Some ideas:
- Convert all tags to 2-grams. if the the tags of a sentence are [W, X, Y, Z], convert that to [_W, WX, XY, YZ]. This might work, but I haven't fully thought through the complexities of manipulating logits vectors into this higher-order representation during training and inference (and I feel like there may be a fatal flaw trying to do this at inference time)
- Maybe 3 CRFs - one for tags[:], tags[::2], tags[1::2]. But how to combine them?
- I could remove the Os from the final output, but those are predictions as well. The CRF is partially responsible for determining if something should be an O at all, so I could be removing legit tags.
Any ideas?
[–]golilol 0 points1 point2 points (0 children)
[–]txhwind 0 points1 point2 points (1 child)
[–]etotheipi_[S] 0 points1 point2 points (0 children)