Spaces:
Running
on
Zero
Running
on
Zero
wangshuai6
commited on
Commit
·
4d3bd2d
1
Parent(s):
de07e5c
init
Browse files
app.py
CHANGED
@@ -33,6 +33,7 @@
|
|
33 |
# step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
|
34 |
import random
|
35 |
import os
|
|
|
36 |
import spaces
|
37 |
import torch
|
38 |
import argparse
|
@@ -69,8 +70,8 @@ def load_model(weight_dict, denoiser):
|
|
69 |
class Pipeline:
|
70 |
def __init__(self, vae, denoiser, conditioner, resolution):
|
71 |
self.vae = vae
|
72 |
-
self.denoiser = denoiser
|
73 |
-
self.conditioner = conditioner
|
74 |
self.resolution = resolution
|
75 |
self.tmp_dir = tempfile.TemporaryDirectory(prefix="traj_gifs_")
|
76 |
# self.denoiser.compile()
|
@@ -78,10 +79,10 @@ class Pipeline:
|
|
78 |
def __del__(self):
|
79 |
self.tmp_dir.cleanup()
|
80 |
|
81 |
-
@spaces.GPU
|
82 |
@torch.no_grad()
|
83 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
84 |
-
def __call__(self, y, seed,
|
85 |
diffusion_sampler = AdamLMSampler(
|
86 |
order=order,
|
87 |
scheduler=LinearScheduler(),
|
@@ -91,18 +92,18 @@ class Pipeline:
|
|
91 |
timeshift=timeshift
|
92 |
)
|
93 |
generator = torch.Generator(device="cpu").manual_seed(seed)
|
94 |
-
|
95 |
-
image_width = image_width // 32 * 32
|
96 |
-
self.denoiser.decoder_patch_scaling_h = image_height / 512
|
97 |
-
self.denoiser.decoder_patch_scaling_w = image_width / 512
|
98 |
-
xT = torch.randn((1, 3, image_height, image_width), device="cpu", dtype=torch.float32,
|
99 |
generator=generator).cuda()
|
|
|
|
|
100 |
with torch.no_grad():
|
101 |
condition, uncondition = conditioner([y,]*1)
|
|
|
102 |
|
103 |
-
|
104 |
# Sample images:
|
105 |
samples, trajs = diffusion_sampler(denoiser, xT, condition, uncondition, return_x_trajs=True)
|
|
|
106 |
|
107 |
def decode_images(samples):
|
108 |
samples = vae.decode(samples)
|
@@ -114,35 +115,35 @@ class Pipeline:
|
|
114 |
images.append(image)
|
115 |
return images
|
116 |
|
117 |
-
def decode_trajs(trajs):
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
|
138 |
images = decode_images(samples)
|
139 |
-
animations = decode_trajs(trajs)
|
140 |
|
141 |
-
return images
|
142 |
|
143 |
if __name__ == "__main__":
|
144 |
parser = argparse.ArgumentParser()
|
145 |
-
parser.add_argument("--config", type=str, default="configs_t2i/
|
146 |
parser.add_argument("--resolution", type=int, default=512)
|
147 |
parser.add_argument("--model_id", type=str, default="MCG-NJU/PixNerd-XXL-P16-T2I")
|
148 |
parser.add_argument("--ckpt_path", type=str, default="models")
|
@@ -167,6 +168,7 @@ if __name__ == "__main__":
|
|
167 |
ckpt = torch.load(ckpt_path, map_location="cpu")
|
168 |
denoiser = load_model(ckpt, denoiser)
|
169 |
denoiser = denoiser.cuda()
|
|
|
170 |
vae = vae.cuda()
|
171 |
denoiser.eval()
|
172 |
|
@@ -179,27 +181,23 @@ if __name__ == "__main__":
|
|
179 |
with gr.Column(scale=1):
|
180 |
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=25)
|
181 |
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
|
182 |
-
image_height = gr.Slider(minimum=128, maximum=1024, step=32, label="image height", value=512)
|
183 |
-
image_width = gr.Slider(minimum=128, maximum=1024, step=32, label="image width", value=512)
|
184 |
label = gr.Textbox(label="positive prompt", value="a photo of a cat")
|
185 |
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
|
186 |
timeshift = gr.Slider(minimum=0.1, maximum=5.0, step=0.1, label="timeshift", value=3.0)
|
187 |
order = gr.Slider(minimum=1, maximum=4, step=1, label="order", value=2)
|
188 |
with gr.Column(scale=2):
|
189 |
btn = gr.Button("Generate")
|
190 |
-
output_sample = gr.
|
191 |
-
with gr.Column(scale=2):
|
192 |
-
|
193 |
|
194 |
btn.click(fn=pipeline,
|
195 |
inputs=[
|
196 |
label,
|
197 |
seed,
|
198 |
-
image_height,
|
199 |
-
image_width,
|
200 |
num_steps,
|
201 |
guidance,
|
202 |
timeshift,
|
203 |
order
|
204 |
-
], outputs=[output_sample
|
205 |
demo.launch()
|
|
|
33 |
# step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
|
34 |
import random
|
35 |
import os
|
36 |
+
import time
|
37 |
import spaces
|
38 |
import torch
|
39 |
import argparse
|
|
|
70 |
class Pipeline:
|
71 |
def __init__(self, vae, denoiser, conditioner, resolution):
|
72 |
self.vae = vae
|
73 |
+
self.denoiser = denoiser.cuda()
|
74 |
+
self.conditioner = conditioner.cuda()
|
75 |
self.resolution = resolution
|
76 |
self.tmp_dir = tempfile.TemporaryDirectory(prefix="traj_gifs_")
|
77 |
# self.denoiser.compile()
|
|
|
79 |
def __del__(self):
|
80 |
self.tmp_dir.cleanup()
|
81 |
|
82 |
+
# @spaces.GPU
|
83 |
@torch.no_grad()
|
84 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
85 |
+
def __call__(self, y, seed, num_steps, guidance, timeshift, order):
|
86 |
diffusion_sampler = AdamLMSampler(
|
87 |
order=order,
|
88 |
scheduler=LinearScheduler(),
|
|
|
92 |
timeshift=timeshift
|
93 |
)
|
94 |
generator = torch.Generator(device="cpu").manual_seed(seed)
|
95 |
+
xT = torch.randn((1, 3, 512, 512), device="cpu", dtype=torch.float32,
|
|
|
|
|
|
|
|
|
96 |
generator=generator).cuda()
|
97 |
+
|
98 |
+
start = time.time()
|
99 |
with torch.no_grad():
|
100 |
condition, uncondition = conditioner([y,]*1)
|
101 |
+
print("conditioner:",time.time() - start)
|
102 |
|
103 |
+
start = time.time()
|
104 |
# Sample images:
|
105 |
samples, trajs = diffusion_sampler(denoiser, xT, condition, uncondition, return_x_trajs=True)
|
106 |
+
print("diffusion:",time.time() - start)
|
107 |
|
108 |
def decode_images(samples):
|
109 |
samples = vae.decode(samples)
|
|
|
115 |
images.append(image)
|
116 |
return images
|
117 |
|
118 |
+
# def decode_trajs(trajs):
|
119 |
+
# cat_trajs = torch.stack(trajs, dim=0).permute(1, 0, 2, 3, 4)
|
120 |
+
# animations = []
|
121 |
+
# for i in range(cat_trajs.shape[0]):
|
122 |
+
# frames = decode_images(
|
123 |
+
# cat_trajs[i]
|
124 |
+
# )
|
125 |
+
# # 生成唯一文件名(结合seed和样本索引,避免冲突)
|
126 |
+
# gif_filename = f"{random.randint(0, 100000)}.gif"
|
127 |
+
# gif_path = os.path.join(self.tmp_dir.name, gif_filename)
|
128 |
+
# frames[0].save(
|
129 |
+
# gif_path,
|
130 |
+
# format="GIF",
|
131 |
+
# append_images=frames[1:],
|
132 |
+
# save_all=True,
|
133 |
+
# duration=200,
|
134 |
+
# loop=0
|
135 |
+
# )
|
136 |
+
# animations.append(gif_path)
|
137 |
+
# return animations
|
138 |
|
139 |
images = decode_images(samples)
|
140 |
+
# animations = decode_trajs(trajs)
|
141 |
|
142 |
+
return images[0]
|
143 |
|
144 |
if __name__ == "__main__":
|
145 |
parser = argparse.ArgumentParser()
|
146 |
+
parser.add_argument("--config", type=str, default="configs_t2i/sft_res512.yaml")
|
147 |
parser.add_argument("--resolution", type=int, default=512)
|
148 |
parser.add_argument("--model_id", type=str, default="MCG-NJU/PixNerd-XXL-P16-T2I")
|
149 |
parser.add_argument("--ckpt_path", type=str, default="models")
|
|
|
168 |
ckpt = torch.load(ckpt_path, map_location="cpu")
|
169 |
denoiser = load_model(ckpt, denoiser)
|
170 |
denoiser = denoiser.cuda()
|
171 |
+
conditioner = conditioner.cuda()
|
172 |
vae = vae.cuda()
|
173 |
denoiser.eval()
|
174 |
|
|
|
181 |
with gr.Column(scale=1):
|
182 |
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=25)
|
183 |
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
|
|
|
|
|
184 |
label = gr.Textbox(label="positive prompt", value="a photo of a cat")
|
185 |
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
|
186 |
timeshift = gr.Slider(minimum=0.1, maximum=5.0, step=0.1, label="timeshift", value=3.0)
|
187 |
order = gr.Slider(minimum=1, maximum=4, step=1, label="order", value=2)
|
188 |
with gr.Column(scale=2):
|
189 |
btn = gr.Button("Generate")
|
190 |
+
output_sample = gr.Image(label="Images")
|
191 |
+
# with gr.Column(scale=2):
|
192 |
+
# output_trajs = gr.Gallery(label="Trajs of Diffusion", columns=2, rows=2)
|
193 |
|
194 |
btn.click(fn=pipeline,
|
195 |
inputs=[
|
196 |
label,
|
197 |
seed,
|
|
|
|
|
198 |
num_steps,
|
199 |
guidance,
|
200 |
timeshift,
|
201 |
order
|
202 |
+
], outputs=[output_sample])
|
203 |
demo.launch()
|