Wan-2.2-5B / app_t2v.py
rahul7star's picture
Update app_t2v.py
c5f7d50 verified
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import gradio as gr
import torch
from huggingface_hub import snapshot_download
from huggingface_hub import hf_hub_download
from PIL import Image
import random
import numpy as np
import spaces
import gc
# Import for Stable Diffusion XL
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
from compel import Compel, ReturnedEmbeddingsType
# Import for Wan2.2
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from wan.utils.utils import cache_video
# --- Global Setup ---
print("Starting Integrated Text-to-Image-to-Video App...")
# --- 1. Setup Text-to-Image Model (SDXL) ---
print("Loading Stable Diffusion XL model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize SDXL pipeline
sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(
"votepurchase/pornmasterPro_noobV3VAE",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
)
# sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(
# "stablediffusionapi/omnigenxl-nsfw-sfw",
# torch_dtype=torch.float16,
# variant="fp16",
# use_safetensors=True
# )
sdxl_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sdxl_pipe.scheduler.config)
sdxl_pipe.to(device)
# Force all components to use the same dtype
sdxl_pipe.text_encoder.to(torch.float16)
sdxl_pipe.text_encoder_2.to(torch.float16)
sdxl_pipe.vae.to(torch.float16)
sdxl_pipe.unet.to(torch.float16)
# Initialize Compel for long prompt processing
compel = Compel(
tokenizer=[sdxl_pipe.tokenizer, sdxl_pipe.tokenizer_2],
text_encoder=[sdxl_pipe.text_encoder, sdxl_pipe.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True],
truncate_long_prompts=False
)
# --- 2. Setup Image-to-Video Model (Wan2.2) ---
print("Loading Wan 2.2 TI2V-5B model...")
# Download model snapshots
repo_id = "Wan-AI/Wan2.2-TI2V-5B"
print(f"Downloading/loading checkpoints for {repo_id}...")
ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
print(f"Using checkpoints from {ckpt_dir}")
# Load the model configuration
TASK_NAME = 'ti2v-5B'
cfg = WAN_CONFIGS[TASK_NAME]
FIXED_FPS = 24
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 121
# Instantiate the pipeline
device_id = 0 if torch.cuda.is_available() else -1
wan_pipeline = wan.WanTI2V(
config=cfg,
checkpoint_dir=ckpt_dir,
device_id=device_id,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_sp=False,
t5_cpu=False,
init_on_cpu=False,
convert_model_dtype=True,
)
###LORA ####
# LORA_REPO_ID = "Kijai/WanVideo_comfy"
# LORA_FILENAME = "Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors"
# causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
# wan_pipeline.load_lora_weights(causvid_path, adapter_name="causvid_lora")
# wan_pipeline.set_adapters(["causvid_lora"], adapter_weights=[0.95])
# wan_pipeline.fuse_lora()
print("All models loaded and ready.")
# --- Constants ---
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1216
# --- Helper Functions ---
def clear_gpu_memory():
"""Clear GPU memory more thoroughly"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
def process_long_prompt(prompt, negative_prompt=""):
"""Simple long prompt processing using Compel"""
try:
conditioning, pooled = compel([prompt, negative_prompt])
return conditioning, pooled
except Exception as e:
print(f"Long prompt processing failed: {e}, falling back to standard processing")
return None, None
def select_best_size_for_image(image, available_sizes):
"""Select the size option with aspect ratio closest to the input image."""
if image is None:
return available_sizes[0]
img_width, img_height = image.size
img_aspect_ratio = img_height / img_width
best_size = available_sizes[0]
best_diff = float('inf')
for size_str in available_sizes:
height, width = map(int, size_str.split('*'))
size_aspect_ratio = height / width
diff = abs(img_aspect_ratio - size_aspect_ratio)
if diff < best_diff:
best_diff = diff
best_size = size_str
return best_size
def validate_video_inputs(image, prompt, duration_seconds):
"""Validate user inputs for video generation"""
errors = []
if not prompt or len(prompt.strip()) < 5:
errors.append("Prompt must be at least 5 characters long.")
if image is not None:
if isinstance(image, np.ndarray):
img = Image.fromarray(image)
else:
img = image
if img.size[0] * img.size[1] > 4096 * 4096:
errors.append("Image size is too large (maximum 4096x4096).")
if duration_seconds > 5.0 and image is None:
errors.append("Videos longer than 5 seconds require an input image.")
return errors
# --- Text-to-Image Generation Function ---
def get_t_duration(prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps, progress):
"""Calculate dynamic GPU duration for video generation"""
if sampling_steps > 35 and duration_seconds >= 2:
return 4
elif sampling_steps < 35 or duration_seconds < 2:
return 105
else:
return 90
@spaces.GPU(duration=50)
def generate_image(
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
progress=gr.Progress(track_tqdm=True)
):
"""Generate image from text prompt"""
progress(0, desc="Initializing image generation...")
use_long_prompt = len(prompt.split()) > 60 or len(prompt) > 300
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
try:
progress(0.3, desc="Processing prompt...")
if use_long_prompt:
print("Using long prompt processing...")
conditioning, pooled = process_long_prompt(prompt, negative_prompt)
if conditioning is not None:
progress(0.5, desc="Generating image...")
output_image = sdxl_pipe(
prompt_embeds=conditioning[0:1],
pooled_prompt_embeds=pooled[0:1],
negative_prompt_embeds=conditioning[1:2],
negative_pooled_prompt_embeds=pooled[1:2],
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
).images[0]
progress(1.0, desc="Complete!")
return output_image, seed
# Fall back to standard processing
progress(0.5, desc="Generating image...")
output_image = sdxl_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator
).images[0]
progress(1.0, desc="Complete!")
return output_image, seed
except RuntimeError as e:
print(f"Error during generation: {e}")
error_img = Image.new('RGB', (width, height), color=(0, 0, 0))
return error_img, seed
finally:
clear_gpu_memory()
# --- Image-to-Video Generation Function ---
def get_video_duration(image, prompt, size, duration_seconds, sampling_steps, guide_scale, shift, seed, progress):
"""Calculate dynamic GPU duration for video generation"""
if sampling_steps > 35 and duration_seconds >= 2:
return 120
elif sampling_steps < 35 or duration_seconds < 2:
return 105
else:
return 90
@spaces.GPU(duration=120)
def generate_video(
image,
prompt,
size,
duration_seconds,
sampling_steps,
guide_scale,
shift,
seed,
progress=gr.Progress(track_tqdm=True)
):
"""Generate video from image and prompt"""
errors = validate_video_inputs(image, prompt, duration_seconds)
if errors:
raise gr.Error("\n".join(errors))
progress(0, desc="Setting up video generation...")
if seed == -1:
seed = random.randint(0, sys.maxsize)
progress(0.1, desc="Processing image...")
input_image = None
if image is not None:
if isinstance(image, np.ndarray):
input_image = Image.fromarray(image).convert("RGB")
else:
input_image = image.convert("RGB")
# Resize image to match selected size
target_height, target_width = map(int, size.split('*'))
input_image = input_image.resize((target_width, target_height))
# Calculate number of frames based on duration
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
progress(0.2, desc="Generating video...")
try:
video_tensor = wan_pipeline.generate(
input_prompt=prompt,
img=input_image,
size=SIZE_CONFIGS[size],
max_area=MAX_AREA_CONFIGS[size],
frame_num=num_frames,
shift=shift,
sample_solver='unipc',
sampling_steps=int(sampling_steps),
guide_scale=guide_scale,
seed=seed,
offload_model=True
)
progress(0.9, desc="Saving video...")
video_path = cache_video(
tensor=video_tensor[None],
save_file=None,
fps=cfg.sample_fps,
normalize=True,
value_range=(-1, 1)
)
progress(1.0, desc="Complete!")
except torch.cuda.OutOfMemoryError:
clear_gpu_memory()
raise gr.Error("GPU out of memory. Please try with lower settings.")
except Exception as e:
raise gr.Error(f"Video generation failed: {str(e)}")
finally:
if 'video_tensor' in locals():
del video_tensor
clear_gpu_memory()
return video_path
# --- Combined Generation Function ---
def generate_image_to_video(
img_prompt,
img_negative_prompt,
img_seed,
img_randomize_seed,
img_width,
img_height,
img_guidance_scale,
img_num_inference_steps,
video_prompt,
video_size,
video_duration,
video_sampling_steps,
video_guide_scale,
video_shift,
video_seed
):
"""Generate image from text, then use it to generate video"""
# First generate image
generated_image, used_seed = generate_image(
img_prompt,
img_negative_prompt,
img_seed,
img_randomize_seed,
img_width,
img_height,
img_guidance_scale,
img_num_inference_steps
)
# Update the best video size based on generated image
available_sizes = list(SUPPORTED_SIZES[TASK_NAME])
best_size = select_best_size_for_image(generated_image, available_sizes)
# Then generate video using the generated image
video_path = generate_video(
generated_image,
video_prompt,
best_size, # Use auto-selected size
video_duration,
video_sampling_steps,
video_guide_scale,
video_shift,
video_seed
)
return generated_image, video_path, used_seed, best_size
# --- Gradio Interface ---
css = """
.gradio-container {max-width: 1400px !important; margin: 0 auto}
#output_video {height: 500px;}
#input_image {height: 400px;}
#generated_image {height: 400px;}
.tab-nav button {font-size: 18px !important; padding: 10px 20px !important;}
"""
# Prompt templates
video_templates = {
"Cinematic": "cinematic shot of {subject}, professional lighting, smooth camera movement, 4k quality",
"Animation": "animated style {subject}, vibrant colors, fluid motion, dynamic movement",
"Nature": "nature documentary footage of {subject}, wildlife photography, natural movement",
"Slow Motion": "slow motion capture of {subject}, high speed camera, detailed motion",
"Action": "dynamic action shot of {subject}, fast paced movement, energetic motion"
}
def apply_template(template, current_prompt):
"""Apply prompt template"""
if "{subject}" in template:
subject = current_prompt.split(",")[0] if "," in current_prompt else current_prompt
return template.replace("{subject}", subject)
return template + " " + current_prompt
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🎨 Integrated Text-to-Image-to-Video Generator
Generate images from text and convert them to high-quality videos using:
- Text-to-Image generation
- **Wan 2.2 5B** for Image-to-Video generation
### ✨ Features:
- 📝 **Text-to-Image**: Generate images from text descriptions
- 🎬 **Image-to-Video**: Convert images (uploaded or generated) to videos
- 🔄 **Text-to-Image-to-Video**: Complete pipeline from text to video
""")
# Badge section
with gr.Tabs() as tabs:
# Tab 1: Text-to-Image
with gr.Tab("Text to Image", id="t2i_tab"):
with gr.Row():
with gr.Column(scale=1):
t2i_prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want to generate...",
lines=3
)
t2i_negative_prompt = gr.Textbox(
label="Negative Prompt",
value=" (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
lines=2
)
with gr.Row():
t2i_width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
t2i_height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
with gr.Accordion("Advanced Settings", open=False):
t2i_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
t2i_randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
t2i_guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, step=0.1, value=7)
t2i_num_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
t2i_generate_btn = gr.Button("Generate Image", variant="primary", size="lg")
with gr.Column(scale=1):
t2i_output = gr.Image(label="Generated Image", elem_id="generated_image")
t2i_seed_output = gr.Number(label="Used Seed", interactive=False)
# Tab 2: Image-to-Video
with gr.Tab("Image to Video", id="i2v_tab"):
with gr.Row():
with gr.Column(scale=1):
i2v_image = gr.Image(type="numpy", label="Input Image", elem_id="input_image")
i2v_prompt = gr.Textbox(
label="Video Prompt",
value="Generate a video with smooth and natural movement. Objects should have visible motion while maintaining fluid transitions.",
lines=3
)
with gr.Accordion("Prompt Templates", open=False):
gr.Markdown("Click a template to apply it to your prompt:")
template_buttons = {}
for name, template in video_templates.items():
btn = gr.Button(name, size="sm")
template_buttons[name] = (btn, template)
i2v_duration = gr.Slider(
label="Duration (seconds)",
minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1),
step=0.1,
value=2.0
)
i2v_size = gr.Dropdown(
label="Output Resolution",
choices=list(SUPPORTED_SIZES[TASK_NAME]),
value="704*1280"
)
with gr.Accordion("Advanced Settings", open=False):
i2v_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=38, step=1)
i2v_guide_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=cfg.sample_guide_scale, step=0.1)
i2v_shift = gr.Slider(label="Sample Shift", minimum=1.0, maximum=20.0, value=cfg.sample_shift, step=0.1)
i2v_seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
i2v_generate_btn = gr.Button("Generate Video", variant="primary", size="lg")
with gr.Column(scale=1):
i2v_output = gr.Video(label="Generated Video", elem_id="output_video")
# Tab 3: Text-to-Image-to-Video
with gr.Tab("Text to Image to Video", id="t2i2v_tab"):
gr.Markdown("### 🎯 Complete Pipeline: Generate an image from text, then convert it to video")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### Step 1: Image Generation Settings")
t2i2v_img_prompt = gr.Textbox(
label="Image Prompt",
placeholder="Describe the image to generate...",
lines=3
)
t2i2v_img_negative = gr.Textbox(
label="Negative Prompt",
value="nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn",
lines=2
)
with gr.Row():
t2i2v_img_width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
t2i2v_img_height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
with gr.Accordion("Image Advanced Settings", open=False):
t2i2v_img_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
t2i2v_img_randomize = gr.Checkbox(label="Randomize seed", value=True)
t2i2v_img_guidance = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, step=0.1, value=7)
t2i2v_img_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
gr.Markdown("#### Step 2: Video Generation Settings")
t2i2v_video_prompt = gr.Textbox(
label="Video Prompt",
value="Generate a video with smooth and natural movement. Objects should have visible motion while maintaining fluid transitions.",
lines=3
)
t2i2v_video_duration = gr.Slider(
label="Duration (seconds)",
minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1),
step=0.1,
value=2.0
)
# Add the missing video size dropdown component
t2i2v_video_size = gr.Dropdown(
label="Video Output Resolution",
choices=list(SUPPORTED_SIZES[TASK_NAME]),
value="704*1280",
info="This will be auto-adjusted based on generated image aspect ratio"
)
with gr.Accordion("Video Advanced Settings", open=False):
t2i2v_video_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=38, step=1)
t2i2v_video_guide = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=cfg.sample_guide_scale, step=0.1)
t2i2v_video_shift = gr.Slider(label="Sample Shift", minimum=1.0, maximum=20.0, value=cfg.sample_shift, step=0.1)
t2i2v_video_seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
t2i2v_generate_btn = gr.Button("Generate Image → Video", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown("#### Results")
t2i2v_image_output = gr.Image(label="Generated Image", elem_id="generated_image")
t2i2v_video_output = gr.Video(label="Generated Video", elem_id="output_video")
with gr.Row():
t2i2v_seed_output = gr.Number(label="Image Seed Used", interactive=False)
t2i2v_size_output = gr.Textbox(label="Video Size Used", interactive=False)
# Event handlers
# Tab 1: Text-to-Image
t2i_generate_btn.click(
fn=generate_image,
inputs=[
t2i_prompt, t2i_negative_prompt, t2i_seed, t2i_randomize_seed,
t2i_width, t2i_height, t2i_guidance_scale, t2i_num_steps
],
outputs=[t2i_output, t2i_seed_output]
)
# Tab 2: Image-to-Video
# Connect template buttons
for name, (btn, template) in template_buttons.items():
btn.click(
fn=lambda t=template, p=i2v_prompt: apply_template(t, p),
inputs=[i2v_prompt],
outputs=i2v_prompt
)
# Auto-select best size when image is uploaded
def handle_image_upload(image):
if image is None:
return gr.update()
pil_image = Image.fromarray(image).convert("RGB")
available_sizes = list(SUPPORTED_SIZES[TASK_NAME])
best_size = select_best_size_for_image(pil_image, available_sizes)
return gr.update(value=best_size)
i2v_image.upload(
fn=handle_image_upload,
inputs=[i2v_image],
outputs=[i2v_size]
)
i2v_generate_btn.click(
fn=generate_video,
inputs=[
i2v_image, i2v_prompt, i2v_size, i2v_duration,
i2v_steps, i2v_guide_scale, i2v_shift, i2v_seed
],
outputs=i2v_output
)
# Tab 3: Text-to-Image-to-Video
t2i2v_generate_btn.click(
fn=generate_image_to_video,
inputs=[
t2i2v_img_prompt, t2i2v_img_negative, t2i2v_img_seed, t2i2v_img_randomize,
t2i2v_img_width, t2i2v_img_height, t2i2v_img_guidance, t2i2v_img_steps,
t2i2v_video_prompt, t2i2v_video_size, t2i2v_video_duration,
t2i2v_video_steps, t2i2v_video_guide, t2i2v_video_shift, t2i2v_video_seed
],
outputs=[t2i2v_image_output, t2i2v_video_output, t2i2v_seed_output, t2i2v_size_output]
)
# Examples
gr.Examples(
examples=[
["A majestic lion sitting on a rock at sunset, golden hour lighting, photorealistic", "Generate a video with the lion slowly turning its head and mane flowing in the wind"],
["A futuristic cyberpunk city with neon lights and flying cars", "Cinematic shot with smooth camera movement through the city streets"],
["A serene Japanese garden with cherry blossoms and a koi pond", "Gentle breeze causing cherry blossoms to fall, ripples in the pond"],
],
inputs=[t2i2v_img_prompt, t2i2v_video_prompt],
label="Example Prompts"
)
if __name__ == "__main__":
demo.launch()