multimodalart HF Staff commited on
Commit
5e08962
·
verified ·
1 Parent(s): a999829

Update radial_attn/models/wan/sparse_transformer.py

Browse files
radial_attn/models/wan/sparse_transformer.py CHANGED
@@ -367,176 +367,176 @@ class WanPipeline_Sparse(WanPipeline):
367
 
368
  return WanPipelineOutput(frames=video)
369
 
370
- # Add this entire function to the file
371
- @torch.no_grad()
372
- def wan_i2v_pipeline_call_sparse(
373
- self,
374
- image: PipelineImageInput,
375
- prompt: Union[str, List[str]] = None,
376
- negative_prompt: Union[str, List[str]] = None,
377
- height: int = 480,
378
- width: int = 832,
379
- num_frames: int = 81,
380
- num_inference_steps: int = 50,
381
- guidance_scale: float = 5.0,
382
- num_videos_per_prompt: Optional[int] = 1,
383
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
384
- latents: Optional[torch.Tensor] = None,
385
- prompt_embeds: Optional[torch.Tensor] = None,
386
- negative_prompt_embeds: Optional[torch.Tensor] = None,
387
- image_embeds: Optional[torch.Tensor] = None,
388
- last_image: Optional[torch.Tensor] = None,
389
- output_type: Optional[str] = "np",
390
- return_dict: bool = True,
391
- attention_kwargs: Optional[Dict[str, Any]] = None,
392
- callback_on_step_end: Optional[
393
- Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
394
- ] = None,
395
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
396
- max_sequence_length: int = 512,
397
- ):
398
- if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
399
- callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
400
-
401
- self.check_inputs(
402
- prompt,
403
- negative_prompt,
404
- image,
405
- height,
406
- width,
407
- prompt_embeds,
408
- negative_prompt_embeds,
409
- image_embeds,
410
- callback_on_step_end_tensor_inputs,
411
- )
412
- if num_frames % self.vae_scale_factor_temporal != 1:
413
- logger.warning(
414
- f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
415
  )
416
- num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
417
- num_frames = max(num_frames, 1)
418
-
419
- self._guidance_scale = guidance_scale
420
- self._attention_kwargs = attention_kwargs
421
- self._current_timestep = None
422
- self._interrupt = False
423
- device = self._execution_device
424
-
425
- if prompt is not None and isinstance(prompt, str):
426
- batch_size = 1
427
- elif prompt is not None and isinstance(prompt, list):
428
- batch_size = len(prompt)
429
- else:
430
- batch_size = prompt_embeds.shape[0]
431
-
432
- prompt_embeds, negative_prompt_embeds = self.encode_prompt(
433
- prompt=prompt,
434
- negative_prompt=negative_prompt,
435
- do_classifier_free_guidance=self.do_classifier_free_guidance,
436
- num_videos_per_prompt=num_videos_per_prompt,
437
- prompt_embeds=prompt_embeds,
438
- negative_prompt_embeds=negative_prompt_embeds,
439
- max_sequence_length=max_sequence_length,
440
- device=device,
441
- )
442
- transformer_dtype = self.transformer.dtype
443
- prompt_embeds = prompt_embeds.to(transformer_dtype)
444
- if negative_prompt_embeds is not None:
445
- negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
446
- if image_embeds is None:
447
- if last_image is None:
448
- image_embeds = self.encode_image(image, device)
449
  else:
450
- image_embeds = self.encode_image([image, last_image], device)
451
- image_embeds = image_embeds.repeat(batch_size, 1, 1)
452
- image_embeds = image_embeds.to(transformer_dtype)
453
-
454
- self.scheduler.set_timesteps(num_inference_steps, device=device)
455
- timesteps = self.scheduler.timesteps
456
- num_channels_latents = self.vae.config.z_dim
457
- image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
458
- if last_image is not None:
459
- last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
460
- device, dtype=torch.float32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  )
462
- latents, condition = self.prepare_latents(
463
- image,
464
- batch_size * num_videos_per_prompt,
465
- num_channels_latents,
466
- height,
467
- width,
468
- num_frames,
469
- torch.float32,
470
- device,
471
- generator,
472
- latents,
473
- last_image,
474
- )
475
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
476
- self._num_timesteps = len(timesteps)
477
-
478
- with self.progress_bar(total=num_inference_steps) as progress_bar:
479
- for i, t in enumerate(timesteps):
480
- if self.interrupt:
481
- continue
482
- self._current_timestep = t
483
- latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
484
- timestep = t.expand(latents.shape[0])
485
- noise_pred = self.transformer(
486
- hidden_states=latent_model_input,
487
- timestep=timestep,
488
- encoder_hidden_states=prompt_embeds,
489
- encoder_hidden_states_image=image_embeds,
490
- attention_kwargs=attention_kwargs,
491
- return_dict=False,
492
- numeral_timestep=i, # <--- MODIFICATION
493
- )[0]
494
- if self.do_classifier_free_guidance:
495
- noise_uncond = self.transformer(
496
  hidden_states=latent_model_input,
497
  timestep=timestep,
498
- encoder_hidden_states=negative_prompt_embeds,
499
  encoder_hidden_states_image=image_embeds,
500
  attention_kwargs=attention_kwargs,
501
  return_dict=False,
502
  numeral_timestep=i, # <--- MODIFICATION
503
  )[0]
