Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,60 +1,82 @@
|
|
1 |
import os
|
2 |
-
import gradio as gr
|
3 |
-
import subprocess
|
4 |
import shutil
|
5 |
-
import
|
|
|
6 |
from huggingface_hub import snapshot_download
|
7 |
|
8 |
-
# Paths
|
9 |
-
PUSA_REPO = "./PusaV1"
|
10 |
-
PUSA_SCRIPT = os.path.join(PUSA_REPO, "examples/pusavideo/wan_14b_text_to_video_pusa.py")
|
11 |
-
MODEL_DIR = "./model_zoo/PusaV1"
|
12 |
-
MODEL_PATH = os.path.join(MODEL_DIR, "pusa_v1.pt")
|
13 |
-
OUTPUT_VIDEO_PATH = os.path.join(tempfile.gettempdir(), "output.mp4")
|
14 |
-
|
15 |
-
def setup_dependencies():
|
16 |
-
subprocess.run(["pip", "install", "xfuser>=0.4.3", "absl-py", "peft", "lightning", "pandas", "deepspeed", "wandb", "av"])
|
17 |
-
subprocess.run(
|
18 |
-
'pip install flash-attn --no-build-isolation',
|
19 |
-
shell=True,
|
20 |
-
env={**os.environ, "FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}
|
21 |
-
)
|
22 |
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
return
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
for part in part_files:
|
31 |
-
with open(
|
32 |
-
shutil.copyfileobj(
|
|
|
|
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
38 |
"--prompt", prompt,
|
39 |
-
"--lora_path",
|
|
|
40 |
]
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
|
|
|
|
2 |
import shutil
|
3 |
+
import subprocess
|
4 |
+
import gradio as gr
|
5 |
from huggingface_hub import snapshot_download
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
MODEL_SUBFOLDER = "Wan2.1-T2V-14B"
|
9 |
+
HF_REPO = "RaphaelLiu/PusaV1"
|
10 |
+
MODEL_ZOO_DIR = "./model_zoo"
|
11 |
+
MODEL_PARTS_DIR = os.path.join(MODEL_ZOO_DIR, MODEL_SUBFOLDER)
|
12 |
+
FINAL_MODEL_PATH = os.path.join(MODEL_ZOO_DIR, "PusaV1", "pusa_v1.pt")
|
13 |
+
PUSA_SCRIPT_PATH = "PusaV1/examples/pusavideo/wan_14b_text_to_video_pusa.py"
|
14 |
+
|
15 |
+
|
16 |
+
def download_model_subset():
|
17 |
+
if os.path.exists(FINAL_MODEL_PATH):
|
18 |
+
print("β
Model already exists. Skipping download.")
|
19 |
return
|
20 |
+
|
21 |
+
print("β¬ Downloading model parts...")
|
22 |
+
snapshot_download(
|
23 |
+
repo_id=HF_REPO,
|
24 |
+
repo_type="model",
|
25 |
+
local_dir=MODEL_ZOO_DIR,
|
26 |
+
local_dir_use_symlinks=False,
|
27 |
+
allow_patterns=[f"{MODEL_SUBFOLDER}/*"]
|
28 |
+
)
|
29 |
+
os.makedirs(os.path.dirname(FINAL_MODEL_PATH), exist_ok=True)
|
30 |
+
|
31 |
+
part_files = sorted([
|
32 |
+
os.path.join(MODEL_PARTS_DIR, f)
|
33 |
+
for f in os.listdir(MODEL_PARTS_DIR)
|
34 |
+
if f.startswith("pusa_v1.pt.part")
|
35 |
+
])
|
36 |
+
|
37 |
+
print("π§© Stitching model parts...")
|
38 |
+
with open(FINAL_MODEL_PATH, 'wb') as f_out:
|
39 |
for part in part_files:
|
40 |
+
with open(part, 'rb') as f_in:
|
41 |
+
shutil.copyfileobj(f_in, f_out)
|
42 |
+
|
43 |
+
print(f"β
Final model saved at {FINAL_MODEL_PATH}")
|
44 |
|
45 |
+
|
46 |
+
def generate_video(prompt):
|
47 |
+
download_model_subset()
|
48 |
+
|
49 |
+
temp_output_dir = "/tmp/pusa_video_output"
|
50 |
+
os.makedirs(temp_output_dir, exist_ok=True)
|
51 |
+
|
52 |
+
command = [
|
53 |
+
"python", PUSA_SCRIPT_PATH,
|
54 |
"--prompt", prompt,
|
55 |
+
"--lora_path", FINAL_MODEL_PATH,
|
56 |
+
"--output_dir", temp_output_dir
|
57 |
]
|
58 |
+
|
59 |
+
try:
|
60 |
+
print("π Running inference...")
|
61 |
+
subprocess.run(command, check=True)
|
62 |
+
|
63 |
+
# Return first mp4 video found
|
64 |
+
for file in os.listdir(temp_output_dir):
|
65 |
+
if file.endswith(".mp4"):
|
66 |
+
return os.path.join(temp_output_dir, file)
|
67 |
+
|
68 |
+
return "β No video generated."
|
69 |
+
|
70 |
+
except subprocess.CalledProcessError as e:
|
71 |
+
return f"β Inference failed: {str(e)}"
|
72 |
+
|
73 |
+
|
74 |
+
with gr.Blocks() as demo:
|
75 |
+
gr.Markdown("## π§ββοΈ PusaV1 Text-to-Video Generator (Wan2.1-T2V-14B)")
|
76 |
+
prompt_input = gr.Textbox(label="Enter your prompt", lines=4, placeholder="A coral reef full of colorful fish...")
|
77 |
+
generate_button = gr.Button("Generate Video")
|
78 |
+
video_output = gr.Video(label="Generated Video")
|
79 |
+
|
80 |
+
generate_button.click(fn=generate_video, inputs=prompt_input, outputs=video_output)
|
81 |
+
|
82 |
+
demo.launch()
|