PusaV1 / app.py
rahul7star's picture
Update app.py
985cc0d verified
raw
history blame
2.42 kB
import os
import sys
import gradio as gr
import tempfile
from huggingface_hub import snapshot_download
import spaces
from diffsynth import ModelManager, WanVideoPusaPipeline, save_video, VideoData
WAN_MODEL_DIR = "/tmp/model_zoo/Wan2.1-T2V-14B"
LORA_DIR = "/tmp/model_zoo/PusaV1"
LORA_PATH = os.path.join(LORA_DIR, "pusa_v1.pt")
@spaces.GPU
def generate_video(prompt, lora_upload):
# Download Wan2.1 model only if missing
if not os.path.exists(WAN_MODEL_DIR):
snapshot_download(
repo_id="RaphaelLiu/PusaV1",
allow_patterns=["Wan2.1-T2V-14B/*"],
local_dir=WAN_MODEL_DIR,
local_dir_use_symlinks=False,
resume_download=True,
)
# Handle LoRA file (upload or default download + stitch)
if lora_upload is not None:
lora_path = lora_upload
else:
if not os.path.exists(LORA_PATH):
os.makedirs(LORA_DIR, exist_ok=True)
snapshot_download(
repo_id="RaphaelLiu/PusaV1",
allow_patterns=["PusaV1/pusa_v1.pt.part*"],
local_dir=LORA_DIR,
local_dir_use_symlinks=False,
)
# Stitch parts
part_files = sorted(
f for f in os.listdir(LORA_DIR) if f.startswith("pusa_v1.pt.part")
)
with open(LORA_PATH, "wb") as wfd:
for part in part_files:
with open(os.path.join(LORA_DIR, part), "rb") as fd:
wfd.write(fd.read())
lora_path = LORA_PATH
# Load model and pipeline
manager = ModelManager(pretrained_model_dir=WAN_MODEL_DIR)
pipe = WanVideoPusaPipeline(model_manager=manager)
pipe.set_lora_adapters(lora_path)
# Run generation
result: VideoData = pipe(prompt)
# Save video to temp file
tmp_dir = tempfile.mkdtemp()
video_path = os.path.join(tmp_dir, "output.mp4")
save_video(result.frames, video_path, fps=8)
return video_path
with gr.Blocks() as demo:
gr.Markdown("# πŸŽ₯ Pusa Text-to-Video (Wan2.1-T2V-14B)")
prompt = gr.Textbox(label="Prompt", lines=4)
lora_file = gr.File(label="Upload LoRA .pt (optional)", file_types=[".pt"])
generate_btn = gr.Button("Generate")
output_video = gr.Video(label="Generated Video")
generate_btn.click(fn=generate_video, inputs=[prompt, lora_file], outputs=output_video)
demo.launch()