504
- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
505
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
506
- if callback_on_step_end is not None:
507
- callback_kwargs = {}
508
- for k in callback_on_step_end_tensor_inputs:
509
- callback_kwargs[k] = locals()[k]
510
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
511
- latents = callback_outputs.pop("latents", latents)
512
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
513
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
514
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
515
- progress_bar.update()
516
-
517
- self._current_timestep = None
518
- if not output_type == "latent":
519
- latents = latents.to(self.vae.dtype)
520
- latents_mean = (
521
- torch.tensor(self.vae.config.latents_mean)
522
- .view(1, self.vae.config.z_dim, 1, 1, 1)
523
- .to(latents.device, latents.dtype)
524
- )
525
- latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
526
- latents.device, latents.dtype
527
- )
528
- latents = latents / latents_std + latents_mean
529
- video = self.vae.decode(latents, return_dict=False)[0]
530
- video = self.video_processor.postprocess_video(video, output_type=output_type)
531
- else:
532
- video = latents
533
- self.maybe_free_model_hooks()
534
- if not return_dict:
535
- return (video,)
536
- return WanPipelineOutput(frames=video)
 
 
 
 
 
 
 
 
 
 
537
 
538
  def replace_sparse_forward():
539
  WanTransformerBlock.forward = WanTransformerBlock_Sparse.forward
540
  WanTransformer3DModel.forward = WanTransformer3DModel_Sparse.forward
541
  WanPipeline.__call__ = WanPipeline_Sparse.__call__
542
- WanImageToVideoPipeline.__call__ = wan_i2v_pipeline_call_sparse
 
367
 
368
  return WanPipelineOutput(frames=video)
369
 
370
+ class WanImageToVideoPipeline_Sparse(WanImageToVideoPipeline):
371
+ @torch.no_grad()
372
+ def __call__(
373
+ self,
374
+ image: PipelineImageInput,
375
+ prompt: Union[str, List[str]] = None,
376
+ negative_prompt: Union[str, List[str]] = None,
377
+ height: int = 480,
378
+ width: int = 832,
379
+ num_frames: int = 81,
380
+ num_inference_steps: int = 50,
381
+ guidance_scale: float = 5.0,
382
+ num_videos_per_prompt: Optional[int] = 1,
383
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
384
+ latents: Optional[torch.Tensor] = None,
385
+ prompt_embeds: Optional[torch.Tensor] = None,
386
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
387
+ image_embeds: Optional[torch.Tensor] = None,
388
+ last_image: Optional[torch.Tensor] = None,
389
+ output_type: Optional[str] = "np",
390
+ return_dict: bool = True,
391
+ attention_kwargs: Optional[Dict[str, Any]] = None,
392
+ callback_on_step_end: Optional[
393
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
394
+ ] = None,
395
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
396
+ max_sequence_length: int = 512,
397
+ ):
398
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
399
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
400
+
401
+ self.check_inputs(
402
+ prompt,
403
+ negative_prompt,
404
+ image,
405
+ height,
406
+ width,
407
+ prompt_embeds,
408
+ negative_prompt_embeds,
409
+ image_embeds,
410
+ callback_on_step_end_tensor_inputs,
 
 
 
 
411
  )
