multimodalart HF Staff commited on
Commit
8ff4968
·
verified ·
1 Parent(s): 52f499a

Dimensions

Browse files
Files changed (1) hide show
  1. app.py +78 -53
app.py CHANGED
@@ -37,6 +37,15 @@ FIXED_FPS = 24
37
  MIN_FRAMES_MODEL = 8
38
  MAX_FRAMES_MODEL = 121
39
 
 
 
 
 
 
 
 
 
 
40
  # Instantiate the pipeline in the global scope
41
  print("Initializing WanTI2V pipeline...")
42
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -55,44 +64,52 @@ pipeline = wan.WanTI2V(
55
  )
56
  print("Pipeline initialized and ready.")
57
 
58
- # --- Helper Functions ---
59
- def select_best_size_for_image(image, available_sizes):
60
- """Select the size option with aspect ratio closest to the input image."""
61
- if image is None:
62
- return available_sizes[0] # Return first option if no image
63
-
64
- img_width, img_height = image.size
65
- img_aspect_ratio = img_height / img_width
66
-
67
- best_size = available_sizes[0]
68
- best_diff = float('inf')
69
-
70
- for size_str in available_sizes:
71
- # Parse size string like "704*1280"
72
- height, width = map(int, size_str.split('*'))
73
- size_aspect_ratio = height / width
74
- diff = abs(img_aspect_ratio - size_aspect_ratio)
75
-
76
- if diff < best_diff:
77
- best_diff = diff
78
- best_size = size_str
79
 
80
- return best_size
 
81
 
82
- def handle_image_upload(image):
83
- """Handle image upload and return the best matching size."""
84
- if image is None:
85
- return gr.update()
86
 
87
- pil_image = Image.fromarray(image).convert("RGB")
88
- available_sizes = list(SUPPORTED_SIZES[TASK_NAME])
89
- best_size = select_best_size_for_image(pil_image, available_sizes)
90
 
