wangshuai6 commited on
Commit
4d3bd2d
·
1 Parent(s): de07e5c
Files changed (1) hide show
  1. app.py +39 -41
app.py CHANGED
@@ -33,6 +33,7 @@
33
  # step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
34
  import random
35
  import os
 
36
  import spaces
37
  import torch
38
  import argparse
@@ -69,8 +70,8 @@ def load_model(weight_dict, denoiser):
69
  class Pipeline:
70
  def __init__(self, vae, denoiser, conditioner, resolution):
71
  self.vae = vae
72
- self.denoiser = denoiser
73
- self.conditioner = conditioner
74
  self.resolution = resolution
75
  self.tmp_dir = tempfile.TemporaryDirectory(prefix="traj_gifs_")
76
  # self.denoiser.compile()
@@ -78,10 +79,10 @@ class Pipeline:
78
  def __del__(self):
79
  self.tmp_dir.cleanup()
80
 
81
- @spaces.GPU
82
  @torch.no_grad()
83
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
84
- def __call__(self, y, seed, image_height, image_width, num_steps, guidance, timeshift, order):
85
  diffusion_sampler = AdamLMSampler(
86
  order=order,
87
  scheduler=LinearScheduler(),
@@ -91,18 +92,18 @@ class Pipeline:
91
  timeshift=timeshift
92
  )
93
  generator = torch.Generator(device="cpu").manual_seed(seed)
94
- image_height = image_height // 32 * 32
95
- image_width = image_width // 32 * 32
96
- self.denoiser.decoder_patch_scaling_h = image_height / 512
97
- self.denoiser.decoder_patch_scaling_w = image_width / 512
98
- xT = torch.randn((1, 3, image_height, image_width), device="cpu", dtype=torch.float32,
99
  generator=generator).cuda()
 
 
100
  with torch.no_grad():
101
  condition, uncondition = conditioner([y,]*1)
 
102
 
103
-
104
  # Sample images:
105
  samples, trajs = diffusion_sampler(denoiser, xT, condition, uncondition, return_x_trajs=True)
 
106
 
107
  def decode_images(samples):
108
  samples = vae.decode(samples)
@@ -114,35 +115,35 @@ class Pipeline:
114
  images.append(image)
115
  return images
116
 
117
- def decode_trajs(trajs):
118
- cat_trajs = torch.stack(trajs, dim=0).permute(1, 0, 2, 3, 4)
119
- animations = []
120
- for i in range(cat_trajs.shape[0]):
121
- frames = decode_images(
122
- cat_trajs[i]
123
- )
124
- # 生成唯一文件名(结合seed和样本索引,避免冲突)
125
- gif_filename = f"{random.randint(0, 100000)}.gif"
126
- gif_path = os.path.join(self.tmp_dir.name, gif_filename)
127
- frames[0].save(
128
- gif_path,
129
- format="GIF",
130
- append_images=frames[1:],
131
- save_all=True,
132
- duration=200,
133
- loop=0
134
- )
135
- animations.append(gif_path)
136
- return animations
137
 
138
  images = decode_images(samples)
139
- animations = decode_trajs(trajs)
140
 
141
- return images, animations
142
 
143
  if __name__ == "__main__":
144
  parser = argparse.ArgumentParser()
145
- parser.add_argument("--config", type=str, default="configs_t2i/inference_heavydecoder.yaml")
146
  parser.add_argument("--resolution", type=int, default=512)
147
  parser.add_argument("--model_id", type=str, default="MCG-NJU/PixNerd-XXL-P16-T2I")
148
  parser.add_argument("--ckpt_path", type=str, default="models")
@@ -167,6 +168,7 @@ if __name__ == "__main__":
167
  ckpt = torch.load(ckpt_path, map_location="cpu")
168
  denoiser = load_model(ckpt, denoiser)
169
  denoiser = denoiser.cuda()
 
170
  vae = vae.cuda()
171
  denoiser.eval()
172
 
@@ -179,27 +181,23 @@ if __name__ == "__main__":
179
  with gr.Column(scale=1):
180
  num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=25)
181
  guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
182
- image_height = gr.Slider(minimum=128, maximum=1024, step=32, label="image height", value=512)
183
- image_width = gr.Slider(minimum=128, maximum=1024, step=32, label="image width", value=512)
184
  label = gr.Textbox(label="positive prompt", value="a photo of a cat")
185
  seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
186
  timeshift = gr.Slider(minimum=0.1, maximum=5.0, step=0.1, label="timeshift", value=3.0)
187
  order = gr.Slider(minimum=1, maximum=4, step=1, label="order", value=2)
188
  with gr.Column(scale=2):
189
  btn = gr.Button("Generate")
190
- output_sample = gr.Gallery(label="Images", columns=2, rows=2)
191
- with gr.Column(scale=2):
192
- output_trajs = gr.Gallery(label="Trajs of Diffusion", columns=2, rows=2)
193
 
