rahul7star commited on
Commit
d54a2c6
·
verified ·
1 Parent(s): e852cbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -54
app.py CHANGED
@@ -1,75 +1,66 @@
1
- import os
2
- import sys
3
  import gradio as gr
 
4
  import tempfile
5
  from huggingface_hub import snapshot_download
 
6
  import spaces
7
 
 
 
 
 
 
 
8
 
9
- from diffsynth import ModelManager, WanVideoPusaPipeline, save_video, VideoData
10
-
11
- WAN_MODEL_DIR = "/tmp/model_zoo/Wan2.1-T2V-14B"
12
- LORA_DIR = "/tmp/model_zoo/PusaV1"
13
- LORA_PATH = os.path.join(LORA_DIR, "pusa_v1.pt")
14
-
15
- @spaces.GPU
16
- def generate_video(prompt, lora_upload):
17
- # Download Wan2.1 model only if missing
18
- if not os.path.exists(WAN_MODEL_DIR):
19
  snapshot_download(
20
- repo_id="RaphaelLiu/PusaV1",
21
- allow_patterns=["Wan2.1-T2V-14B/*"],
22
- local_dir=WAN_MODEL_DIR,
 
23
  local_dir_use_symlinks=False,
24
- resume_download=True,
25
  )
 
26
 
27
- # Handle LoRA file (upload or default download + stitch)
28
- if lora_upload is not None:
29
- lora_path = lora_upload
30
- else:
31
- if not os.path.exists(LORA_PATH):
32
- os.makedirs(LORA_DIR, exist_ok=True)
33
- snapshot_download(
34
- repo_id="RaphaelLiu/PusaV1",
35
- allow_patterns=["PusaV1/pusa_v1.pt.part*"],
36
- local_dir=LORA_DIR,
37
- local_dir_use_symlinks=False,
38
- )
39
- # Stitch parts
40
- part_files = sorted(
41
- f for f in os.listdir(LORA_DIR) if f.startswith("pusa_v1.pt.part")
42
- )
43
- with open(LORA_PATH, "wb") as wfd:
44
- for part in part_files:
45
- with open(os.path.join(LORA_DIR, part), "rb") as fd:
46
- wfd.write(fd.read())
47
 
48
- lora_path = LORA_PATH
 
 
49
 
50
- # Load model and pipeline
51
- manager = ModelManager(pretrained_model_dir=WAN_MODEL_DIR)
52
- pipe = WanVideoPusaPipeline(model_manager=manager)
53
- pipe.set_lora_adapters(lora_path)
54
 
55
- # Run generation
56
- result: VideoData = pipe(prompt)
57
 
58
- # Save video to temp file
59
  tmp_dir = tempfile.mkdtemp()
60
- video_path = os.path.join(tmp_dir, "output.mp4")
61
- save_video(result.frames, video_path, fps=8)
62
-
63
- return video_path
64
 
 
65
 
 
66
  with gr.Blocks() as demo:
67
- gr.Markdown("# 🎥 Pusa Text-to-Video (Wan2.1-T2V-14B)")
68
- prompt = gr.Textbox(label="Prompt", lines=4)
69
- lora_file = gr.File(label="Upload LoRA .pt (optional)", file_types=[".pt"])
70
- generate_btn = gr.Button("Generate")
71
- output_video = gr.Video(label="Generated Video")
 
 
 
 
 
72
 
73
- generate_btn.click(fn=generate_video, inputs=[prompt, lora_file], outputs=output_video)
74
 
75
  demo.launch()
 
 
 
1
  import gradio as gr
2
+ import os
3
  import tempfile
4
  from huggingface_hub import snapshot_download
5
+ from diffsynth import ModelManager, WanVideoPusaPipeline, save_video
6
  import spaces
7
 
8
+ # Constants
9
+ WAN_SUBFOLDER = "Wan2.1-T2V-14B"
10
+ MODEL_REPO_ID = "RaphaelLiu/PusaV1"
11
+ MODEL_ZOO_DIR = "./model_zoo"
12
+ WAN_MODEL_PATH = os.path.join(MODEL_ZOO_DIR, WAN_SUBFOLDER)
13
+ LORA_PATH = os.path.join(MODEL_ZOO_DIR, "PusaV1", "pusa_v1.pt")
14
 
15
+ # Ensure model is downloaded
16
+ def ensure_model_downloaded():
17
+ if not os.path.exists(WAN_MODEL_PATH):
18
+ print("Downloading Wan2.1-T2V-14B from HuggingFace Hub...")
 
 
 
 
 
 
19
  snapshot_download(
20
+ repo_id=MODEL_REPO_ID,
21
+ local_dir=MODEL_ZOO_DIR,
22
+ repo_type="model",
23
+ allow_patterns=[f"{WAN_SUBFOLDER}/**"],
24
  local_dir_use_symlinks=False,
 
25
  )
26
+ print("Model downloaded.")
27
 
28
+ # Video generation logic
29
+ @spaces.GPU
30
+ def generate_video(prompt: str):
31
+ ensure_model_downloaded()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # Load model
34
+ manager = ModelManager(pretrained_model_dir=WAN_MODEL_PATH)
35
+ model = manager.load_model()
36
 
37
+ # Set up pipeline
38
+ pipeline = WanVideoPusaPipeline(model=model)
39
+ pipeline.set_lora_adapters(LORA_PATH)
 
40
 
41
+ # Generate video
42
+ result = pipeline(prompt)
43
 
44
+ # Save video
45
  tmp_dir = tempfile.mkdtemp()
46
+ output_path = os.path.join(tmp_dir, "video.mp4")
47
+ save_video(result.frames, output_path, fps=8)
 
 
48
 
49
+ return output_path
50
 
51
+ # Gradio UI
52
  with gr.Blocks() as demo:
53
+ gr.Markdown("## 🎥 Wan2.1-T2V-14B with Pusa LoRA | Text-to-Video Generator")
54
+
55
+ prompt_input = gr.Textbox(
56
+ lines=4,
57
+ label="Prompt",
58
+ placeholder="Describe your video (e.g. A coral reef full of colorful fish...)"
59
+ )
60
+
61
+ generate_btn = gr.Button("Generate Video")
62
+ video_output = gr.Video(label="Output")
63
 
64
+ generate_btn.click(fn=generate_video, inputs=prompt_input, outputs=video_output)
65
 
66
  demo.launch()