rahul7star commited on
Commit
bb6befb
·
verified ·
1 Parent(s): 362d099

Update app_t2v.py

Browse files
Files changed (1) hide show
  1. app_t2v.py +265 -63
app_t2v.py CHANGED
@@ -1,76 +1,278 @@
1
- import spaces
 
 
 
 
2
  import gradio as gr
3
  import torch
4
- from diffusers import WanPipeline, AutoencoderKLWan
5
- from diffusers.utils import export_to_video
 
 
 
6
 
7
- import os
8
- import uuid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Load model once at startup
11
- dtype = torch.bfloat16
12
  # Instantiate the pipeline in the global scope
13
  print("Initializing WanTI2V pipeline...")
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  device_id = 0 if torch.cuda.is_available() else -1
16
 
17
- vae = AutoencoderKLWan.from_pretrained(
18
- "Wan-AI/Wan2.2-T2V-A14B-Diffusers", subfolder="vae", torch_dtype=torch.float32
19
- )
20
- pipe = WanPipeline.from_pretrained(
21
- "Wan-AI/Wan2.2-T2V-A14B-Diffusers", vae=vae, torch_dtype=dtype
22
- )
23
- pipe.to(device)
24
-
25
- # Constants
26
- HEIGHT = 720
27
- WIDTH = 1280
28
- FPS = 16
29
- DEFAULT_NEGATIVE_PROMPT = (
30
- "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,"
31
- "最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,"
32
- "画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,"
33
- "杂乱的背景,三条腿,背景人很多,倒着走"
34
- )
35
 
36
- # Inference function
37
- @spaces.GPU(duration=150)
38
- def generate_video(prompt, num_frames=81, steps=40, scale=4.0, scale2=3.0):
39
- if not prompt.strip():
40
- return "Prompt is empty.", None
41
-
42
- output = pipe(
43
- prompt=prompt,
44
- negative_prompt=DEFAULT_NEGATIVE_PROMPT,
45
- height=HEIGHT,
46
- width=WIDTH,
47
- num_frames=num_frames,
48
- guidance_scale=scale,
49
- guidance_scale_2=scale2,
50
- num_inference_steps=steps,
51
- ).frames[0]
52
-
53
- filename = f"/tmp/t2v_{uuid.uuid4().hex}.mp4"
54
- export_to_video(output, filename, fps=FPS)
55
- return f"Generated {num_frames} frames at {FPS} FPS.", filename
56
-
57
- # Gradio UI
58
- iface = gr.Interface(
59
- fn=generate_video,
60
- inputs=[
61
- gr.Textbox(label="Prompt", placeholder="Describe your video scene here..."),
62
- gr.Slider(16, 81, value=81, step=1, label="Number of Frames"),
63
- gr.Slider(10, 60, value=40, step=1, label="Inference Steps"),
64
- gr.Slider(1.0, 8.0, value=4.0, step=0.1, label="Guidance Scale"),
65
- gr.Slider(1.0, 8.0, value=3.0, step=0.1, label="Guidance Scale 2"),
66
- ],
67
- outputs=[
68
- gr.Textbox(label="Status"),
69
- gr.Video(label="Generated Video"),
70
- ],
71
- title="🧠 Wan2.2 Text-to-Video Generator",
72
- description="Enter a scene description and generate a video using the Wan2.2 T2V model.",
73
  )
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  if __name__ == "__main__":
76
- iface.launch()
 
1
+ import os
2
+ import sys
3
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
4
+
5
+ # wan2.2-main/gradio_ti2v.py
6
  import gradio as gr
7
  import torch
8
+ from huggingface_hub import snapshot_download
9
+ from PIL import Image
10
+ import random
11
+ import numpy as np
12
+ import spaces
13
 
14
+ import wan
15
+ from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
16
+ from wan.utils.utils import cache_video
17
+
18
+ import gc
19
+
20
+ # --- 1. Global Setup and Model Loading ---
21
+
22
+ print("Starting Gradio App for Wan 2.2 TI2V-5B...")
23
+
24
+ # Download model snapshots from Hugging Face Hub
25
+ repo_id = "Wan-AI/Wan2.2-TI2V-5B"
26
+ print(f"Downloading/loading checkpoints for {repo_id}...")
27
+ ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
28
+ print(f"Using checkpoints from {ckpt_dir}")
29
+
30
+ # Load the model configuration
31
+ TASK_NAME = 'ti2v-5B'
32
+ cfg = WAN_CONFIGS[TASK_NAME]
33
+ FIXED_FPS = 24
34
+ MIN_FRAMES_MODEL = 8
35
+ MAX_FRAMES_MODEL = 121
36
+
37
+ # Dimension calculation constants
38
+ MOD_VALUE = 32
39
+ DEFAULT_H_SLIDER_VALUE = 704
40
+ DEFAULT_W_SLIDER_VALUE = 1280
41
+ NEW_FORMULA_MAX_AREA = 1280.0 * 704.0
42
+
43
+ SLIDER_MIN_H, SLIDER_MAX_H = 128, 1280
44
+ SLIDER_MIN_W, SLIDER_MAX_W = 128, 1280
45
 
 
 
