Spaces:
Paused
Paused
import os | |
import sys | |
import gradio as gr | |
import tempfile | |
from huggingface_hub import snapshot_download | |
import spaces | |
from diffsynth import ModelManager, WanVideoPusaPipeline, save_video, VideoData | |
WAN_MODEL_DIR = "/tmp/model_zoo/Wan2.1-T2V-14B" | |
LORA_DIR = "/tmp/model_zoo/PusaV1" | |
LORA_PATH = os.path.join(LORA_DIR, "pusa_v1.pt") | |
def generate_video(prompt, lora_upload): | |
# Download Wan2.1 model only if missing | |
if not os.path.exists(WAN_MODEL_DIR): | |
snapshot_download( | |
repo_id="RaphaelLiu/PusaV1", | |
allow_patterns=["Wan2.1-T2V-14B/*"], | |
local_dir=WAN_MODEL_DIR, | |
local_dir_use_symlinks=False, | |
resume_download=True, | |
) | |
# Handle LoRA file (upload or default download + stitch) | |
if lora_upload is not None: | |
lora_path = lora_upload | |
else: | |
if not os.path.exists(LORA_PATH): | |
os.makedirs(LORA_DIR, exist_ok=True) | |
snapshot_download( | |
repo_id="RaphaelLiu/PusaV1", | |
allow_patterns=["PusaV1/pusa_v1.pt.part*"], | |
local_dir=LORA_DIR, | |
local_dir_use_symlinks=False, | |
) | |
# Stitch parts | |
part_files = sorted( | |
f for f in os.listdir(LORA_DIR) if f.startswith("pusa_v1.pt.part") | |
) | |
with open(LORA_PATH, "wb") as wfd: | |
for part in part_files: | |
with open(os.path.join(LORA_DIR, part), "rb") as fd: | |
wfd.write(fd.read()) | |
lora_path = LORA_PATH | |
# Load model and pipeline | |
manager = ModelManager(pretrained_model_dir=WAN_MODEL_DIR) | |
pipe = WanVideoPusaPipeline(model_manager=manager) | |
pipe.set_lora_adapters(lora_path) | |
# Run generation | |
result: VideoData = pipe(prompt) | |
# Save video to temp file | |
tmp_dir = tempfile.mkdtemp() | |
video_path = os.path.join(tmp_dir, "output.mp4") | |
save_video(result.frames, video_path, fps=8) | |
return video_path | |
with gr.Blocks() as demo: | |
gr.Markdown("# π₯ Pusa Text-to-Video (Wan2.1-T2V-14B)") | |
prompt = gr.Textbox(label="Prompt", lines=4) | |
lora_file = gr.File(label="Upload LoRA .pt (optional)", file_types=[".pt"]) | |
generate_btn = gr.Button("Generate") | |
output_video = gr.Video(label="Generated Video") | |
generate_btn.click(fn=generate_video, inputs=[prompt, lora_file], outputs=output_video) | |
demo.launch() | |