412
+ if num_frames % self.vae_scale_factor_temporal != 1:
413
+ logger.warning(
414
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
415
+ )
416
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
417
+ num_frames = max(num_frames, 1)
418
+
419
+ self._guidance_scale = guidance_scale
420
+ self._attention_kwargs = attention_kwargs
421
+ self._current_timestep = None
422
+ self._interrupt = False
423
+ device = self._execution_device
424
+
425
+ if prompt is not None and isinstance(prompt, str):
426
+ batch_size = 1
427
+ elif prompt is not None and isinstance(prompt, list):
428
+ batch_size = len(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  else:
430
+ batch_size = prompt_embeds.shape[0]
431
+
432
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
433
+ prompt=prompt,
434
+ negative_prompt=negative_prompt,
435
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
436
+ num_videos_per_prompt=num_videos_per_prompt,
437
+ prompt_embeds=prompt_embeds,
438
+ negative_prompt_embeds=negative_prompt_embeds,
439
+ max_sequence_length=max_sequence_length,
440
+ device=device,
441
+ )
442
+ transformer_dtype = self.transformer.dtype
443
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
444
+ if negative_prompt_embeds is not None:
445
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
446
+ if image_embeds is None:
447
+ if last_image is None:
448
+ image_embeds = self.encode_image(image, device)
449
+ else:
450
+ image_embeds = self.encode_image([image, last_image], device)
451
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
452
+ image_embeds = image_embeds.to(transformer_dtype)
453
+
454
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
455
+ timesteps = self.scheduler.timesteps
456
+ num_channels_latents = self.vae.config.z_dim
457
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
458
+ if last_image is not None:
459
+ last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
460
+ device, dtype=torch.float32
461
+ )
462
+ latents, condition = self.prepare_latents(
463
+ image,
464
+ batch_size * num_videos_per_prompt,
465
+ num_channels_latents,
466
+ height,
467
+ width,
468
+ num_frames,
469
+ torch.float32,
470
+ device,
471
+ generator,
472
+ latents,
473
+ last_image,
474
  )
475
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
476
+ self._num_timesteps = len(timesteps)
477
+
478
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
479
+ for i, t in enumerate(timesteps):
480
+ if self.interrupt:
481
+ continue
482
+ self._current_timestep = t
483
+ latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
484
+ timestep = t.expand(latents.shape[0])
485
+ noise_pred = self.transformer(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  hidden_states=latent_model_input,
487
  timestep=timestep,
488
+ encoder_hidden_states=prompt_embeds,
489
  encoder_hidden_states_image=image_embeds,
490
  attention_kwargs=attention_kwargs,
491
  return_dict=False,
492
  numeral_timestep=i, # <--- MODIFICATION
493
  )[0]
494
+ if self.do_classifier_free_guidance:
495
+ noise_uncond = self.transformer(
496
+ hidden_states=latent_model_input,
497
+ timestep=timestep,
498
+ encoder_hidden_states=negative_prompt_embeds,
499
+ encoder_hidden_states_image=image_embeds,
500
+ attention_kwargs=attention_kwargs,
501
+ return_dict=False,
502
+ numeral_timestep=i, # <--- MODIFICATION
503
+ )[0]
504
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
505
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
506
+ if callback_on_step_end is not None:
507
+ callback_kwargs = {}
508
+ for k in callback_on_step_end_tensor_inputs:
509
+ callback_kwargs[k] = locals()[k]
510
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
511
+ latents = callback_outputs.pop("latents", latents)
512
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
513
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
514
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
515
+ progress_bar.update()
516
+
517
+ self._current_timestep = None
518
+ if not output_type == "latent":
519
+ latents = latents.to(self.vae.dtype)
520
+ latents_mean = (
521
+ torch.tensor(self.vae.config.latents_mean)
522
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
523
+ .to(latents.device, latents.dtype)
524
+ )
525
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
526
+ latents.device, latents.dtype
527
+ )
528
+ latents = latents / latents_std + latents_mean
529
+ video = self.vae.decode(latents, return_dict=False)[0]
530
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
531
+ else:
532
+ video = latents
533
+ self.maybe_free_model_hooks()
534
+ if not return_dict:
535
+ return (video,)
536
+ return WanPipelineOutput(frames=video)
537
 
538
  def replace_sparse_forward():
539
  WanTransformerBlock.forward = WanTransformerBlock_Sparse.forward
540
  WanTransformer3DModel.forward = WanTransformer3DModel_Sparse.forward
541
  WanPipeline.__call__ = WanPipeline_Sparse.__call__
542
+ WanImageToVideoPipeline.__call__ = WanImageToVideoPipeline_Sparse.__call__