46
  # Instantiate the pipeline in the global scope
47
  print("Initializing WanTI2V pipeline...")
48
  device = "cuda" if torch.cuda.is_available() else "cpu"
49
  device_id = 0 if torch.cuda.is_available() else -1
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+
53
+
54
+ #lora
55
+ LORA_REPO_ID = "Kijai/WanVideo_comfy"
56
+ LORA_FILENAME = "Lightx2v/lightx2v_T2V_14B_cfg_step_distill_v2_lora_rank256_bf16.safetensors"
57
+
58
+
59
+
60
+
61
+
62
+
63
+ pipeline = wan.WanTI2V(
64
+ config=cfg,
65
+ checkpoint_dir=ckpt_dir,
66
+ device_id=device_id,
67
+ rank=0,
68
+ t5_fsdp=False,
69
+ dit_fsdp=False,
70
+ use_sp=False,
71
+ t5_cpu=False,
72
+ init_on_cpu=False,
73
+ convert_model_dtype=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  )
75
 
76
+
77
+
78
+
79
+ causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
80
+ pipeline.load_lora_weights(causvid_path, adapter_name="causvid_lora")
81
+ pipeline.set_adapters(["causvid_lora"], adapter_weights=[0.95])
82
+ pipeline.fuse_lora()
83
+
84
+
85
+
86
+
87
+ print("Pipeline initialized and ready.")
88
+
89
+ # --- Helper Functions (from Wan 2.1 Fast demo) ---
90
+ def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area,
91
+ min_slider_h, max_slider_h,
92
+ min_slider_w, max_slider_w,
93
+ default_h, default_w):
94
+ orig_w, orig_h = pil_image.size
95
+ if orig_w <= 0 or orig_h <= 0:
96
+ return default_h, default_w
97
+
98
+ aspect_ratio = orig_h / orig_w
99
+
100
+ calc_h = round(np.sqrt(calculation_max_area * aspect_ratio))
101
+ calc_w = round(np.sqrt(calculation_max_area / aspect_ratio))
102
+
103
+ calc_h = max(mod_val, (calc_h // mod_val) * mod_val)
104
+ calc_w = max(mod_val, (calc_w // mod_val) * mod_val)
105
+
106
+ new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val))
107
+ new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val))
108
+
109
+ return new_h, new_w
110
+
111
+ def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val):
112
+ if uploaded_pil_image is None:
113
+ return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
114
+ try:
115
+ # Convert numpy array to PIL Image if needed
116
+ if hasattr(uploaded_pil_image, 'shape'): # numpy array
117
+ pil_image = Image.fromarray(uploaded_pil_image).convert("RGB")
118
+ else: # already PIL Image
119
+ pil_image = uploaded_pil_image
120
+
121
+ new_h, new_w = _calculate_new_dimensions_wan(
122
+ pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA,
123
+ SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W,
124
+ DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
125
+ )
126
+ return gr.update(value=new_h), gr.update(value=new_w)
127
+ except Exception as e:
128
+ gr.Warning("Error attempting to calculate new dimensions")
129
+ return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
130
+
131
+ def get_duration(image,
132
+ prompt,
133
+ height,
134
+ width,
135
+ duration_seconds,
136
+ sampling_steps,
137
+ guide_scale,
138
+ shift,
139
+ seed,
140
+ progress):
141
+ """Calculate dynamic GPU duration based on parameters."""
142
+ if duration_seconds >= 3:
143
+ return 220
144
+ elif sampling_steps > 35 and duration_seconds >= 2:
145
+ return 180
146
+ elif sampling_steps < 35 or duration_seconds < 2:
147
+ return 105
148
+ else:
149
+ return 90
150
+
151
+ # --- 2. Gradio Inference Function ---
152
+ @spaces.GPU(duration=get_duration)
153
+ def generate_video(
154
+ image,
155
+ prompt,
156
+ height,
157
+ width,
158
+ duration_seconds,
159
+ sampling_steps=38,
160
+ guide_scale=cfg.sample_guide_scale,
161
+ shift=cfg.sample_shift,
162
+ seed=42,
163
+ progress=gr.Progress(track_tqdm=True)
164
+ ):
165
+ """The main function to generate video, called by the Gradio interface."""
166
+ if seed == -1:
167
+ seed = random.randint(0, sys.maxsize)
168
+
169
+ # Ensure dimensions are multiples of MOD_VALUE
170
+ target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
171
+ target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
172
+
173
+ input_image = None
174
+ if image is not None:
175
+ input_image = Image.fromarray(image).convert("RGB")
176
+ # Resize image to match target dimensions
177
+ input_image = input_image.resize((target_w, target_h))
178
+
179
+ # Calculate number of frames based on duration
180
+ num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
181
+
182
+ # Create size string for the pipeline
183
+ size_str = f"{target_h}*{target_w}"
184
+
185
+ video_tensor = pipeline.generate(
186
+ input_prompt=prompt,
187
+ img=input_image, # Pass None for T2V, Image for I2V
188
+ size=SIZE_CONFIGS.get(size_str, (target_h, target_w)),
189
+ max_area=MAX_AREA_CONFIGS.get(size_str, target_h * target_w),
190
+ frame_num=num_frames, # Use calculated frames instead of cfg.frame_num
191
+ shift=shift,
192
+ sample_solver='unipc',
193
+ sampling_steps=int(sampling_steps),
194
+ guide_scale=guide_scale,
195
+ seed=seed,
196
+ offload_model=True
197
+ )
198
+
199
+ # Save the video to a temporary file
200
+ video_path = cache_video(
201
+ tensor=video_tensor[None], # Add a batch dimension
202
+ save_file=None, # cache_video will create a temp file
203
+ fps=cfg.sample_fps,
204
+ normalize=True,
205
+ value_range=(-1, 1)
206
+ )
207
+ del video_tensor
208
+ gc.collect()
209
+ return video_path
210
+
211
+
212
+ # --- 3. Gradio Interface ---
213
+ css = ".gradio-container {max-width: 1100px !important; margin: 0 auto} #output_video {height: 500px;} #input_image {height: 500px;}"
214
+
215
+ with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
216
+ gr.Markdown("# Wan 2.2 TI2V 5B")
217
+ gr.Markdown("generate high quality videos using **Wan 2.2 5B Text-Image-to-Video model**,[[model]](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B),[[paper]](https://arxiv.org/abs/2503.20314)")
218
+
219
+ with gr.Row():
220
+ with gr.Column(scale=2):
221
+ image_input = gr.Image(type="numpy", label="Optional (blank = text-to-image)", elem_id="input_image")
222
+ prompt_input = gr.Textbox(label="Prompt", value="A beautiful waterfall in a lush jungle, cinematic.", lines=3)
223
+ duration_input = gr.Slider(
224
+ minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
225
+ maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1),
226
+ step=0.1,
227
+ value=2.0,
228
+ label="Duration (seconds)",
229
+ info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps."
230
+ )
231
+
232
+ with gr.Accordion("Advanced Settings", open=False):
233
+ with gr.Row():
234
+ height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})")
235
+ width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})")
236
+ steps_input = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=38, step=1)
237
+ scale_input = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=cfg.sample_guide_scale, step=0.1)
238
+ shift_input = gr.Slider(label="Sample Shift", minimum=1.0, maximum=20.0, value=cfg.sample_shift, step=0.1)
239
+ seed_input = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
240
+
241
+ with gr.Column(scale=2):
242
+ video_output = gr.Video(label="Generated Video", elem_id="output_video")
243
+ run_button = gr.Button("Generate Video", variant="primary")
244
+
245
+ # Add image upload handler
246
+ image_input.upload(
247
+ fn=handle_image_upload_for_dims_wan,
248
+ inputs=[image_input, height_input, width_input],
249
+ outputs=[height_input, width_input]
250
+ )
251
+
252
+ image_input.clear(
253
+ fn=handle_image_upload_for_dims_wan,
254
+ inputs=[image_input, height_input, width_input],
255
+ outputs=[height_input, width_input]
256
+ )
257
+
258
+ example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
259
+ gr.Examples(
260
+ examples=[
261
+ [example_image_path, "The cat removes the glasses from its eyes.", 1088, 800, 1.5],
262
+ [None, "A cinematic shot of a boat sailing on a calm sea at sunset.", 704, 1280, 2.0],
263
+ [None, "Drone footage flying over a futuristic city with flying cars.", 704, 1280, 2.0],
264
+ ],
265
+ inputs=[image_input, prompt_input, height_input, width_input, duration_input],
266
+ outputs=video_output,
267
+ fn=generate_video,
268
+ cache_examples="lazy",
269
+ )
270
+
271
+ run_button.click(
272
+ fn=generate_video,
273
+ inputs=[image_input, prompt_input, height_input, width_input, duration_input, steps_input, scale_input, shift_input, seed_input],
274
+ outputs=video_output
275
+ )
276
+
277
  if __name__ == "__main__":
278
+ demo.launch()