|
from huggingface_hub import snapshot_download, hf_hub_download |
|
|
|
snapshot_download( |
|
repo_id="Wan-AI/Wan2.1-T2V-1.3B", |
|
local_dir="wan_models/Wan2.1-T2V-1.3B", |
|
local_dir_use_symlinks=False, |
|
resume_download=True, |
|
repo_type="model" |
|
) |
|
|
|
hf_hub_download( |
|
repo_id="gdhe17/Self-Forcing", |
|
filename="checkpoints/self_forcing_dmd.pt", |
|
local_dir=".", |
|
local_dir_use_symlinks=False |
|
) |
|
|
|
import os |
|
import re |
|
import random |
|
import argparse |
|
import hashlib |
|
import urllib.request |
|
from PIL import Image |
|
import spaces |
|
import numpy as np |
|
import torch |
|
import gradio as gr |
|
from omegaconf import OmegaConf |
|
from tqdm import tqdm |
|
import imageio |
|
|
|
|
|
from fastrtc import WebRTC, get_cloudflare_turn_credentials |
|
from fastrtc.utils import AdditionalOutputs |
|
|
|
|
|
from pipeline import CausalInferencePipeline |
|
from demo_utils.constant import ZERO_VAE_CACHE |
|
from demo_utils.vae_block3 import VAEDecoderWrapper |
|
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with FastRTC") |
|
parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.") |
|
parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the Gradio app to.") |
|
parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_forcing_dmd.pt', help="Path to the model checkpoint.") |
|
parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.") |
|
parser.add_argument('--share', action='store_true', help="Create a public Gradio link.") |
|
parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.") |
|
args = parser.parse_args() |
|
|
|
gpu = "cuda" |
|
|
|
try: |
|
config = OmegaConf.load(args.config_path) |
|
default_config = OmegaConf.load("configs/default_config.yaml") |
|
config = OmegaConf.merge(default_config, config) |
|
except FileNotFoundError as e: |
|
print(f"Error loading config file: {e}\n. Please ensure config files are in the correct path.") |
|
exit(1) |
|
|
|
|
|
print("Initializing models...") |
|
text_encoder = WanTextEncoder() |
|
transformer = WanDiffusionWrapper(is_causal=True) |
|
|
|
try: |
|
state_dict = torch.load(args.checkpoint_path, map_location="cpu") |
|
transformer.load_state_dict(state_dict.get('generator_ema', state_dict.get('generator'))) |
|
except FileNotFoundError as e: |
|
print(f"Error loading checkpoint: {e}\nPlease ensure the checkpoint '{args.checkpoint_path}' exists.") |
|
exit(1) |
|
|
|
text_encoder.eval().to(dtype=torch.bfloat16).requires_grad_(False) |
|
transformer.eval().to(dtype=torch.float16).requires_grad_(False) |
|
|
|
text_encoder.to(gpu) |
|
transformer.to(gpu) |
|
|
|
APP_STATE = { |
|
"torch_compile_applied": False, |
|
"fp8_applied": False, |
|
"current_use_taehv": False, |
|
"current_vae_decoder": None, |
|
} |
|
|
|
def initialize_vae_decoder(use_taehv=False, use_trt=False): |
|
if use_trt: |
|
from demo_utils.vae import VAETRTWrapper |
|
print("Initializing TensorRT VAE Decoder...") |
|
vae_decoder = VAETRTWrapper() |
|
APP_STATE["current_use_taehv"] = False |
|
elif use_taehv: |
|
print("Initializing TAEHV VAE Decoder...") |
|
from demo_utils.taehv import TAEHV |
|
taehv_checkpoint_path = "checkpoints/taew2_1.pth" |
|
if not os.path.exists(taehv_checkpoint_path): |
|
print(f"Downloading TAEHV checkpoint to {taehv_checkpoint_path}...") |
|
os.makedirs("checkpoints", exist_ok=True) |
|
download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth" |
|
try: |
|
urllib.request.urlretrieve(download_url, taehv_checkpoint_path) |
|
except Exception as e: |
|
raise RuntimeError(f"Failed to download taew2_1.pth: {e}") |
|
|
|
class DotDict(dict): __getattr__ = dict.get |
|
|
|
class TAEHVDiffusersWrapper(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.dtype = torch.float16 |
|
self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype) |
|
self.config = DotDict(scaling_factor=1.0) |
|
def decode(self, latents, return_dict=None): |
|
return self.taehv.decode_video(latents, parallel=not LOW_MEMORY).mul_(2).sub_(1) |
|
|
|
vae_decoder = TAEHVDiffusersWrapper() |
|
APP_STATE["current_use_taehv"] = True |
|
else: |
|
print("Initializing Default VAE Decoder...") |
|
vae_decoder = VAEDecoderWrapper() |
|
try: |
|
vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu") |
|
decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k} |
|
vae_decoder.load_state_dict(decoder_state_dict) |
|
except FileNotFoundError: |
|
print("Warning: Default VAE weights not found.") |
|
APP_STATE["current_use_taehv"] = False |
|
|
|
vae_decoder.eval().to(dtype=torch.float16).requires_grad_(False).to(gpu) |
|
APP_STATE["current_vae_decoder"] = vae_decoder |
|
print(f"โ
VAE decoder initialized: {'TAEHV' if use_taehv else 'Default VAE'}") |
|
|
|
|
|
initialize_vae_decoder(use_taehv=False, use_trt=args.trt) |
|
|
|
pipeline = CausalInferencePipeline( |
|
config, device=gpu, generator=transformer, text_encoder=text_encoder, |
|
vae=APP_STATE["current_vae_decoder"] |
|
) |
|
|
|
pipeline.to(dtype=torch.float16).to(gpu) |
|
|
|
|
|
def handle_additional_outputs(status_html_update, video_update, webrtc_output): |
|
return status_html_update, video_update, webrtc_output |
|
|
|
|
|
@torch.no_grad() |
|
@spaces.GPU |
|
def video_generation_handler(prompt, seed, progress=gr.Progress()): |
|
""" |
|
Generator function that yields BGR NumPy frames for real-time streaming. |
|
Returns cleanly when done - no infinite loops. |
|
""" |
|
|
|
if seed == -1: |
|
seed = random.randint(0, 2**32 - 1) |
|
|
|
print(f"๐ฌ Starting video generation with prompt: '{prompt}' and seed: {seed}") |
|
|
|
print("๐ค Encoding text prompt...") |
|
conditional_dict = text_encoder(text_prompts=[prompt]) |
|
for key, value in conditional_dict.items(): |
|
conditional_dict[key] = value.to(dtype=torch.float16) |
|
|
|
|
|
rnd = torch.Generator(gpu).manual_seed(int(seed)) |
|
pipeline._initialize_kv_cache(1, torch.float16, device=gpu) |
|
pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu) |
|
noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd) |
|
|
|
vae_cache, latents_cache = None, None |
|
if not APP_STATE["current_use_taehv"] and not args.trt: |
|
vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE] |
|
|
|
num_blocks = 7 |
|
current_start_frame = 0 |
|
all_num_frames = [pipeline.num_frame_per_block] * num_blocks |
|
|
|
total_frames_yielded = 0 |
|
all_frames_for_video = [] |
|
|
|
for idx, current_num_frames in enumerate(all_num_frames): |
|
print(f"๐ฆ Processing block {idx+1}/{num_blocks} with {current_num_frames} frames") |
|
|
|
noisy_input = noise[:, current_start_frame : current_start_frame + current_num_frames] |
|
|
|
for step_idx, current_timestep in enumerate(pipeline.denoising_step_list): |
|
timestep = torch.ones([1, current_num_frames], device=noise.device, dtype=torch.int64) * current_timestep |
|
_, denoised_pred = pipeline.generator( |
|
noisy_image_or_video=noisy_input, conditional_dict=conditional_dict, |
|
timestep=timestep, kv_cache=pipeline.kv_cache1, |
|
crossattn_cache=pipeline.crossattn_cache, |
|
current_start=current_start_frame * pipeline.frame_seq_length |
|
) |
|
if step_idx < len(pipeline.denoising_step_list) - 1: |
|
next_timestep = pipeline.denoising_step_list[step_idx + 1] |
|
noisy_input = pipeline.scheduler.add_noise( |
|
denoised_pred.flatten(0, 1), torch.randn_like(denoised_pred.flatten(0, 1)), |
|
next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long) |
|
).unflatten(0, denoised_pred.shape[:2]) |
|
|
|
if idx < len(all_num_frames) - 1: |
|
pipeline.generator( |
|
noisy_image_or_video=denoised_pred, conditional_dict=conditional_dict, |
|
timestep=torch.zeros_like(timestep), kv_cache=pipeline.kv_cache1, |
|
crossattn_cache=pipeline.crossattn_cache, |
|
current_start=current_start_frame * pipeline.frame_seq_length, |
|
) |
|
|
|
|
|
if args.trt: |
|
pixels, vae_cache = pipeline.vae.forward(denoised_pred.half(), *vae_cache) |
|
elif APP_STATE["current_use_taehv"]: |
|
if latents_cache is None: |
|
latents_cache = denoised_pred |
|
else: |
|
denoised_pred = torch.cat([latents_cache, denoised_pred], dim=1) |
|
latents_cache = denoised_pred[:, -3:] |
|
pixels = pipeline.vae.decode(denoised_pred) |
|
else: |
|
pixels, vae_cache = pipeline.vae(denoised_pred.half(), *vae_cache) |
|
|
|
|
|
if idx == 0 and not args.trt: |
|
pixels = pixels[:, 3:] |
|
elif APP_STATE["current_use_taehv"] and idx > 0: |
|
pixels = pixels[:, 12:] |
|
|
|
print(f"๐น Decoded pixels shape: {pixels.shape}") |
|
|
|
|
|
for frame_idx in range(pixels.shape[1]): |
|
frame_tensor = pixels[0, frame_idx] |
|
|
|
|
|
frame_np = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5 |
|
frame_np = frame_np.to(torch.uint8).cpu().numpy() |
|
|
|
|
|
frame_np = np.transpose(frame_np, (1, 2, 0)) |
|
|
|
all_frames_for_video.append(frame_np) |
|
|
|
|
|
frame_bgr = frame_np[:, :, ::-1] |
|
|
|
total_frames_yielded += 1 |
|
print(f"๐บ Yielding frame {total_frames_yielded}: shape {frame_bgr.shape}, dtype {frame_bgr.dtype}") |
|
|
|
|
|
total_expected_frames = num_blocks * pipeline.num_frame_per_block |
|
current_frame_count = (idx * pipeline.num_frame_per_block) + frame_idx + 1 |
|
frame_progress = 100 * (current_frame_count / total_expected_frames) |
|
|
|
|
|
if frame_idx == pixels.shape[1] - 1 and idx + 1 == num_blocks: |
|
status_html = ( |
|
f"<div style='padding: 16px; border: 1px solid #198754; background-color: #d1e7dd; border-radius: 8px; font-family: sans-serif; text-align: center;'>" |
|
f" <h4 style='margin: 0 0 8px 0; color: #0f5132; font-size: 18px;'>๐ Generation Complete!</h4>" |
|
f" <p style='margin: 0; color: #0f5132;'>" |
|
f" Total frames: {total_frames_yielded}. The final video is now available." |
|
f" </p>" |
|
f"</div>" |
|
) |
|
|
|
print("๐พ Saving final rendered video...") |
|
video_update = gr.update() |
|
try: |
|
video_path = f"gradio_tmp/{seed}_{hashlib.md5(prompt.encode()).hexdigest()}.mp4" |
|
imageio.mimwrite(video_path, all_frames_for_video, fps=15, quality=8) |
|
print(f"โ
Video saved to {video_path}") |
|
video_update = gr.update(value=video_path, visible=True) |
|
except Exception as e: |
|
print(f"โ ๏ธ Could not save final video: {e}") |
|
|
|
yield frame_bgr, AdditionalOutputs(status_html, video_update, gr.update(visible=False)) |
|
|
|
return |
|
else: |
|
status_html = ( |
|
f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>" |
|
f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>Generating Video...</p>" |
|
f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>" |
|
f" <div style='width: {frame_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>" |
|
f" </div>" |
|
f" <p style='margin: 8px 0 0 0; color: #555; font-size: 14px; text-align: right;'>" |
|
f" Block {idx+1}/{num_blocks} ย | ย Frame {total_frames_yielded} ย | ย {frame_progress:.1f}%" |
|
f" </p>" |
|
f"</div>" |
|
) |
|
|
|
|
|
yield frame_bgr, AdditionalOutputs(status_html, gr.update(visible=False), gr.update(visible=True)) |
|
|
|
current_start_frame += current_num_frames |
|
|
|
print(f"โ
Video generation completed! Total frames yielded: {total_frames_yielded}") |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), title="Self-Forcing FastRTC Demo") as demo: |
|
gr.Markdown("# ๐ Self-Forcing Video Generation with FastRTC Streaming") |
|
gr.Markdown("*Real-time video generation streaming via WebRTC*") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
gr.Markdown("### ๐ Configure Generation") |
|
with gr.Group(): |
|
prompt = gr.Textbox( |
|
label="Prompt", |
|
placeholder="A stylish woman walks down a Tokyo street...", |
|
lines=4, |
|
value="A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage." |
|
) |
|
gr.Examples( |
|
examples=[ |
|
"A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse.", |
|
"A white and orange tabby cat is seen happily darting through a dense garden, as if chasing something. Its eyes are wide and happy as it jogs forward, scanning the branches, flowers, and leaves.", |
|
"A drone shot of a surfer riding a wave on a sunny day. The camera follows the surfer as they carve through the water.", |
|
], |
|
inputs=[prompt] |
|
) |
|
|
|
with gr.Row(): |
|
seed = gr.Number(label="Seed", value=-1, info="Use -1 for a random seed.") |
|
|
|
with gr.Accordion("โ๏ธ Performance Options", open=False): |
|
gr.Markdown("*These optimizations are applied once per session*") |
|
|
|
start_btn = gr.Button("๐ฌ Start Generation", variant="primary", size="lg") |
|
|
|
with gr.Column(scale=3): |
|
gr.Markdown("### ๐บ Live Video Stream") |
|
gr.Markdown("*Click 'Start Generation' to begin streaming*") |
|
|
|
webrtc_output = WebRTC( |
|
label="Generated Video Stream", |
|
modality="video", |
|
mode="receive", |
|
height=480, |
|
width=832, |
|
rtc_configuration=get_cloudflare_turn_credentials(), |
|
elem_id="video_stream" |
|
) |
|
|
|
final_video = gr.Video(label="Final Rendered Video", visible=False, interactive=False) |
|
|
|
status_html = gr.HTML( |
|
value="<div style='text-align: center; padding: 20px; color: #666;'>Ready to start generation...</div>", |
|
label="Generation Status" |
|
) |
|
|
|
|
|
|
|
|
|
webrtc_output.stream( |
|
fn=video_generation_handler, |
|
inputs=[prompt, seed], |
|
outputs=[webrtc_output], |
|
time_limit=300, |
|
trigger=start_btn.click, |
|
) |
|
|
|
webrtc_output.on_additional_outputs( |
|
fn=handle_additional_outputs, |
|
outputs=[status_html, final_video, webrtc_output] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
if os.path.exists("gradio_tmp"): |
|
import shutil |
|
shutil.rmtree("gradio_tmp") |
|
os.makedirs("gradio_tmp", exist_ok=True) |
|
|
|
demo.queue().launch( |
|
server_name=args.host, |
|
server_port=args.port, |
|
share=args.share, |
|
show_error=True |
|
) |