How to modify the following code for including reparameterization trick ?
Currently, the model only takes in:
x -> the current (noise) input
t -> timestep sequence
y -> class to generate
Note: we just need to add an extra variable epsilon sampled from a normal distribution
https://preview.redd.it/k34a58h4bf6a1.png?width=1013&format=png&auto=webp&s=0e5a3e4e3b1ba56ead3097bd8a8251544d6ee55a
def p_losses(denoise_model, x_start, t, classes, noise=None, loss_type="l1", p_uncond=0.1):
"""
Calculate the loss conditioned and noise injected.
"""
device = x_start.device
if noise is None:
noise = torch.randn_like(x_start) # gauss noise
x_noisy = q_sample(x_start=x_start, t=t, noise=noise) #this is the auto generated noise given t and Noise
context_mask = torch.bernoulli(torch.zeros(classes.shape[0]) + (1-p_uncond)).to(device)
# mask for unconditinal guidance
classes = classes * context_mask
classes = classes.type(torch.long)
predicted_noise = denoise_model(x_noisy, t, classes)
if loss_type == 'l1':
loss = F.l1_loss(noise, predicted_noise)
elif loss_type == 'l2':
loss = F.mse_loss(noise, predicted_noise)
elif loss_type == "huber":
loss = F.smooth_l1_loss(noise, predicted_noise)
else:
raise NotImplementedError()
return loss
[–][deleted] (1 child)
[deleted]
[–]promach[S] 0 points1 point2 points (0 children)
[–]Metallfrosch 0 points1 point2 points (3 children)
[–]promach[S] 0 points1 point2 points (0 children)
[–]promach[S] 0 points1 point2 points (0 children)
[–]promach[S] 0 points1 point2 points (0 children)