wangshuai6 commited on
Commit
e321f58
·
1 Parent(s): 56238f0
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -43,6 +43,8 @@ from src.diffusion.flow_matching.scheduling import LinearScheduler
43
  from PIL import Image
44
  import gradio as gr
45
  import tempfile
 
 
46
  from huggingface_hub import snapshot_download
47
 
48
 
@@ -65,9 +67,9 @@ def load_model(weight_dict, denoiser):
65
 
66
  class Pipeline:
67
  def __init__(self, vae, denoiser, conditioner, resolution):
68
- self.vae = vae.cuda()
69
- self.denoiser = denoiser.cuda()
70
- self.conditioner = conditioner.cuda()
71
  self.conditioner.compile()
72
  self.resolution = resolution
73
  self.tmp_dir = tempfile.TemporaryDirectory(prefix="traj_gifs_")
@@ -76,6 +78,7 @@ class Pipeline:
76
  def __del__(self):
77
  self.tmp_dir.cleanup()
78
 
 
79
  @torch.no_grad()
80
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
81
  def __call__(self, y, num_images, seed, image_height, image_width, num_steps, guidance, timeshift, order):
@@ -93,8 +96,7 @@ class Pipeline:
93
  self.denoiser.decoder_patch_scaling_h = image_height / 512
94
  self.denoiser.decoder_patch_scaling_w = image_width / 512
95
  xT = torch.randn((num_images, 3, image_height, image_width), device="cpu", dtype=torch.float32,
96
- generator=generator)
97
- xT = xT.to("cuda")
98
  with torch.no_grad():
99
  condition, uncondition = conditioner([y,]*num_images)
100
 
 
43
  from PIL import Image
44
  import gradio as gr
45
  import tempfile
46
+ import spaces
47
+
48
  from huggingface_hub import snapshot_download
49
 
50
 
 
67
 
68
  class Pipeline:
69
  def __init__(self, vae, denoiser, conditioner, resolution):
70
+ self.vae = vae
71
+ self.denoiser = denoiser
72
+ self.conditioner = conditioner
73
  self.conditioner.compile()
74
  self.resolution = resolution
75
  self.tmp_dir = tempfile.TemporaryDirectory(prefix="traj_gifs_")
 
78
  def __del__(self):
79
  self.tmp_dir.cleanup()
80
 
81
+ @spaces.GPU
82
  @torch.no_grad()
83
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
84
  def __call__(self, y, num_images, seed, image_height, image_width, num_steps, guidance, timeshift, order):
 
96
  self.denoiser.decoder_patch_scaling_h = image_height / 512
97
  self.denoiser.decoder_patch_scaling_w = image_width / 512
98
  xT = torch.randn((num_images, 3, image_height, image_width), device="cpu", dtype=torch.float32,
99
+ generator=generator).cuda()
 
100
  with torch.no_grad():
101
  condition, uncondition = conditioner([y,]*num_images)
102