import subprocess # not sure why it works in the original space but says "pip not found" in mine #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) import os import base64 from huggingface_hub import snapshot_download, hf_hub_download # Configuration for data paths DATA_ROOT = os.path.normpath(os.getenv('DATA_ROOT', '.')) WAN_MODELS_PATH = os.path.join(DATA_ROOT, 'wan_models') OTHER_MODELS_PATH = os.path.join(DATA_ROOT, 'other_models') snapshot_download( repo_id="Wan-AI/Wan2.1-T2V-1.3B", local_dir=os.path.join(WAN_MODELS_PATH, "Wan2.1-T2V-1.3B"), local_dir_use_symlinks=False, resume_download=True, repo_type="model" ) hf_hub_download( repo_id="gdhe17/Self-Forcing", filename="checkpoints/self_forcing_dmd.pt", local_dir=OTHER_MODELS_PATH, local_dir_use_symlinks=False ) import re import random import argparse import hashlib import urllib.request import time from PIL import Image import torch import gradio as gr from omegaconf import OmegaConf from tqdm import tqdm import imageio import av import uuid import tempfile import shutil from pathlib import Path from typing import Dict, Any, List, Optional, Tuple, Union from pipeline import CausalInferencePipeline from demo_utils.constant import ZERO_VAE_CACHE from demo_utils.vae_block3 import VAEDecoderWrapper from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM #, BitsAndBytesConfig import numpy as np device = "cuda" if torch.cuda.is_available() else "cpu" DEFAULT_WIDTH = 832 DEFAULT_HEIGHT = 480 def create_vae_cache_for_resolution(latent_height, latent_width, device, dtype): """ Create VAE cache tensors dynamically based on the latent resolution. The cache structure mirrors ZERO_VAE_CACHE but with resolution-dependent dimensions. """ # Scale dimensions based on latent resolution # The original cache assumes 832x480 -> 104x60 latent dimensions # We need to scale proportionally for other resolutions cache = [ torch.zeros(1, 16, 2, latent_height, latent_width, device=device, dtype=dtype), # First set of 384-channel caches at latent resolution torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype), torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype), torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype), torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype), torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype), torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype), torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype), torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype), torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype), torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype), torch.zeros(1, 384, 2, latent_height, latent_width, device=device, dtype=dtype), # Second set at 2x upsampled resolution torch.zeros(1, 192, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype), torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype), torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype), torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype), torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype), torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype), torch.zeros(1, 384, 2, latent_height * 2, latent_width * 2, device=device, dtype=dtype), # Third set at 4x upsampled resolution torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype), torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype), torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype), torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype), torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype), torch.zeros(1, 192, 2, latent_height * 4, latent_width * 4, device=device, dtype=dtype), # Fourth set at 8x upsampled resolution (final output resolution) torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype), torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype), torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype), torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype), torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype), torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype), torch.zeros(1, 96, 2, latent_height * 8, latent_width * 8, device=device, dtype=dtype) ] return cache # --- Argument Parsing --- parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming") parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.") parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the Gradio app to.") parser.add_argument("--checkpoint_path", type=str, default=os.path.join(OTHER_MODELS_PATH, 'checkpoints', 'self_forcing_dmd.pt'), help="Path to the model checkpoint.") parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.") parser.add_argument('--share', action='store_true', help="Create a public Gradio link.") parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.") parser.add_argument('--fps', type=float, default=15.0, help="Playback FPS for frame streaming.") args = parser.parse_args() gpu = "cuda" try: config = OmegaConf.load(args.config_path) default_config = OmegaConf.load("configs/default_config.yaml") config = OmegaConf.merge(default_config, config) except FileNotFoundError as e: print(f"Error loading config file: {e}\n. Please ensure config files are in the correct path.") exit(1) # Initialize Models print("Initializing models...") text_encoder = WanTextEncoder() transformer = WanDiffusionWrapper(is_causal=True) try: state_dict = torch.load(args.checkpoint_path, map_location="cpu") transformer.load_state_dict(state_dict.get('generator_ema', state_dict.get('generator'))) except FileNotFoundError as e: print(f"Error loading checkpoint: {e}\nPlease ensure the checkpoint '{args.checkpoint_path}' exists.") exit(1) text_encoder.eval().to(dtype=torch.float16).requires_grad_(False) transformer.eval().to(dtype=torch.float16).requires_grad_(False) text_encoder.to(gpu) transformer.to(gpu) APP_STATE = { "torch_compile_applied": False, "fp8_applied": False, "current_use_taehv": False, "current_vae_decoder": None, } # I've tried to enable it, but I didn't notice a significant performance improvement.. ENABLE_TORCH_COMPILATION = False # “default”: The default mode, used when no mode parameter is specified. It provides a good balance between performance and overhead. # “reduce-overhead”: Minimizes Python-related overhead using CUDA graphs. However, it may increase memory usage. # “max-autotune”: Uses Triton or template-based matrix multiplications on supported devices. It takes longer to compile but optimizes for the fastest possible execution. On GPUs it enables CUDA graphs by default. # “max-autotune-no-cudagraphs”: Similar to “max-autotune”, but without CUDA graphs. TORCH_COMPILATION_MODE = "default" # Apply torch.compile for maximum performance if not APP_STATE["torch_compile_applied"] and ENABLE_TORCH_COMPILATION: print("🚀 Applying torch.compile for speed optimization...") transformer.compile(mode=TORCH_COMPILATION_MODE) APP_STATE["torch_compile_applied"] = True print("✅ torch.compile applied to transformer") def frames_to_mp4_base64(frames, fps = 15): """ Convert frames directly to base64 data URI using PyAV. Args: frames: List of numpy arrays (HWC, RGB, uint8) fps: Frames per second Returns: Base64 data URI string for the MP4 video """ if not frames: return "data:video/mp4;base64," height, width = frames[0].shape[:2] # Create temporary file for MP4 encoding temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) temp_filepath = temp_file.name temp_file.close() try: # Create container for MP4 format container = av.open(temp_filepath, mode='w', format='mp4') # Add video stream with fast settings stream = container.add_stream('h264', rate=fps) stream.width = width stream.height = height stream.pix_fmt = 'yuv420p' # Optimize for low latency streaming stream.options = { 'preset': 'ultrafast', 'tune': 'zerolatency', 'crf': '23', 'profile': 'baseline', 'level': '3.0' } try: for frame_np in frames: frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24') frame = frame.reformat(format=stream.pix_fmt) for packet in stream.encode(frame): container.mux(packet) for packet in stream.encode(): container.mux(packet) finally: container.close() # Read the MP4 file and encode to base64 with open(temp_filepath, 'rb') as f: video_data = f.read() base64_data = base64.b64encode(video_data).decode('utf-8') return f"data:video/mp4;base64,{base64_data}" finally: # Clean up temporary file if os.path.exists(temp_filepath): os.unlink(temp_filepath) return "data:video/mp4;base64," # note: we set use_taehv to be able to use other resolutions # this might impact performance def initialize_vae_decoder(use_taehv=True, use_trt=False): if use_trt: from demo_utils.vae import VAETRTWrapper print("Initializing TensorRT VAE Decoder...") vae_decoder = VAETRTWrapper() APP_STATE["current_use_taehv"] = False elif use_taehv: print("Initializing TAEHV VAE Decoder...") from demo_utils.taehv import TAEHV taehv_checkpoint_path = "checkpoints/taew2_1.pth" if not os.path.exists(taehv_checkpoint_path): print(f"Downloading TAEHV checkpoint to {taehv_checkpoint_path}...") os.makedirs("checkpoints", exist_ok=True) download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth" try: urllib.request.urlretrieve(download_url, taehv_checkpoint_path) except Exception as e: raise RuntimeError(f"Failed to download taew2_1.pth: {e}") class DotDict(dict): __getattr__ = dict.get class TAEHVDiffusersWrapper(torch.nn.Module): def __init__(self): super().__init__() self.dtype = torch.float16 self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype) self.config = DotDict(scaling_factor=1.0) def decode(self, latents, return_dict=None): return self.taehv.decode_video(latents, parallel=not LOW_MEMORY).mul_(2).sub_(1) vae_decoder = TAEHVDiffusersWrapper() APP_STATE["current_use_taehv"] = True else: print("Initializing Default VAE Decoder...") vae_decoder = VAEDecoderWrapper() try: vae_state_dict = torch.load(os.path.join(WAN_MODELS_PATH, 'Wan2.1-T2V-1.3B', 'Wan2.1_VAE.pth'), map_location="cpu") decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k} vae_decoder.load_state_dict(decoder_state_dict) except FileNotFoundError: print("Warning: Default VAE weights not found.") APP_STATE["current_use_taehv"] = False vae_decoder.eval().to(dtype=torch.float16).requires_grad_(False).to(gpu) # Apply torch.compile to VAE decoder if enabled (following demo.py pattern) if APP_STATE["torch_compile_applied"] and not use_taehv and not use_trt: print("🚀 Applying torch.compile to VAE decoder...") vae_decoder.compile(mode=TORCH_COMPILATION_MODE) print("✅ torch.compile applied to VAE decoder") APP_STATE["current_vae_decoder"] = vae_decoder print(f"✅ VAE decoder initialized: {'TAEHV' if use_taehv else 'Default VAE'}") # Initialize with default VAE initialize_vae_decoder(use_taehv=False, use_trt=args.trt) pipeline = CausalInferencePipeline( config, device=gpu, generator=transformer, text_encoder=text_encoder, vae=APP_STATE["current_vae_decoder"] ) pipeline.to(dtype=torch.float16).to(gpu) @torch.no_grad() def video_generation_handler(prompt, seed=42, fps=15, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, duration=5): """ Generate video and return a single MP4 file. """ # Add fallback values for None parameters if seed is None: seed = -1 if fps is None: fps = 15 if width is None: width = DEFAULT_WIDTH if height is None: height = DEFAULT_HEIGHT if duration is None: duration = 5 if seed == -1: seed = random.randint(0, 2**32 - 1) print(f"🎬 video_generation_handler called, seed: {seed}, duration: {duration}s, fps: {fps}, width: {width}, height: {height}") # Setup conditional_dict = text_encoder(text_prompts=[prompt]) for key, value in conditional_dict.items(): conditional_dict[key] = value.to(dtype=torch.float16) rnd = torch.Generator(gpu).manual_seed(int(seed)) pipeline._initialize_kv_cache(1, torch.float16, device=gpu) pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu) # Calculate latent dimensions based on actual width/height (assuming 8x downsampling) latent_height = height // 8 latent_width = width // 8 noise = torch.randn([1, 21, 16, latent_height, latent_width], device=gpu, dtype=torch.float16, generator=rnd) vae_cache, latents_cache = None, None if not APP_STATE["current_use_taehv"] and not args.trt: # Create resolution-dependent VAE cache vae_cache = create_vae_cache_for_resolution(latent_height, latent_width, device=gpu, dtype=torch.float16) # Calculate number of blocks based on duration # Current setup generates approximately 5 seconds with 7 blocks # So we scale proportionally base_duration = 5.0 # seconds base_blocks = 8 num_blocks = max(1, int(base_blocks * duration / base_duration)) current_start_frame = 0 all_num_frames = [pipeline.num_frame_per_block] * num_blocks all_frames = [] total_frames_generated = 0 # Ensure temp directory exists os.makedirs("gradio_tmp", exist_ok=True) # Generation loop for idx, current_num_frames in enumerate(all_num_frames): print(f"📦 Processing block {idx+1}/{num_blocks}") noisy_input = noise[:, current_start_frame : current_start_frame + current_num_frames] # Denoising steps for step_idx, current_timestep in enumerate(pipeline.denoising_step_list): timestep = torch.ones([1, current_num_frames], device=noise.device, dtype=torch.int64) * current_timestep _, denoised_pred = pipeline.generator( noisy_image_or_video=noisy_input, conditional_dict=conditional_dict, timestep=timestep, kv_cache=pipeline.kv_cache1, crossattn_cache=pipeline.crossattn_cache, current_start=current_start_frame * pipeline.frame_seq_length ) if step_idx < len(pipeline.denoising_step_list) - 1: next_timestep = pipeline.denoising_step_list[step_idx + 1] noisy_input = pipeline.scheduler.add_noise( denoised_pred.flatten(0, 1), torch.randn_like(denoised_pred.flatten(0, 1)), next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long) ).unflatten(0, denoised_pred.shape[:2]) if idx < len(all_num_frames) - 1: pipeline.generator( noisy_image_or_video=denoised_pred, conditional_dict=conditional_dict, timestep=torch.zeros_like(timestep), kv_cache=pipeline.kv_cache1, crossattn_cache=pipeline.crossattn_cache, current_start=current_start_frame * pipeline.frame_seq_length, ) # Decode to pixels if args.trt: pixels, vae_cache = pipeline.vae.forward(denoised_pred.half(), *vae_cache) elif APP_STATE["current_use_taehv"]: if latents_cache is None: latents_cache = denoised_pred else: denoised_pred = torch.cat([latents_cache, denoised_pred], dim=1) latents_cache = denoised_pred[:, -3:] pixels = pipeline.vae.decode(denoised_pred) else: pixels, vae_cache = pipeline.vae(denoised_pred.half(), *vae_cache) # Handle frame skipping if idx == 0 and not args.trt: pixels = pixels[:, 3:] elif APP_STATE["current_use_taehv"] and idx > 0: pixels = pixels[:, 12:] print(f"🔍 DEBUG Block {idx}: Pixels shape after skipping: {pixels.shape}") # Process all frames from this block and add to main collection for frame_idx in range(pixels.shape[1]): frame_tensor = pixels[0, frame_idx] # Convert to numpy (HWC, RGB, uint8) frame_np = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5 frame_np = frame_np.to(torch.uint8).cpu().numpy() frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC all_frames.append(frame_np) total_frames_generated += 1 print(f"📦 Block {idx+1}/{num_blocks}, Frame {frame_idx+1}/{pixels.shape[1]} - Total frames: {total_frames_generated}") current_start_frame += current_num_frames # Generate final MP4 as base64 data URI if all_frames: print(f"📹 Encoding final MP4 with {len(all_frames)} frames") try: base64_data_uri = frames_to_mp4_base64(all_frames, fps) print(f"✅ Video generation complete! {total_frames_generated} frames encoded to base64 data URI") return base64_data_uri except Exception as e: print(f"⚠️ Error encoding final video: {e}") import traceback traceback.print_exc() return "data:video/mp4;base64," else: print("⚠️ No frames generated") return "data:video/mp4;base64," # --- Gradio UI Layout --- with gr.Blocks(title="Wan2.1 1.3B Self-Forcing demo") as demo: gr.Markdown("Video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)") with gr.Row(): with gr.Column(scale=2): with gr.Group(): prompt = gr.Textbox( label="Prompt", placeholder="A stylish woman walks down a Tokyo street...", lines=4, value="" ) start_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg") gr.Markdown("### ⚙️ Settings") with gr.Row(): seed = gr.Slider( label="Generation Seed (-1 for random)", minimum=-1, maximum=2147483647, # 2^31 - 1 step=1, value=-1 ) fps = gr.Slider( label="Playback FPS", minimum=1, maximum=30, value=args.fps, step=1, visible=False, info="Frames per second for playback" ) with gr.Row(): duration = gr.Slider( label="Duration (seconds)", minimum=1, maximum=5, value=3, step=1, info="Video duration in seconds" ) with gr.Row(): width = gr.Slider( label="Width", minimum=224, maximum=832, value=DEFAULT_WIDTH, step=8, info="Video width in pixels (8px steps)" ) height = gr.Slider( label="Height", minimum=224, maximum=832, value=DEFAULT_HEIGHT, step=8, info="Video height in pixels (8px steps)" ) with gr.Column(scale=3): gr.Markdown("### 🎬 Generated Video (Base64)") video_output = gr.Textbox( label="Base64 Video Data URI", lines=10, max_lines=20, show_copy_button=True, placeholder="Generated video will appear here as base64 data URI..." ) # Connect the generator to the text output start_btn.click( fn=video_generation_handler, inputs=[prompt, seed, fps, width, height, duration], outputs=[video_output] ) # --- Launch App --- if __name__ == "__main__": if os.path.exists("gradio_tmp"): import shutil shutil.rmtree("gradio_tmp") os.makedirs("gradio_tmp", exist_ok=True) print("🚀 Video Generation Node (default engine is Wan2.1 1.3B Self-Forcing)") print(f"📁 Temporary files will be stored in: gradio_tmp/") print(f"🎯 Video encoding: PyAV (MP4/H.264)") print(f"⚡ GPU acceleration: {gpu}") demo.queue().launch( server_name=args.host, server_port=args.port, share=args.share, show_error=True, max_threads=40, mcp_server=True )