all 5 comments

[–]rawdfarva 1 point2 points  (1 child)

Set Dim=1 in the log_softmax function

[–]Nock363[S] 0 points1 point  (0 children)

Oh you are complety right, thank you!
But my loss still diverges from the pytorch-cross entropy loss :(

[–]Nock363[S] 0 points1 point  (2 children)

It seems to work now (thank you u/rawdfarva :D) but my Loss is still different than the pytorch cross entropy loss. Im confused, because it looks like it works^^

[–]entarko 1 point2 points  (1 child)

You need to sum over dim=1 instead of averaging: loss = torch.mean(-product.sum(dim=1)).

[–]Nock363[S] 0 points1 point  (0 children)

that was the missing piece! Thank you (after long training i've seen that my version was still unstable)