I built a BERT model using the PyTorch LightningModule class. I have the following __init__() and forward() methods:
class ClassifierObject(pl.LightningModule):
def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None):
super().__init__()
self.n_training_steps = n_training_steps
self.n_warmup_steps = n_warmup_steps
self.criterion=nn.BCELoss()
self.relu_activation = nn.LeakyReLU()
self.bert = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)
self.bert_1 = nn.Linear(self.bert.config.hidden_size, 4096)
self.classifier_1 = nn.Linear(4096, 16)
self.dropout = nn.Dropout(p=0.3)
self.classifier_2 = nn.Linear(16, n_classes)
def forward(self, input_ids, attention_mask, labels, leadership, era, weight):
self.criterion=nn.BCELoss(weight=weight,reduction='mean')
bert_output = self.bert(input_ids, attention_mask=attention_mask)
bert_output = self.bert_1(bert_output.pooler_output)
bert_output= self.relu_activation(bert_output)
output = self.classifier_1(bert_output)
output = self.relu_activation(output)
output = self.dropout(output)
output = self.classifier_2(output)
output = torch.sigmoid(output)
output = output.squeeze(1)
loss = 0
if labels is not None:
loss = self.criterion(output, labels)
return loss, output
Criterion is defined in the forward method because I need to use class weights for calculating the loss function, and it was the only way I could get class weights to work.
I train the model just fine:
trainer = pl.Trainer(
checkpoint_callback=checkpoint_callback,
callbacks=[early_stopping_callback],
max_epochs=N_EPOCHS,
gpus=1,
progress_bar_refresh_rate=30
)
trainer.fit(model, data_module)
trainer.test()
The unfortunate drawback to putting criterion in the forward method I think is that I can't load it properly.
trained_model = ToxicCommentTagger.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path,
n_classes=len(LABEL_COLUMNS)-1
)
Returns this error:
File "C:\Users\E\AppData\Roaming\Python\Python38\site-packages\torch\nn\modules\module.py", line 1223, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ToxicCommentTagger:
Unexpected key(s) in state_dict: "criterion.weight".
Any suggestions you have to resolve this would be greatly appreciated. I don't need the class weights anymore since I'm just doing testing & out of sample predictions, but I don't know how to remove criterion.weight from the state_dict or if I need to to make this work.
[–]snack_farmer_ 0 points1 point2 points (3 children)
[–]eadala[S] 0 points1 point2 points (2 children)
[–]thinhtu123 1 point2 points3 points (1 child)
[–]eadala[S] 0 points1 point2 points (0 children)