rahul7star commited on
Commit
17d4813
Β·
verified Β·
1 Parent(s): 079f8aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -50
app.py CHANGED
@@ -1,60 +1,82 @@
1
  import os
2
- import gradio as gr
3
- import subprocess
4
  import shutil
5
- import tempfile
 
6
  from huggingface_hub import snapshot_download
7
 
8
- # Paths
9
- PUSA_REPO = "./PusaV1"
10
- PUSA_SCRIPT = os.path.join(PUSA_REPO, "examples/pusavideo/wan_14b_text_to_video_pusa.py")
11
- MODEL_DIR = "./model_zoo/PusaV1"
12
- MODEL_PATH = os.path.join(MODEL_DIR, "pusa_v1.pt")
13
- OUTPUT_VIDEO_PATH = os.path.join(tempfile.gettempdir(), "output.mp4")
14
-
15
- def setup_dependencies():
16
- subprocess.run(["pip", "install", "xfuser>=0.4.3", "absl-py", "peft", "lightning", "pandas", "deepspeed", "wandb", "av"])
17
- subprocess.run(
18
- 'pip install flash-attn --no-build-isolation',
19
- shell=True,
20
- env={**os.environ, "FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}
21
- )
22
 
23
- def download_and_stitch_model():
24
- if os.path.exists(MODEL_PATH):
 
 
 
 
 
 
 
 
 
25
  return
26
- os.makedirs(MODEL_DIR, exist_ok=True)
27
- snapshot_download("RaphaelLiu/PusaV1", local_dir=MODEL_DIR, local_dir_use_symlinks=False)
28
- part_files = sorted([f for f in os.listdir(MODEL_DIR) if f.startswith("pusa_v1.pt.part")])
29
- with open(MODEL_PATH, 'wb') as outfile:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  for part in part_files:
31
- with open(os.path.join(MODEL_DIR, part), 'rb') as infile:
32
- shutil.copyfileobj(infile, outfile)
 
 
33
 
34
- def generate_video_from_text(prompt: str):
35
- download_and_stitch_model()
36
- cmd = [
37
- "python", PUSA_SCRIPT,
 
 
 
 
 
38
  "--prompt", prompt,
39
- "--lora_path", MODEL_PATH
 
40
  ]
41
- result = subprocess.run(cmd, capture_output=True, text=True)
42
- if os.path.exists(OUTPUT_VIDEO_PATH):
43
- return OUTPUT_VIDEO_PATH
44
- return f"Video generation failed:\n{result.stderr}"
45
-
46
- def main():
47
- setup_dependencies()
48
- download_and_stitch_model()
49
-
50
- with gr.Blocks() as demo:
51
- gr.Markdown("## πŸŽ₯ PusaV1 - Text to Video Generator")
52
- prompt = gr.Textbox(lines=6, label="Enter your scene prompt")
53
- output = gr.Video(label="Generated Video")
54
- run = gr.Button("Generate Video")
55
- run.click(generate_video_from_text, inputs=prompt, outputs=output)
56
-
57
- demo.launch()
58
-
59
- if __name__ == "__main__":
60
- main()
 
 
 
 
 
 
1
  import os
 
 
2
  import shutil
3
+ import subprocess
4
+ import gradio as gr
5
  from huggingface_hub import snapshot_download
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ MODEL_SUBFOLDER = "Wan2.1-T2V-14B"
9
+ HF_REPO = "RaphaelLiu/PusaV1"
10
+ MODEL_ZOO_DIR = "./model_zoo"
11
+ MODEL_PARTS_DIR = os.path.join(MODEL_ZOO_DIR, MODEL_SUBFOLDER)
12
+ FINAL_MODEL_PATH = os.path.join(MODEL_ZOO_DIR, "PusaV1", "pusa_v1.pt")
13
+ PUSA_SCRIPT_PATH = "PusaV1/examples/pusavideo/wan_14b_text_to_video_pusa.py"
14
+
15
+
16
+ def download_model_subset():
17
+ if os.path.exists(FINAL_MODEL_PATH):
18
+ print("βœ… Model already exists. Skipping download.")
19
  return
20
+
21
+ print("⏬ Downloading model parts...")
22
+ snapshot_download(
23
+ repo_id=HF_REPO,
24
+ repo_type="model",
25
+ local_dir=MODEL_ZOO_DIR,
26
+ local_dir_use_symlinks=False,
27
+ allow_patterns=[f"{MODEL_SUBFOLDER}/*"]
28
+ )
29
+ os.makedirs(os.path.dirname(FINAL_MODEL_PATH), exist_ok=True)
30
+
31
+ part_files = sorted([
32
+ os.path.join(MODEL_PARTS_DIR, f)
33
+ for f in os.listdir(MODEL_PARTS_DIR)
34
+ if f.startswith("pusa_v1.pt.part")
35
+ ])
36
+
37
+ print("🧩 Stitching model parts...")
38
+ with open(FINAL_MODEL_PATH, 'wb') as f_out:
39
  for part in part_files:
40
+ with open(part, 'rb') as f_in:
41
+ shutil.copyfileobj(f_in, f_out)
42
+
43
+ print(f"βœ… Final model saved at {FINAL_MODEL_PATH}")
44
 
45
+
46
+ def generate_video(prompt):
47
+ download_model_subset()
48
+
49
+ temp_output_dir = "/tmp/pusa_video_output"
50
+ os.makedirs(temp_output_dir, exist_ok=True)
51
+
52
+ command = [
53
+ "python", PUSA_SCRIPT_PATH,
54
  "--prompt", prompt,
55
+ "--lora_path", FINAL_MODEL_PATH,
56
+ "--output_dir", temp_output_dir
57
  ]
58
+
59
+ try:
60
+ print("πŸš€ Running inference...")
61
+ subprocess.run(command, check=True)
62
+
63
+ # Return first mp4 video found
64
+ for file in os.listdir(temp_output_dir):
65
+ if file.endswith(".mp4"):
66
+ return os.path.join(temp_output_dir, file)
67
+
68
+ return "❌ No video generated."
69
+
70
+ except subprocess.CalledProcessError as e:
71
+ return f"❌ Inference failed: {str(e)}"
72
+
73
+
74
+ with gr.Blocks() as demo:
75
+ gr.Markdown("## πŸ§˜β€β™‚οΈ PusaV1 Text-to-Video Generator (Wan2.1-T2V-14B)")
76
+ prompt_input = gr.Textbox(label="Enter your prompt", lines=4, placeholder="A coral reef full of colorful fish...")
77
+ generate_button = gr.Button("Generate Video")
78
+ video_output = gr.Video(label="Generated Video")
79
+
80
+ generate_button.click(fn=generate_video, inputs=prompt_input, outputs=video_output)
81
+
82
+ demo.launch()