194
  btn.click(fn=pipeline,
195
  inputs=[
196
  label,
197
  seed,
198
- image_height,
199
- image_width,
200
  num_steps,
201
  guidance,
202
  timeshift,
203
  order
204
- ], outputs=[output_sample, output_trajs])
205
  demo.launch()
 
33
  # step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
34
  import random
35
  import os
36
+ import time
37
  import spaces
38
  import torch
39
  import argparse
 
70
  class Pipeline:
71
  def __init__(self, vae, denoiser, conditioner, resolution):
72
  self.vae = vae
73
+ self.denoiser = denoiser.cuda()
74
+ self.conditioner = conditioner.cuda()
75
  self.resolution = resolution
76
  self.tmp_dir = tempfile.TemporaryDirectory(prefix="traj_gifs_")
77
  # self.denoiser.compile()
 
79
  def __del__(self):
80
  self.tmp_dir.cleanup()
81
 
82
+ # @spaces.GPU
83
  @torch.no_grad()
84
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
85
+ def __call__(self, y, seed, num_steps, guidance, timeshift, order):
86
  diffusion_sampler = AdamLMSampler(
87
  order=order,
88
  scheduler=LinearScheduler(),
 
92
  timeshift=timeshift
93
  )
94
  generator = torch.Generator(device="cpu").manual_seed(seed)
95
+ xT = torch.randn((1, 3, 512, 512), device="cpu", dtype=torch.float32,
 
 
 
 
96
  generator=generator).cuda()
97
+
98
+ start = time.time()
99
  with torch.no_grad():
100
  condition, uncondition = conditioner([y,]*1)
101
+ print("conditioner:",time.time() - start)
102
 
103
+ start = time.time()
104
  # Sample images:
105
  samples, trajs = diffusion_sampler(denoiser, xT, condition, uncondition, return_x_trajs=True)
106
+ print("diffusion:",time.time() - start)
107
 
108
  def decode_images(samples):
109
  samples = vae.decode(samples)
 
115
  images.append(image)
116
  return images
117
 
118
+ # def decode_trajs(trajs):
119
+ # cat_trajs = torch.stack(trajs, dim=0).permute(1, 0, 2, 3, 4)
120
+ # animations = []
121
+ # for i in range(cat_trajs.shape[0]):
122
+ # frames = decode_images(
123
+ # cat_trajs[i]
124
+ # )
125
+ # # 生成唯一文件名(结合seed和样本索引,避免冲突)
126
+ # gif_filename = f"{random.randint(0, 100000)}.gif"
127
+ # gif_path = os.path.join(self.tmp_dir.name, gif_filename)
128
+ # frames[0].save(
129
+ # gif_path,
130
+ # format="GIF",
131
+ # append_images=frames[1:],
132
+ # save_all=True,
133
+ # duration=200,
134
+ # loop=0
135
+ # )
136
+ # animations.append(gif_path)
137
+ # return animations
138
 
139
  images = decode_images(samples)
140
+ # animations = decode_trajs(trajs)
141
 
142
+ return images[0]
143
 
144
  if __name__ == "__main__":
145
  parser = argparse.ArgumentParser()
146
+ parser.add_argument("--config", type=str, default="configs_t2i/sft_res512.yaml")
147
  parser.add_argument("--resolution", type=int, default=512)
148
  parser.add_argument("--model_id", type=str, default="MCG-NJU/PixNerd-XXL-P16-T2I")
149
  parser.add_argument("--ckpt_path", type=str, default="models")
 
168
  ckpt = torch.load(ckpt_path, map_location="cpu")
169
  denoiser = load_model(ckpt, denoiser)
170
  denoiser = denoiser.cuda()
171
+ conditioner = conditioner.cuda()
172
  vae = vae.cuda()
173
  denoiser.eval()
174
 
 
181
  with gr.Column(scale=1):
182
  num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=25)
183
  guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
 
 
184
  label = gr.Textbox(label="positive prompt", value="a photo of a cat")
185
  seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
186
  timeshift = gr.Slider(minimum=0.1, maximum=5.0, step=0.1, label="timeshift", value=3.0)
187
  order = gr.Slider(minimum=1, maximum=4, step=1, label="order", value=2)
188
  with gr.Column(scale=2):
189
  btn = gr.Button("Generate")
190
+ output_sample = gr.Image(label="Images")
191
+ # with gr.Column(scale=2):
192
+ # output_trajs = gr.Gallery(label="Trajs of Diffusion", columns=2, rows=2)
193
 
194
  btn.click(fn=pipeline,
195
  inputs=[
196
  label,
197
  seed,
 
 
198
  num_steps,
199
  guidance,
200
  timeshift,
201
  order
202
+ ], outputs=[output_sample])
203
  demo.launch()