|
import subprocess |
|
|
|
|
|
|
|
import os |
|
from huggingface_hub import snapshot_download, hf_hub_download |
|
|
|
|
|
DATA_ROOT = os.path.normpath(os.getenv('DATA_ROOT', '.')) |
|
WAN_MODELS_PATH = os.path.join(DATA_ROOT, 'wan_models') |
|
OTHER_MODELS_PATH = os.path.join(DATA_ROOT, 'other_models') |
|
|
|
snapshot_download( |
|
repo_id="Wan-AI/Wan2.1-T2V-1.3B", |
|
local_dir=os.path.join(WAN_MODELS_PATH, "Wan2.1-T2V-1.3B"), |
|
local_dir_use_symlinks=False, |
|
resume_download=True, |
|
repo_type="model" |
|
) |
|
|
|
hf_hub_download( |
|
repo_id="gdhe17/Self-Forcing", |
|
filename="checkpoints/self_forcing_dmd.pt", |
|
local_dir=OTHER_MODELS_PATH, |
|
local_dir_use_symlinks=False |
|
) |
|
import re |
|
import random |
|
import argparse |
|
import hashlib |
|
import urllib.request |
|
import time |
|
from PIL import Image |
|
import torch |
|
import gradio as gr |
|
from omegaconf import OmegaConf |
|
from tqdm import tqdm |
|
import imageio |
|
import av |
|
import uuid |
|
import tempfile |
|
import shutil |
|
from pathlib import Path |
|
from typing import Dict, Any, List, Optional, Tuple, Union |
|
|
|
from pipeline import CausalInferencePipeline |
|
from demo_utils.constant import ZERO_VAE_CACHE |
|
from demo_utils.vae_block3 import VAEDecoderWrapper |
|
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder |
|
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
import numpy as np |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
DEFAULT_WIDTH = 832 |
|
DEFAULT_HEIGHT = 480 |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming") |
|
parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.") |
|
parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the Gradio app to.") |
|
parser.add_argument("--checkpoint_path", type=str, default=os.path.join(OTHER_MODELS_PATH, 'checkpoints', 'self_forcing_dmd.pt'), help="Path to the model checkpoint.") |
|
parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.") |
|
parser.add_argument('--share', action='store_true', help="Create a public Gradio link.") |
|
parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.") |
|
parser.add_argument('--fps', type=float, default=15.0, help="Playback FPS for frame streaming.") |
|
args = parser.parse_args() |
|
|
|
gpu = "cuda" |
|
|
|
try: |
|
config = OmegaConf.load(args.config_path) |
|
default_config = OmegaConf.load("configs/default_config.yaml") |
|
config = OmegaConf.merge(default_config, config) |
|
except FileNotFoundError as e: |
|
print(f"Error loading config file: {e}\n. Please ensure config files are in the correct path.") |
|
exit(1) |
|
|
|
|
|
print("Initializing models...") |
|
text_encoder = WanTextEncoder() |
|
transformer = WanDiffusionWrapper(is_causal=True) |
|
|
|
try: |
|
state_dict = torch.load(args.checkpoint_path, map_location="cpu") |
|
transformer.load_state_dict(state_dict.get('generator_ema', state_dict.get('generator'))) |
|
except FileNotFoundError as e: |
|
print(f"Error loading checkpoint: {e}\nPlease ensure the checkpoint '{args.checkpoint_path}' exists.") |
|
exit(1) |
|
|
|
text_encoder.eval().to(dtype=torch.float16).requires_grad_(False) |
|
transformer.eval().to(dtype=torch.float16).requires_grad_(False) |
|
|
|
text_encoder.to(gpu) |
|
transformer.to(gpu) |
|
|
|
APP_STATE = { |
|
"torch_compile_applied": False, |
|
"fp8_applied": False, |
|
"current_use_taehv": False, |
|
"current_vae_decoder": None, |
|
} |
|
|
|
|
|
ENABLE_TORCH_COMPILATION = False |
|
|
|
|
|
|
|
|
|
|
|
TORCH_COMPILATION_MODE = "default" |
|
|
|
|
|
if not APP_STATE["torch_compile_applied"] and ENABLE_TORCH_COMPILATION: |
|
print("π Applying torch.compile for speed optimization...") |
|
transformer.compile(mode=TORCH_COMPILATION_MODE) |
|
APP_STATE["torch_compile_applied"] = True |
|
print("β
torch.compile applied to transformer") |
|
|
|
def frames_to_ts_file(frames, filepath, fps = 15): |
|
""" |
|
Convert frames directly to .ts file using PyAV. |
|
|
|
Args: |
|
frames: List of numpy arrays (HWC, RGB, uint8) |
|
filepath: Output file path |
|
fps: Frames per second |
|
|
|
Returns: |
|
The filepath of the created file |
|
""" |
|
if not frames: |
|
return filepath |
|
|
|
height, width = frames[0].shape[:2] |
|
|
|
|
|
container = av.open(filepath, mode='w', format='mpegts') |
|
|
|
|
|
stream = container.add_stream('h264', rate=fps) |
|
stream.width = width |
|
stream.height = height |
|
stream.pix_fmt = 'yuv420p' |
|
|
|
|
|
stream.options = { |
|
'preset': 'ultrafast', |
|
'tune': 'zerolatency', |
|
'crf': '23', |
|
'profile': 'baseline', |
|
'level': '3.0' |
|
} |
|
|
|
try: |
|
for frame_np in frames: |
|
frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24') |
|
frame = frame.reformat(format=stream.pix_fmt) |
|
for packet in stream.encode(frame): |
|
container.mux(packet) |
|
|
|
for packet in stream.encode(): |
|
container.mux(packet) |
|
|
|
finally: |
|
container.close() |
|
|
|
return filepath |
|
|
|
def initialize_vae_decoder(use_taehv=False, use_trt=False): |
|
if use_trt: |
|
from demo_utils.vae import VAETRTWrapper |
|
print("Initializing TensorRT VAE Decoder...") |
|
vae_decoder = VAETRTWrapper() |
|
APP_STATE["current_use_taehv"] = False |
|
elif use_taehv: |
|
print("Initializing TAEHV VAE Decoder...") |
|
from demo_utils.taehv import TAEHV |
|
taehv_checkpoint_path = "checkpoints/taew2_1.pth" |
|
if not os.path.exists(taehv_checkpoint_path): |
|
print(f"Downloading TAEHV checkpoint to {taehv_checkpoint_path}...") |
|
os.makedirs("checkpoints", exist_ok=True) |
|
download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth" |
|
try: |
|
urllib.request.urlretrieve(download_url, taehv_checkpoint_path) |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to download taew2_1.pth: {e}") |
|
|
|
class DotDict(dict): __getattr__ = dict.get |
|
|
|
class TAEHVDiffusersWrapper(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.dtype = torch.float16 |
|
self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype) |
|
self.config = DotDict(scaling_factor=1.0) |
|
def decode(self, latents, return_dict=None): |
|
return self.taehv.decode_video(latents, parallel=not LOW_MEMORY).mul_(2).sub_(1) |
|
|
|
vae_decoder = TAEHVDiffusersWrapper() |
|
APP_STATE["current_use_taehv"] = True |
|
else: |
|
print("Initializing Default VAE Decoder...") |
|
vae_decoder = VAEDecoderWrapper() |
|
try: |
|
vae_state_dict = torch.load(os.path.join(WAN_MODELS_PATH, 'Wan2.1-T2V-1.3B', 'Wan2.1_VAE.pth'), map_location="cpu") |
|
decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k} |
|
vae_decoder.load_state_dict(decoder_state_dict) |
|
except FileNotFoundError: |
|
print("Warning: Default VAE weights not found.") |
|
APP_STATE["current_use_taehv"] = False |
|
|
|
vae_decoder.eval().to(dtype=torch.float16).requires_grad_(False).to(gpu) |
|
|
|
|
|
if APP_STATE["torch_compile_applied"] and not use_taehv and not use_trt: |
|
print("π Applying torch.compile to VAE decoder...") |
|
vae_decoder.compile(mode=TORCH_COMPILATION_MODE) |
|
print("β
torch.compile applied to VAE decoder") |
|
|
|
APP_STATE["current_vae_decoder"] = vae_decoder |
|
print(f"β
VAE decoder initialized: {'TAEHV' if use_taehv else 'Default VAE'}") |
|
|
|
|
|
initialize_vae_decoder(use_taehv=False, use_trt=args.trt) |
|
|
|
pipeline = CausalInferencePipeline( |
|
config, device=gpu, generator=transformer, text_encoder=text_encoder, |
|
vae=APP_STATE["current_vae_decoder"] |
|
) |
|
|
|
pipeline.to(dtype=torch.float16).to(gpu) |
|
|
|
@torch.no_grad() |
|
def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, duration=5): |
|
""" |
|
Generator function that yields .ts video chunks using PyAV for streaming. |
|
""" |
|
|
|
if seed is None: |
|
seed = -1 |
|
if fps is None: |
|
fps = 15 |
|
if width is None: |
|
width = DEFAULT_WIDTH |
|
if height is None: |
|
height = DEFAULT_HEIGHT |
|
if duration is None: |
|
duration = 5 |
|
|
|
if seed == -1: |
|
seed = random.randint(0, 2**32 - 1) |
|
|
|
|
|
|
|
conditional_dict = text_encoder(text_prompts=[prompt]) |
|
for key, value in conditional_dict.items(): |
|
conditional_dict[key] = value.to(dtype=torch.float16) |
|
|
|
rnd = torch.Generator(gpu).manual_seed(int(seed)) |
|
pipeline._initialize_kv_cache(1, torch.float16, device=gpu) |
|
pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu) |
|
|
|
|
|
latent_height = height // 8 |
|
latent_width = width // 8 |
|
|
|
print(f"π¬ video_generation_handler_streaming called, seed: {seed}, duration: {duration}s, fps: {fps}, width: {width}, height: {height}") |
|
|
|
noise = torch.randn([1, 21, 16, latent_height, latent_width], device=gpu, dtype=torch.float16, generator=rnd) |
|
|
|
vae_cache, latents_cache = None, None |
|
if not APP_STATE["current_use_taehv"] and not args.trt: |
|
vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE] |
|
|
|
|
|
|
|
|
|
base_duration = 5.0 |
|
base_blocks = 8 |
|
num_blocks = max(1, int(base_blocks * duration / base_duration)) |
|
|
|
current_start_frame = 0 |
|
all_num_frames = [pipeline.num_frame_per_block] * num_blocks |
|
|
|
total_frames_yielded = 0 |
|
|
|
|
|
os.makedirs("gradio_tmp", exist_ok=True) |
|
|
|
|
|
for idx, current_num_frames in enumerate(all_num_frames): |
|
print(f"π¦ Processing block {idx+1}/{num_blocks}") |
|
|
|
noisy_input = noise[:, current_start_frame : current_start_frame + current_num_frames] |
|
|
|
|
|
for step_idx, current_timestep in enumerate(pipeline.denoising_step_list): |
|
timestep = torch.ones([1, current_num_frames], device=noise.device, dtype=torch.int64) * current_timestep |
|
_, denoised_pred = pipeline.generator( |
|
noisy_image_or_video=noisy_input, conditional_dict=conditional_dict, |
|
timestep=timestep, kv_cache=pipeline.kv_cache1, |
|
crossattn_cache=pipeline.crossattn_cache, |
|
current_start=current_start_frame * pipeline.frame_seq_length |
|
) |
|
if step_idx < len(pipeline.denoising_step_list) - 1: |
|
next_timestep = pipeline.denoising_step_list[step_idx + 1] |
|
noisy_input = pipeline.scheduler.add_noise( |
|
denoised_pred.flatten(0, 1), torch.randn_like(denoised_pred.flatten(0, 1)), |
|
next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long) |
|
).unflatten(0, denoised_pred.shape[:2]) |
|
|
|
if idx < len(all_num_frames) - 1: |
|
pipeline.generator( |
|
noisy_image_or_video=denoised_pred, conditional_dict=conditional_dict, |
|
timestep=torch.zeros_like(timestep), kv_cache=pipeline.kv_cache1, |
|
crossattn_cache=pipeline.crossattn_cache, |
|
current_start=current_start_frame * pipeline.frame_seq_length, |
|
) |
|
|
|
|
|
if args.trt: |
|
pixels, vae_cache = pipeline.vae.forward(denoised_pred.half(), *vae_cache) |
|
elif APP_STATE["current_use_taehv"]: |
|
if latents_cache is None: |
|
latents_cache = denoised_pred |
|
else: |
|
denoised_pred = torch.cat([latents_cache, denoised_pred], dim=1) |
|
latents_cache = denoised_pred[:, -3:] |
|
pixels = pipeline.vae.decode(denoised_pred) |
|
else: |
|
pixels, vae_cache = pipeline.vae(denoised_pred.half(), *vae_cache) |
|
|
|
|
|
if idx == 0 and not args.trt: |
|
pixels = pixels[:, 3:] |
|
elif APP_STATE["current_use_taehv"] and idx > 0: |
|
pixels = pixels[:, 12:] |
|
|
|
print(f"π DEBUG Block {idx}: Pixels shape after skipping: {pixels.shape}") |
|
|
|
|
|
all_frames_from_block = [] |
|
for frame_idx in range(pixels.shape[1]): |
|
frame_tensor = pixels[0, frame_idx] |
|
|
|
|
|
frame_np = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5 |
|
frame_np = frame_np.to(torch.uint8).cpu().numpy() |
|
frame_np = np.transpose(frame_np, (1, 2, 0)) |
|
|
|
all_frames_from_block.append(frame_np) |
|
total_frames_yielded += 1 |
|
|
|
|
|
blocks_completed = idx |
|
current_block_progress = (frame_idx + 1) / pixels.shape[1] |
|
total_progress = (blocks_completed + current_block_progress) / num_blocks * 100 |
|
|
|
|
|
total_progress = min(total_progress, 100.0) |
|
|
|
frame_status_html = ( |
|
f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>" |
|
f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>Generating Video...</p>" |
|
f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>" |
|
f" <div style='width: {total_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>" |
|
f" </div>" |
|
f" <p style='margin: 8px 0 0 0; color: #555; font-size: 14px; text-align: right;'>" |
|
f" Block {idx+1}/{num_blocks} | Frame {total_frames_yielded} | {total_progress:.1f}%" |
|
f" </p>" |
|
f"</div>" |
|
) |
|
|
|
|
|
yield None, frame_status_html |
|
|
|
|
|
if all_frames_from_block: |
|
print(f"πΉ Encoding block {idx} with {len(all_frames_from_block)} frames") |
|
|
|
try: |
|
chunk_uuid = str(uuid.uuid4())[:8] |
|
ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts" |
|
ts_path = os.path.join("gradio_tmp", ts_filename) |
|
|
|
frames_to_ts_file(all_frames_from_block, ts_path, fps) |
|
|
|
|
|
total_progress = (idx + 1) / num_blocks * 100 |
|
|
|
|
|
yield ts_path, gr.update() |
|
|
|
except Exception as e: |
|
print(f"β οΈ Error encoding block {idx}: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
|
|
current_start_frame += current_num_frames |
|
|
|
|
|
final_status_html = ( |
|
f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>" |
|
f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>" |
|
f" <span style='font-size: 24px; margin-right: 12px;'>π</span>" |
|
f" <h4 style='margin: 0; color: #0f5132; font-size: 18px;'>Stream Complete!</h4>" |
|
f" </div>" |
|
f" <div style='background: rgba(255,255,255,0.7); padding: 8px; border-radius: 4px;'>" |
|
f" <p style='margin: 0; color: #0f5132; font-weight: 500;'>" |
|
f" π Generated {total_frames_yielded} frames across {num_blocks} blocks" |
|
f" </p>" |
|
f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>" |
|
f" π¬ Playback: {fps} FPS β’ π Format: MPEG-TS/H.264" |
|
f" </p>" |
|
f" </div>" |
|
f"</div>" |
|
) |
|
yield None, final_status_html |
|
print(f"β
PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks") |
|
|
|
|
|
with gr.Blocks(title="Wan2.1 1.3B Self-Forcing streaming demo") as demo: |
|
gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
with gr.Group(): |
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
placeholder="A stylish woman walks down a Tokyo street...", |
|
lines=4, |
|
value="" |
|
) |
|
start_btn = gr.Button("π¬ Start Streaming", variant="primary", size="lg") |
|
|
|
gr.Markdown("### βοΈ Settings") |
|
with gr.Row(): |
|
seed = gr.Slider( |
|
label="Generation Seed (-1 for random)", |
|
minimum=-1, |
|
maximum=2147483647, |
|
step=1, |
|
value=-1 |
|
) |
|
fps = gr.Slider( |
|
label="Playback FPS", |
|
minimum=1, |
|
maximum=30, |
|
value=args.fps, |
|
step=1, |
|
visible=False, |
|
info="Frames per second for playback" |
|
) |
|
|
|
with gr.Row(): |
|
duration = gr.Slider( |
|
label="Duration (seconds)", |
|
minimum=1, |
|
maximum=5, |
|
value=3, |
|
step=1, |
|
info="Video duration in seconds" |
|
) |
|
|
|
with gr.Row(): |
|
width = gr.Slider( |
|
label="Width", |
|
minimum=224, |
|
maximum=832, |
|
value=DEFAULT_WIDTH, |
|
step=8, |
|
info="Video width in pixels (8px steps)" |
|
) |
|
height = gr.Slider( |
|
label="Height", |
|
minimum=224, |
|
maximum=832, |
|
value=DEFAULT_HEIGHT, |
|
step=8, |
|
info="Video height in pixels (8px steps)" |
|
) |
|
|
|
with gr.Column(scale=3): |
|
gr.Markdown("### πΊ Video Stream") |
|
streaming_video = gr.Video( |
|
label="Live Stream", |
|
streaming=True, |
|
loop=True, |
|
height=400, |
|
autoplay=True, |
|
show_label=False |
|
) |
|
|
|
status_display = gr.HTML( |
|
value=( |
|
"<div style='text-align: center; padding: 20px; color: #666; border: 1px dashed #ddd; border-radius: 8px;'>" |
|
"π¬ Ready to start streaming...<br>" |
|
"<small>Configure your prompt and click 'Start Streaming'</small>" |
|
"</div>" |
|
), |
|
label="Generation Status" |
|
) |
|
|
|
|
|
start_btn.click( |
|
fn=video_generation_handler_streaming, |
|
inputs=[prompt, seed, fps, width, height, duration], |
|
outputs=[streaming_video, status_display] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
if os.path.exists("gradio_tmp"): |
|
import shutil |
|
shutil.rmtree("gradio_tmp") |
|
os.makedirs("gradio_tmp", exist_ok=True) |
|
|
|
print("π Clapper Rendering Node (default engine is Wan2.1 1.3B Self-Forcing)") |
|
print(f"π Temporary files will be stored in: gradio_tmp/") |
|
print(f"π― Chunk encoding: PyAV (MPEG-TS/H.264)") |
|
print(f"β‘ GPU acceleration: {gpu}") |
|
|
|
demo.queue().launch( |
|
server_name=args.host, |
|
server_port=args.port, |
|
share=args.share, |
|
show_error=True, |
|
max_threads=40, |
|
mcp_server=True |
|
) |