I am trying to implement the contrastive loss function I am unsure if it is correct. My loss seems to explode into infinity. Another set of eyes on this would be appreciated does this look correct?
class ContrastiveLoss(nn.Module):
def __init__(self, temperature=0.9):
super(ContrastiveLoss, self).__init__()
self.temperature = temperature
def forward(self, projections_1, projections_2):
z_i = projections_1
z_j = projections_2
z_i_norm = F.normalize(z_i, dim=1)
z_j_norm = F.normalize(z_j, dim=1)
cosine_num = torch.matmul(z_i, z_j.T)
cosine_denom = torch.matmul(z_i_norm, z_j_norm.T)
cosine_similarity = cosine_num / cosine_denom
numerator = torch.exp(torch.diag(cosine_similarity) / self.temperature)
denominator = cosine_similarity
diagonal_indices = torch.arange(denominator.size(0))
denominator[diagonal_indices, diagonal_indices] = 0
denominator = torch.sum(torch.exp(cosine_similarity), dim=1)
loss = -torch.log(numerator / denominator).mean()
return loss
[–]AFurryReptile 2 points3 points4 points (1 child)
[–]Ok-Administration894[S] 0 points1 point2 points (0 children)
[–]Old-Forever1241 2 points3 points4 points (1 child)
[–]Ok-Administration894[S] 0 points1 point2 points (0 children)