91
- return gr.update(value=best_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def get_duration(image,
94
  prompt,
95
- size,
 
96
  duration_seconds,
97
  sampling_steps,
98
  guide_scale,
@@ -112,7 +129,8 @@ def get_duration(image,
112
  def generate_video(
113
  image,
114
  prompt,
115
- size,
 
116
  duration_seconds,
117
  sampling_steps,
118
  guide_scale,
@@ -124,21 +142,27 @@ def generate_video(
124
  if seed == -1:
125
  seed = random.randint(0, sys.maxsize)
126
 
 
 
 
 
127
  input_image = None
128
  if image is not None:
129
  input_image = Image.fromarray(image).convert("RGB")
130
- # Resize image to match selected size
131
- target_height, target_width = map(int, size.split('*'))
132
- input_image = input_image.resize((target_width, target_height))
133
 
134
  # Calculate number of frames based on duration
135
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
136
 
 
 
 
137
  video_tensor = pipeline.generate(
138
  input_prompt=prompt,
139
  img=input_image, # Pass None for T2V, Image for I2V
140
- size=SIZE_CONFIGS[size],
141
- max_area=MAX_AREA_CONFIGS[size],
142
  frame_num=num_frames, # Use calculated frames instead of cfg.frame_num
143
  shift=shift,
144
  sample_solver='unipc',
@@ -170,7 +194,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
170
 
171
  with gr.Row():
172
  with gr.Column(scale=2):
173
- image_input = gr.Image(type="numpy", label="Input Image (Optional)", elem_id="input_image")
174
  prompt_input = gr.Textbox(label="Prompt", value="A beautiful waterfall in a lush jungle, cinematic.", lines=3)
175
  duration_input = gr.Slider(
176
  minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
@@ -180,40 +204,41 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
180
  label="Duration (seconds)",
181
  info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps."
182
  )
183
- size_input = gr.Dropdown(label="Output Resolution", choices=list(SUPPORTED_SIZES[TASK_NAME]), value="704*1280")
184
- with gr.Column(scale=2):
185
- video_output = gr.Video(label="Generated Video", elem_id="output_video")
186
 
187
-
188
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
189
  steps_input = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=38, step=1)
190
  scale_input = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=cfg.sample_guide_scale, step=0.1)
191
  shift_input = gr.Slider(label="Sample Shift", minimum=1.0, maximum=20.0, value=cfg.sample_shift, step=0.1)
192
  seed_input = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
193
 
 
 
194
  run_button = gr.Button("Generate Video", variant="primary")
195
 
196
  # Add image upload handler
197
  image_input.upload(
198
- fn=handle_image_upload,
199
- inputs=[image_input],
200
- outputs=[size_input]
201
  )
202
 
203
  image_input.clear(
204
- fn=handle_image_upload,
205
- inputs=[image_input],
206
- outputs=[size_input]
207
  )
208
 
209
  example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
210
  gr.Examples(
211
  examples=[
212
- [example_image_path, "The cat removes the glasses from its eyes.", "1280*704", 1.5],
213
- [None, "A cinematic shot of a boat sailing on a calm sea at sunset.", "1280*704", 2.0],
214
- [None, "Drone footage flying over a futuristic city with flying cars.", "1280*704", 2.0],
215
  ],
216
- inputs=[image_input, prompt_input, size_input, duration_input],
217
  outputs=video_output,
218
  fn=generate_video,
219
  cache_examples=False,
@@ -221,7 +246,7 @@ with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
221
 
222
  run_button.click(
223
  fn=generate_video,
224
- inputs=[image_input, prompt_input, size_input, duration_input, steps_input, scale_input, shift_input, seed_input],
225
  outputs=video_output
226
  )
227
 
 
37
  MIN_FRAMES_MODEL = 8
38
  MAX_FRAMES_MODEL = 121
39
 
40
+ # Dimension calculation constants (from Wan 2.1 Fast demo)
41
+ MOD_VALUE = 32
42
+ DEFAULT_H_SLIDER_VALUE = 512
43
+ DEFAULT_W_SLIDER_VALUE = 896
44
+ NEW_FORMULA_MAX_AREA = 480.0 * 832.0
45
+
46
+ SLIDER_MIN_H, SLIDER_MAX_H = 128, 896
47
+ SLIDER_MIN_W, SLIDER_MAX_W = 128, 896
48
+
49
  # Instantiate the pipeline in the global scope
50
  print("Initializing WanTI2V pipeline...")
51
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
64
  )
65
  print("Pipeline initialized and ready.")
66
 
67
+ # --- Helper Functions (from Wan 2.1 Fast demo) ---
68
+ def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
69
+ min_slider_h, max_slider_h,
70
+ min_slider_w, max_slider_w,
71
+ default_h, default_w):
72
+ orig_w, orig_h = pil_image.size
73
+ if orig_w <= 0 or orig_h <= 0:
74
+ return default_h, default_w
75
+
76
+ aspect_ratio = orig_h / orig_w
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
79
+ calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
80
 
81
+ calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
82
+ calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
 
 
83
 
84
+ new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
85
+ new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
 
86
 
87
+ return new_h, new_w
88
+
89
+ def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
90
+ if uploaded_pil_image is None:
91
+ return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
92
+ try:
93
+ # Convert numpy array to PIL Image if needed
94
+ if hasattr(uploaded_pil_image, 'shape'): # numpy array
95
+ pil_image = Image.fromarray(uploaded_pil_image).convert("RGB")
96
+ else: # already PIL Image
97
+ pil_image = uploaded_pil_image
98
+
99
+ new_h, new_w = _calculate_new_dimensions_wan(
100
+ pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
101
+ SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
102
+ DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
103
+ )
104
+ return gr.update(value=new_h), gr.update(value=new_w)
105
+ except Exception as e:
106
+ gr.Warning("Error attempting to calculate new dimensions")
107
+ return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
108
 
109
  def get_duration(image,
110
  prompt,
111
+ height,
112
+ width,
113
  duration_seconds,
114
  sampling_steps,
115
  guide_scale,
 
129
  def generate_video(
130
  image,
131
  prompt,
132
+ height,
133
+ width,
134
  duration_seconds,
135
  sampling_steps,
136
  guide_scale,
 
142
  if seed == -1:
143
  seed = random.randint(0, sys.maxsize)
144
 
145
+ # Ensure dimensions are multiples of MOD_VALUE
146
+ target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
147
+ target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
148
+
149
  input_image = None
150
  if image is not None:
151
  input_image = Image.fromarray(image).convert("RGB")
152
+ # Resize image to match target dimensions
153
+ input_image = input_image.resize((target_w, target_h))
 
154
 
155
  # Calculate number of frames based on duration
156
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
157
 
158
+ # Create size string for the pipeline
159
+ size_str = f"{target_h}*{target_w}"
160
+
161
  video_tensor = pipeline.generate(
162
  input_prompt=prompt,
163
  img=input_image, # Pass None for T2V, Image for I2V
164
+ size=SIZE_CONFIGS.get(size_str, (target_h, target_w)),
165
+ max_area=MAX_AREA_CONFIGS.get(size_str, target_h * target_w),
166
  frame_num=num_frames, # Use calculated frames instead of cfg.frame_num
167
  shift=shift,
168
  sample_solver='unipc',
 
194
 
195
  with gr.Row():
196
  with gr.Column(scale=2):
197
+ image_input = gr.Image(type="numpy", label="Input Image (Optional, auto-resized to target H/W)", elem_id="input_image")
198
  prompt_input = gr.Textbox(label="Prompt", value="A beautiful waterfall in a lush jungle, cinematic.", lines=3)
199
  duration_input = gr.Slider(
200
  minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
 
204
  label="Duration (seconds)",
205
  info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps."
206
  )
 
 
 
207
 
 
208
  with gr.Accordion("Advanced Settings", open=False):
209
+ with gr.Row():
210
+ height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
211
+ width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
212
  steps_input = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=38, step=1)
213
  scale_input = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=cfg.sample_guide_scale, step=0.1)
214
  shift_input = gr.Slider(label="Sample Shift", minimum=1.0, maximum=20.0, value=cfg.sample_shift, step=0.1)
215
  seed_input = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
216
 
217
+ with gr.Column(scale=2):
218
+ video_output = gr.Video(label="Generated Video", elem_id="output_video")
219
  run_button = gr.Button("Generate Video", variant="primary")
220
 
221
  # Add image upload handler
222
  image_input.upload(
223
+ fn=handle_image_upload_for_dims_wan,
224
+ inputs=[image_input, height_input, width_input],
225
+ outputs=[height_input, width_input]
226
  )
227
 
228
  image_input.clear(
229
+ fn=handle_image_upload_for_dims_wan,
230
+ inputs=[image_input, height_input, width_input],
231
+ outputs=[height_input, width_input]
232
  )
233
 
234
  example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
235
  gr.Examples(
236
  examples=[
237
+ [example_image_path, "The cat removes the glasses from its eyes.", 704, 1280, 1.5],
238
+ [None, "A cinematic shot of a boat sailing on a calm sea at sunset.", 704, 1280, 2.0],
239
+ [None, "Drone footage flying over a futuristic city with flying cars.", 704, 1280, 2.0],
240
  ],
241
+ inputs=[image_input, prompt_input, height_input, width_input, duration_input],
242
  outputs=video_output,
243
  fn=generate_video,
244
  cache_examples=False,
 
246
 
247
  run_button.click(
248
  fn=generate_video,
249
+ inputs=[image_input, prompt_input, height_input, width_input, duration_input, steps_input, scale_input, shift_input, seed_input],
250
  outputs=video_output
251
  )
252