#!/usr/bin/env python """ Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video • Single global load (no repeated downloads) • Balanced device_map to avoid OOM on 24 GB A10 • Fast CLIP processor via use_fast=True • High-level streaming progress • Auto-download via gr.File """ import os # persist Hugging Face cache so safetensors only download once os.environ["HF_HOME"] = "/mnt/data/huggingface" import numpy as np import torch import gradio as gr from diffusers import WanImageToVideoPipeline, AutoencoderKLWan from diffusers.utils import export_to_video from transformers import CLIPVisionModel from PIL import Image import torchvision.transforms.functional as TF # ----------------------------------------------------------------------------- # CONFIG # ----------------------------------------------------------------------------- MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" DTYPE = torch.float16 MAX_AREA = 1280 * 720 DEFAULT_FRAMES = 81 # ----------------------------------------------------------------------------- # LOAD PIPELINE ONCE # ----------------------------------------------------------------------------- def load_pipeline(): # 1) CLIP image encoder (fp32) image_encoder = CLIPVisionModel.from_pretrained( MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32 ) # 2) VAE (fp16) vae = AutoencoderKLWan.from_pretrained( MODEL_ID, subfolder="vae", torch_dtype=DTYPE ) # 3) Balanced device placement + fast processor pipe = WanImageToVideoPipeline.from_pretrained( MODEL_ID, image_encoder=image_encoder, vae=vae, torch_dtype=DTYPE, device_map="balanced", # spread weights CPU↔GPU use_fast=True, # internal fast CLIPImageProcessor ) return pipe PIPE = load_pipeline() # ----------------------------------------------------------------------------- # HELPERS # ----------------------------------------------------------------------------- def aspect_resize(img: Image.Image, max_area=MAX_AREA): ar = img.height / img.width mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1] h = int(np.sqrt(max_area * ar)) // mod * mod w = int(np.sqrt(max_area / ar)) // mod * mod return img.resize((w, h), Image.LANCZOS), h, w def center_crop_resize(img: Image.Image, h, w): ratio = max(w / img.width, h / img.height) img2 = img.resize( (round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS ) return TF.center_crop(img2, [h, w]) # ----------------------------------------------------------------------------- # GENERATION + STREAMING # ----------------------------------------------------------------------------- def generate( first_frame: Image.Image, last_frame: Image.Image, prompt: str, negative: str, steps: int, guidance: float, num_frames: int, seed: int, fps: int, progress= gr.Progress(), ): # choose seed if seed == -1: seed = torch.seed() gen = torch.Generator(device=PIPE.device).manual_seed(seed) # 0–15%: resize progress(0.0, desc="Resizing first frame…") f_resized, h, w = aspect_resize(first_frame) if last_frame.size != f_resized.size: progress(0.15, desc="Resizing last frame…") l_resized = center_crop_resize(last_frame, h, w) else: l_resized = f_resized # 15–25%: spin up pipeline progress(0.25, desc="Launching inference…") out = PIPE( image=f_resized, last_image=l_resized, prompt=prompt, negative_prompt=negative or None, height=h, width=w, num_frames=num_frames, num_inference_steps=steps, guidance_scale=guidance, generator=gen, ) # 90–100%: export progress(0.90, desc="Building video file…") video_path = export_to_video(out.frames[0], fps=fps) progress(1.0, desc="Done!") return video_path, seed # ----------------------------------------------------------------------------- # GRADIO UI # ----------------------------------------------------------------------------- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video") with gr.Row(): first_img = gr.Image(label="First frame", type="pil") last_img = gr.Image(label="Last frame", type="pil") prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…") negative = gr.Textbox(label="Negative prompt (opt)", placeholder="blurry, lowres") with gr.Accordion("Advanced parameters", open=False): steps = gr.Slider(10, 50, value=30, step=1, label="Steps") guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance") num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames") fps = gr.Slider(4, 30, value=16, step=1, label="FPS") seed_input = gr.Number(value=-1, precision=0, label="Seed (-1=rand)") run_btn = gr.Button("Generate") download = gr.File(label="Download .mp4", interactive=False) seed_used = gr.Number(label="Seed used", interactive=False) run_btn.click( fn=generate, inputs=[ first_img, last_img, prompt, negative, steps, guidance, num_frames, seed_input, fps ], outputs=[ download, seed_used ], ) demo.queue().launch(server_name="0.0.0.0", server_port=7860)