Spaces:
Running
on
Zero
Running
on
Zero
wangshuai6
commited on
Commit
·
e321f58
1
Parent(s):
56238f0
init
Browse files
app.py
CHANGED
@@ -43,6 +43,8 @@ from src.diffusion.flow_matching.scheduling import LinearScheduler
|
|
43 |
from PIL import Image
|
44 |
import gradio as gr
|
45 |
import tempfile
|
|
|
|
|
46 |
from huggingface_hub import snapshot_download
|
47 |
|
48 |
|
@@ -65,9 +67,9 @@ def load_model(weight_dict, denoiser):
|
|
65 |
|
66 |
class Pipeline:
|
67 |
def __init__(self, vae, denoiser, conditioner, resolution):
|
68 |
-
self.vae = vae
|
69 |
-
self.denoiser = denoiser
|
70 |
-
self.conditioner = conditioner
|
71 |
self.conditioner.compile()
|
72 |
self.resolution = resolution
|
73 |
self.tmp_dir = tempfile.TemporaryDirectory(prefix="traj_gifs_")
|
@@ -76,6 +78,7 @@ class Pipeline:
|
|
76 |
def __del__(self):
|
77 |
self.tmp_dir.cleanup()
|
78 |
|
|
|
79 |
@torch.no_grad()
|
80 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
81 |
def __call__(self, y, num_images, seed, image_height, image_width, num_steps, guidance, timeshift, order):
|
@@ -93,8 +96,7 @@ class Pipeline:
|
|
93 |
self.denoiser.decoder_patch_scaling_h = image_height / 512
|
94 |
self.denoiser.decoder_patch_scaling_w = image_width / 512
|
95 |
xT = torch.randn((num_images, 3, image_height, image_width), device="cpu", dtype=torch.float32,
|
96 |
-
generator=generator)
|
97 |
-
xT = xT.to("cuda")
|
98 |
with torch.no_grad():
|
99 |
condition, uncondition = conditioner([y,]*num_images)
|
100 |
|
|
|
43 |
from PIL import Image
|
44 |
import gradio as gr
|
45 |
import tempfile
|
46 |
+
import spaces
|
47 |
+
|
48 |
from huggingface_hub import snapshot_download
|
49 |
|
50 |
|
|
|
67 |
|
68 |
class Pipeline:
|
69 |
def __init__(self, vae, denoiser, conditioner, resolution):
|
70 |
+
self.vae = vae
|
71 |
+
self.denoiser = denoiser
|
72 |
+
self.conditioner = conditioner
|
73 |
self.conditioner.compile()
|
74 |
self.resolution = resolution
|
75 |
self.tmp_dir = tempfile.TemporaryDirectory(prefix="traj_gifs_")
|
|
|
78 |
def __del__(self):
|
79 |
self.tmp_dir.cleanup()
|
80 |
|
81 |
+
@spaces.GPU
|
82 |
@torch.no_grad()
|
83 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
84 |
def __call__(self, y, num_images, seed, image_height, image_width, num_steps, guidance, timeshift, order):
|
|
|
96 |
self.denoiser.decoder_patch_scaling_h = image_height / 512
|
97 |
self.denoiser.decoder_patch_scaling_w = image_width / 512
|
98 |
xT = torch.randn((num_images, 3, image_height, image_width), device="cpu", dtype=torch.float32,
|
99 |
+
generator=generator).cuda()
|
|
|
100 |
with torch.no_grad():
|
101 |
condition, uncondition = conditioner([y,]*num_images)
|
102 |
|