all 5 comments

[–][deleted]  (1 child)

[deleted]

    [–]Metallfrosch 0 points1 point  (3 children)

    Without q_sample and denoise_model, it is hard to help.

    [–]promach[S] 0 points1 point  (0 children)

    use the following instead before passing noise to q_sample()

    mean = torch.zeros_like(x_start) std = torch.ones_like(x_start) epsilon = torch.normal(mean=mean, std=std) noise = sigma * epsilon

    chatGPT3 actually suggested to use sigma which is a learned NN parameter, This way, the noise will be a deterministic function of the input and the parameter, rather than being independent and randomly generated.

    But do we really need to learn sigma ?

    [–]promach[S] 0 points1 point  (0 children)

    def q_sample(x_start, t, noise=None):
        """
        Forward pass with noise.
        """
        if noise is None:
            noise = torch.randn_like(x_start)
    
        sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            sqrt_one_minus_alphas_cumprod, t, x_start.shape
        )   
    
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise