wangsssssss commited on
Commit
763f919
·
verified ·
1 Parent(s): b1b9957

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -23
app.py CHANGED
@@ -82,26 +82,26 @@ class Pipeline:
82
  images.append(image)
83
  return images
84
 
85
- # def decode_trajs(trajs):
86
- # cat_trajs = torch.stack(trajs, dim=0).permute(1, 0, 2, 3, 4)
87
- # animations = []
88
- # for i in range(cat_trajs.shape[0]):
89
- # frames = decode_images(
90
- # cat_trajs[i]
91
- # )
92
- # # 生成唯一文件名(结合seed和样本索引,避免冲突)
93
- # gif_filename = f"{random.randint(0, 100000)}.gif"
94
- # gif_path = os.path.join(self.tmp_dir.name, gif_filename)
95
- # frames[0].save(
96
- # gif_path,
97
- # format="GIF",
98
- # append_images=frames[1:],
99
- # save_all=True,
100
- # duration=200,
101
- # loop=0
102
- # )
103
- # animations.append(gif_path)
104
- # return animations
105
 
106
  images = decode_images(samples)
107
  # animations = decode_trajs(trajs)
@@ -155,8 +155,8 @@ if __name__ == "__main__":
155
  with gr.Column(scale=1):
156
  btn = gr.Button("Generate")
157
  output_sample = gr.Image(label="Images")
158
- # with gr.Column(scale=2):
159
- # output_trajs = gr.Gallery(label="Trajs of Diffusion", columns=2, rows=2)
160
 
161
  btn.click(fn=pipeline,
162
  inputs=[
@@ -166,5 +166,5 @@ if __name__ == "__main__":
166
  guidance,
167
  timeshift,
168
  order
169
- ], outputs=[output_sample])
170
  demo.launch()
 
82
  images.append(image)
83
  return images
84
 
85
+ def decode_trajs(trajs):
86
+ cat_trajs = torch.stack(trajs, dim=0).permute(1, 0, 2, 3, 4)
87
+ animations = []
88
+ for i in range(cat_trajs.shape[0]):
89
+ frames = decode_images(
90
+ cat_trajs[i]
91
+ )
92
+ # 生成唯一文件名(结合seed和样本索引,避免冲突)
93
+ gif_filename = f"{random.randint(0, 100000)}.gif"
94
+ gif_path = os.path.join(self.tmp_dir.name, gif_filename)
95
+ frames[0].save(
96
+ gif_path,
97
+ format="GIF",
98
+ append_images=frames[1:],
99
+ save_all=True,
100
+ duration=200,
101
+ loop=0
102
+ )
103
+ animations.append(gif_path)
104
+ return animations
105
 
106
  images = decode_images(samples)
107
  # animations = decode_trajs(trajs)
 
155
  with gr.Column(scale=1):
156
  btn = gr.Button("Generate")
157
  output_sample = gr.Image(label="Images")
158
+ with gr.Column(scale=2):
159
+ output_trajs = gr.Gallery(label="Trajs of Diffusion", columns=2, rows=2)
160
 
161
  btn.click(fn=pipeline,
162
  inputs=[
 
166
  guidance,
167
  timeshift,
168
  order
169
+ ], outputs=[output_sample, output_trajs])
170
  demo.launch()