from typing import Optional, List, Union import numpy as np import torch from labml import monit from .latent_diffusion import LatentDiffusion def set_seed(seed): np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # Call the function to set the seed # set_seed(42) def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg return noise_cfg class DiffusionSampler: """ ## Base class for sampling algorithms """ model: LatentDiffusion def __init__(self, model: LatentDiffusion): """ :param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$ """ super().__init__() # Set the model $\epsilon_\text{cond}(x_t, c)$ self.model = model # Get number of steps the model was trained with $T$ self.n_steps = model.n_steps class SDFSampler(DiffusionSampler): """ ## DDPM Sampler This extends the [`DiffusionSampler` base class](index.html). DDPM samples images by repeatedly removing noise by sampling step by step from $p_\theta(x_{t-1} | x_t)$, \begin{align} p_\theta(x_{t-1} | x_t) &= \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big) \\ \mu_t(x_t, t) &= \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar\alpha_t}x_0 + \frac{\sqrt{\alpha_t}(1 - \bar\alpha_{t-1})}{1-\bar\alpha_t}x_t \\ \tilde\beta_t &= \frac{1 - \bar\alpha_{t-1}}{1 - \bar\alpha_t} \beta_t \\ x_0 &= \frac{1}{\sqrt{\bar\alpha_t}} x_t - \Big(\sqrt{\frac{1}{\bar\alpha_t} - 1}\Big)\epsilon_\theta \\ \end{align} """ model: LatentDiffusion def __init__( self, model: LatentDiffusion, max_l, h, is_autocast=False, is_show_image=False, device=None, debug_mode=False ): """ :param model: is the model to predict noise $\epsilon_\text{cond}(x_t, c)$ """ super().__init__(model) if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: self.device = device # selected time steps ($\tau$) $1, 2, \dots, T$ # self.time_steps = np.asarray(list(range(self.n_steps)), dtype=np.int32) self.tau = torch.tensor([13, 53, 116, 193, 310, 443, 587, 730, 845, 999], device=self.device) # torch.tensor([999, 845, 730, 587, 443, 310, 193, 116, 53, 13]) # self.tau = torch.tensor(np.asarray(list(range(self.n_steps)), dtype=np.int32), device=self.device) self.used_n_steps = len(self.tau) self.is_show_image = is_show_image self.autocast = torch.cuda.amp.autocast(enabled=is_autocast) self.out_channel = self.model.eps_model.out_channels self.max_l = max_l self.h = h self.debug_mode = debug_mode self.guidance_scale = 7.5 self.guidance_rescale = 0.7 # now, we set the coefficients with torch.no_grad(): # $\bar\alpha_t$ self.alpha_bar = self.model.alpha_bar # $\beta_t$ schedule beta = self.model.beta # $\bar\alpha_{t-1}$ self.alpha_bar_prev = torch.cat([self.alpha_bar.new_tensor([1.]), self.alpha_bar[:-1]]) # $\sigma_t$ in DDIM self.sigma_ddim = torch.sqrt((1-self.alpha_bar_prev)/(1-self.alpha_bar)*(1-self.alpha_bar/self.alpha_bar_prev)) # DDPM noise schedule # $\frac{1}{\sqrt{\bar\alpha}}$ self.one_over_sqrt_alpha_bar = 1 / (self.alpha_bar ** 0.5) # $\frac{\sqrt{1-\bar\alpha}}{\sqrt{\bar\alpha}}$ self.sqrt_1m_alpha_bar_over_sqrt_alpha_bar = (1 - self.alpha_bar)**0.5 / self.alpha_bar**0.5 # $\sqrt{\bar\alpha}$ self.sqrt_alpha_bar = self.alpha_bar ** 0.5 # $\sqrt{1 - \bar\alpha}$ self.sqrt_1m_alpha_bar = (1 - self.alpha_bar) ** 0.5 # # $\sqrt{\bar\alpha_{t-1}}$ # self.sqrt_alpha_bar_prev = self.alpha_bar_prev ** 0.5 # # $\sqrt{1-\bar\alpha_{t-1}-\sigma_t^2}$ # self.sqrt_1m_alpha_bar_prev_m_sigma2 = (1 - self.alpha_bar_prev - self.sigma_ddim ** 2) ** 0.5 #@property # def d_cond(self): #return self.model.eps_model.d_cond def get_eps( self, x: torch.Tensor, t: torch.Tensor, background_cond: Optional[torch.Tensor], uncond_scale: Optional[float], ): """ ## Get $\epsilon(x_t, c)$ :param x: is $x_t$ of shape `[batch_size, channels, height, width]` :param t: is $t$ of shape `[batch_size]` :param background_cond: background condition :param autoreg_cond: autoregressive condition :param external_cond: external condition :param c: is the conditional embeddings $c$ of shape `[batch_size, emb_size]` :param uncond_scale: is the unconditional guidance scale $s$. This is used for $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$ :param uncond_cond: is the conditional embedding for empty prompt $c_u$ """ # When the scale $s = 1$ # $$\epsilon_\theta(x_t, c) = \epsilon_\text{cond}(x_t, c)$$ batch_size = x.size(0) # if hasattr(self.model, 'style_enc'): # if external_cond is not None: # external_cond = self.model.external_cond_enc(external_cond) # if uncond_scale is None or uncond_scale == 1: # external_uncond = (-torch.ones_like(external_cond)).to(self.device) # else: # external_uncond = None # # if random.random() < 0.2: # # external_cond = (-torch.ones_like(external_cond)).to(self.device) # else: # external_cond = -torch.ones(batch_size, 4, self.d_cond, device=x.device, dtype=x.dtype) # external_uncond = None # cond = torch.cat([autoreg_cond, external_cond], 1) # if external_uncond is None: # uncond = None # else: # uncond = torch.cat([autoreg_cond, external_uncond], 1) # else: # cond = autoreg_cond # uncond = None if background_cond is not None: x = torch.cat([x, background_cond], 1) if background_cond is not None else x # if uncond is None: # e_t = self.model(x, t, cond) # else: # e_t_cond = self.model(x, t, cond) # e_t_uncond = self.model(x, t, uncond) # e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond) e_t = self.model(x,t) return e_t @torch.no_grad() def p_sample( self, x: torch.Tensor, background_cond: Optional[torch.Tensor], #autoreg_cond: Optional[torch.Tensor], #external_cond: Optional[torch.Tensor], t: torch.Tensor, step: int, repeat_noise: bool = False, temperature: float = 1., uncond_scale: float = 1., same_noise_all_measure: bool = False, X0EditFunc = None, use_classifier_free_guidance = False, use_lsh = False, reduce_extra_notes=True, rhythm_control="Yes", ): print("p_sample") """ ### Sample $x_{t-1}$ from $p_\theta(x_{t-1} | x_t)$ :param x: is $x_t$ of shape `[batch_size, channels, height, width]` :param background_cond: background condition :param autoreg_cond: autoregressive condition :param external_cond: external condition :param t: is $t$ of shape `[batch_size]` :param step: is the step $t$ as an integer :param repeat_noise: specified whether the noise should be same for all samples in the batch :param temperature: is the noise temperature (random noise gets multiplied by this) :param uncond_scale: is the unconditional guidance scale $s$. This is used for $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$ """ # Get current tau_i and tau_{i-1} tau_i = self.tau[t] step_tau_i = self.tau[step] # Get $\epsilon_\theta$ with self.autocast: if use_classifier_free_guidance: if use_lsh: assert background_cond.shape[1] == 6 # chd_onset, chd_sustain, null_chd_onset, null_chd_sustain, lsh_onset, lsh_sustain null_lsh = -torch.ones_like(background_cond[:,4:,:,:]) null_background_cond = torch.cat([background_cond[:,2:4,:,:], null_lsh], axis=1) real_background_cond = torch.cat([background_cond[:,:2,:,:], background_cond[:,4:,:,:]], axis=1) e_tau_i_null = self.get_eps(x, tau_i, null_background_cond, uncond_scale=uncond_scale) e_tau_i_real = self.get_eps(x, tau_i, real_background_cond, uncond_scale=uncond_scale) e_tau_i = e_tau_i_null + self.guidance_scale * (e_tau_i_real-e_tau_i_null) if self.guidance_rescale > 0: e_tau_i = rescale_noise_cfg(e_tau_i, e_tau_i_real, guidance_rescale=self.guidance_rescale) else: assert background_cond.shape[1] == 4 # chd_onset, chd_sustain, null_chd_onset, null_chd_sustain null_background_cond = background_cond[:,2:,:,:] real_background_cond = background_cond[:,:2,:,:] e_tau_i_null = self.get_eps(x, tau_i, null_background_cond, uncond_scale=uncond_scale) e_tau_i_real = self.get_eps(x, tau_i, real_background_cond, uncond_scale=uncond_scale) e_tau_i = e_tau_i_null + self.guidance_scale * (e_tau_i_real-e_tau_i_null) if self.guidance_rescale > 0: e_tau_i = rescale_noise_cfg(e_tau_i, e_tau_i_real, guidance_rescale=self.guidance_rescale) else: if use_lsh: assert background_cond.shape[1] == 4 # chd_onset, chd_sustain, lsh_onset, lsh_sustain e_tau_i = self.get_eps(x, tau_i, background_cond, uncond_scale=uncond_scale) else: assert background_cond.shape[1] == 2 # chd_onset, chd_sustain e_tau_i = self.get_eps(x, tau_i, background_cond, uncond_scale=uncond_scale) # Get batch size bs = x.shape[0] # $\frac{1}{\sqrt{\bar\alpha}}$ one_over_sqrt_alpha_bar = x.new_full( (bs, 1, 1, 1), self.one_over_sqrt_alpha_bar[step_tau_i] ) # $\frac{\sqrt{1-\bar\alpha}}{\sqrt{\bar\alpha}}$ sqrt_1m_alpha_bar_over_sqrt_alpha_bar = x.new_full( (bs, 1, 1, 1), self.sqrt_1m_alpha_bar_over_sqrt_alpha_bar[step_tau_i] ) # $\sigma_t$ in DDIM sigma_ddim = x.new_full( (bs, 1, 1, 1), self.sigma_ddim[step_tau_i] ) # Calculate $x_0$ with current $\epsilon_\theta$ # # predicted x_0 in DDIM predicted_x0 = one_over_sqrt_alpha_bar * x[:, 0: e_tau_i.size(1)] - sqrt_1m_alpha_bar_over_sqrt_alpha_bar * e_tau_i # edit predicted x_0 if X0EditFunc is not None: predicted_x0 = X0EditFunc(predicted_x0, background_cond, reduce_extra_notes=reduce_extra_notes, rhythm_control=rhythm_control) e_tau_i = (one_over_sqrt_alpha_bar * x[:, 0: e_tau_i.size(1)] - predicted_x0) / sqrt_1m_alpha_bar_over_sqrt_alpha_bar # Do not add noise when $t = 1$ (final step sampling process). # Note that `step` is `0` when $t = 1$) if step == 0: noise = 0 # If same noise is used for all samples in the batch elif repeat_noise: if same_noise_all_measure: noise = torch.randn((1, predicted_x0.shape[1], 16, predicted_x0.shape[3]), device=self.device).repeat(1,1,int(predicted_x0.shape[2]/16),1) else: noise = torch.randn((1, *predicted_x0.shape[1:]), device=self.device) # Different noise for each sample else: if same_noise_all_measure: noise = torch.randn(predicted_x0.shape[0], predicted_x0.shape[1], 16, predicted_x0.shape[3], device=self.device).repeat(1,1,int(predicted_x0.shape[2]/16),1) else: noise = torch.randn(predicted_x0.shape, device=self.device) # Multiply noise by the temperature noise = noise * temperature if step > 0: step_tau_i_m_1 = self.tau[step-1] # $\sqrt{\bar\alpha_{\tau_i-1}}$ sqrt_alpha_bar_prev = x.new_full( (bs, 1, 1, 1), self.sqrt_alpha_bar[step_tau_i_m_1] ) # $\sqrt{1-\bar\alpha_{\tau_i-1}-\sigma_\tau^2}$ sqrt_1m_alpha_bar_prev_m_sigma2 = x.new_full( (bs, 1, 1, 1), (1 - self.alpha_bar[step_tau_i_m_1] - self.sigma_ddim[step_tau_i] ** 2) ** 0.5 ) direction_to_xt = sqrt_1m_alpha_bar_prev_m_sigma2 * e_tau_i x_prev = sqrt_alpha_bar_prev * predicted_x0 + direction_to_xt + sigma_ddim * noise else: x_prev = predicted_x0 + sigma_ddim * noise # Sample from, # # $$p_\theta(x_{t-1} | x_t) = \mathcal{N}\big(x_{t-1}; \mu_\theta(x_t, t), \tilde\beta_t \mathbf{I} \big)$$ # return x_prev, predicted_x0, e_tau_i @torch.no_grad() def q_sample( self, x0: torch.Tensor, index: int, noise: Optional[torch.Tensor] = None ): """ ### Sample from $q(x_t|x_0)$ $$q(x_t|x_0) = \mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$$ :param x0: is $x_0$ of shape `[batch_size, channels, height, width]` :param index: is the time step $t$ index :param noise: is the noise, $\epsilon$ """ # Random noise, if noise is not specified if noise is None: noise = torch.randn_like(x0, device=self.device) # Sample from $\mathcal{N} \Big(x_t; \sqrt{\bar\alpha_t} x_0, (1-\bar\alpha_t) \mathbf{I} \Big)$ return self.sqrt_alpha_bar[index] * x0 + self.sqrt_1m_alpha_bar[index] * noise @torch.no_grad() def sample( self, shape: List[int], background_cond: Optional[torch.Tensor] = None, #autoreg_cond: Optional[torch.Tensor] = None, #external_cond: Optional[torch.Tensor] = None, repeat_noise: bool = False, temperature: float = 1., uncond_scale: float = 1., x_last: Optional[torch.Tensor] = None, t_start: int = 0, same_noise_all_measure: bool = False, X0EditFunc = None, use_classifier_free_guidance = False, use_lsh = False, reduce_extra_notes=True, rhythm_control="Yes", ): """ ### Sampling Loop :param shape: is the shape of the generated images in the form `[batch_size, channels, height, width]` :param background_cond: background condition :param autoreg_cond: autoregressive condition :param external_cond: external condition :param repeat_noise: specified whether the noise should be same for all samples in the batch :param temperature: is the noise temperature (random noise gets multiplied by this) :param x_last: is $x_T$. If not provided random noise will be used. :param uncond_scale: is the unconditional guidance scale $s$. This is used for $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$ :param t_start: t_start """ # Get device and batch size bs = shape[0] ###### print(shape) ###### # Get $x_T$ if same_noise_all_measure: x = x_last if x_last is not None else torch.randn(shape[0],shape[1],16,shape[3], device=self.device).repeat(1,1,int(shape[2]/16),1) else: x = x_last if x_last is not None else torch.randn(shape, device=self.device) # Time steps to sample at $T - t', T - t' - 1, \dots, 1$ time_steps = np.flip(np.asarray(list(range(self.used_n_steps)), dtype=np.int32))[t_start:] # Sampling loop for step in monit.iterate('Sample', time_steps): # Time step $t$ ts = x.new_full((bs, ), step, dtype=torch.long) x, pred_x0, e_t = self.p_sample( x, background_cond, #autoreg_cond, #external_cond, ts, step, repeat_noise=repeat_noise, temperature=temperature, uncond_scale=uncond_scale, same_noise_all_measure=same_noise_all_measure, X0EditFunc = X0EditFunc, use_classifier_free_guidance = use_classifier_free_guidance, use_lsh=use_lsh, reduce_extra_notes=reduce_extra_notes, rhythm_control=rhythm_control ) s1 = step + 1 # if self.is_show_image: # if s1 % 100 == 0 or (s1 <= 100 and s1 % 25 == 0): # show_image(x, f"exp/img/x{s1}.png") # Return $x_0$ # if self.is_show_image: # show_image(x, f"exp/img/x0.png") return x @torch.no_grad() def paint( self, x: Optional[torch.Tensor] = None, background_cond: Optional[torch.Tensor] = None, #autoreg_cond: Optional[torch.Tensor] = None, #external_cond: Optional[torch.Tensor] = None, t_start: int = 0, orig: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, orig_noise: Optional[torch.Tensor] = None, uncond_scale: float = 1., same_noise_all_measure: bool = False, X0EditFunc = None, use_classifier_free_guidance = False, use_lsh = False, ): """ ### Painting Loop :param x: is $x_{S'}$ of shape `[batch_size, channels, height, width]` :param background_cond: background condition :param autoreg_cond: autoregressive condition :param external_cond: external condition :param t_start: is the sampling step to start from, $S'$ :param orig: is the original image in latent page which we are in paining. If this is not provided, it'll be an image to image transformation. :param mask: is the mask to keep the original image. :param orig_noise: is fixed noise to be added to the original image. :param uncond_scale: is the unconditional guidance scale $s$. This is used for $\epsilon_\theta(x_t, c) = s\epsilon_\text{cond}(x_t, c) + (s - 1)\epsilon_\text{cond}(x_t, c_u)$ """ # Get batch size bs = orig.size(0) if x is None: x = torch.randn(orig.shape, device=self.device) # Time steps to sample at $\tau_{S`}, \tau_{S' - 1}, \dots, \tau_1$ # time_steps = np.flip(self.time_steps[: t_start]) time_steps = np.flip(np.asarray(list(range(self.used_n_steps)), dtype=np.int32))[t_start:] for i, step in monit.enum('Paint', time_steps): # Index $i$ in the list $[\tau_1, \tau_2, \dots, \tau_S]$ # index = len(time_steps) - i - 1 # Time step $\tau_i$ ts = x.new_full((bs, ), step, dtype=torch.long) # Sample $x_{\tau_{i-1}}$ x, _, _ = self.p_sample( x, background_cond, #autoreg_cond, #external_cond, t=ts, step=step, uncond_scale=uncond_scale, same_noise_all_measure=same_noise_all_measure, X0EditFunc = X0EditFunc, use_classifier_free_guidance = use_classifier_free_guidance, use_lsh=use_lsh, ) # Replace the masked area with original image if orig is not None: assert mask is not None # Get the $q_{\sigma,\tau}(x_{\tau_i}|x_0)$ for original image in latent space orig_t = self.q_sample(orig, self.tau[step], noise=orig_noise) # Replace the masked area x = orig_t * mask + x * (1 - mask) s1 = step + 1 # if self.is_show_image: # if s1 % 100 == 0 or (s1 <= 100 and s1 % 25 == 0): # show_image(x, f"exp/img/x{s1}.png") # if self.is_show_image: # show_image(x, f"exp/img/x0.png") return x def generate(self, background_cond=None, batch_size=1, uncond_scale=None, same_noise_all_measure=False, X0EditFunc=None, use_classifier_free_guidance=False, use_lsh=False, reduce_extra_notes=True, rhythm_control="Yes"): shape = [batch_size, self.out_channel, self.max_l, self.h] if self.debug_mode: return torch.randn(shape, dtype=torch.float) return self.sample(shape, background_cond, uncond_scale=uncond_scale, same_noise_all_measure=same_noise_all_measure, X0EditFunc=X0EditFunc, use_classifier_free_guidance=use_classifier_free_guidance, use_lsh=use_lsh, reduce_extra_notes=reduce_extra_notes, rhythm_control=rhythm_control )