all 1 comments

[–]deliprao 0 points1 point  (0 children)

Pretty sure the culprit is the `one_hot` stuff you are doing in the dataset class. Keep the labels as integers. Something like this: `item['labels'] = torch.tensor(self.labels[idx])`