Spaces:
Paused
Paused
import os | |
import sys | |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
import gradio as gr | |
import torch | |
from huggingface_hub import snapshot_download | |
from huggingface_hub import hf_hub_download | |
from PIL import Image | |
import random | |
import numpy as np | |
import spaces | |
import gc | |
# Import for Stable Diffusion XL | |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler | |
from compel import Compel, ReturnedEmbeddingsType | |
# Import for Wan2.2 | |
import wan | |
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES | |
from wan.utils.utils import cache_video | |
# --- Global Setup --- | |
print("Starting Integrated Text-to-Image-to-Video App...") | |
# --- 1. Setup Text-to-Image Model (SDXL) --- | |
print("Loading Stable Diffusion XL model...") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Initialize SDXL pipeline | |
sdxl_pipe = StableDiffusionXLPipeline.from_pretrained( | |
"votepurchase/pornmasterPro_noobV3VAE", | |
torch_dtype=torch.float16, | |
variant="fp16", | |
use_safetensors=True | |
) | |
# sdxl_pipe = StableDiffusionXLPipeline.from_pretrained( | |
# "stablediffusionapi/omnigenxl-nsfw-sfw", | |
# torch_dtype=torch.float16, | |
# variant="fp16", | |
# use_safetensors=True | |
# ) | |
sdxl_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sdxl_pipe.scheduler.config) | |
sdxl_pipe.to(device) | |
# Force all components to use the same dtype | |
sdxl_pipe.text_encoder.to(torch.float16) | |
sdxl_pipe.text_encoder_2.to(torch.float16) | |
sdxl_pipe.vae.to(torch.float16) | |
sdxl_pipe.unet.to(torch.float16) | |
# Initialize Compel for long prompt processing | |
compel = Compel( | |
tokenizer=[sdxl_pipe.tokenizer, sdxl_pipe.tokenizer_2], | |
text_encoder=[sdxl_pipe.text_encoder, sdxl_pipe.text_encoder_2], | |
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, | |
requires_pooled=[False, True], | |
truncate_long_prompts=False | |
) | |
# --- 2. Setup Image-to-Video Model (Wan2.2) --- | |
print("Loading Wan 2.2 TI2V-5B model...") | |
# Download model snapshots | |
repo_id = "Wan-AI/Wan2.2-TI2V-5B" | |
print(f"Downloading/loading checkpoints for {repo_id}...") | |
ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False) | |
print(f"Using checkpoints from {ckpt_dir}") | |
# Load the model configuration | |
TASK_NAME = 'ti2v-5B' | |
cfg = WAN_CONFIGS[TASK_NAME] | |
FIXED_FPS = 24 | |
MIN_FRAMES_MODEL = 8 | |
MAX_FRAMES_MODEL = 121 | |
# Instantiate the pipeline | |
device_id = 0 if torch.cuda.is_available() else -1 | |
wan_pipeline = wan.WanTI2V( | |
config=cfg, | |
checkpoint_dir=ckpt_dir, | |
device_id=device_id, | |
rank=0, | |
t5_fsdp=False, | |
dit_fsdp=False, | |
use_sp=False, | |
t5_cpu=False, | |
init_on_cpu=False, | |
convert_model_dtype=True, | |
) | |
###LORA #### | |
# LORA_REPO_ID = "Kijai/WanVideo_comfy" | |
# LORA_FILENAME = "Wan22-Lightning/Wan2.2-Lightning_T2V-A14B-4steps-lora_LOW_fp16.safetensors" | |
# causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME) | |
# wan_pipeline.load_lora_weights(causvid_path, adapter_name="causvid_lora") | |
# wan_pipeline.set_adapters(["causvid_lora"], adapter_weights=[0.95]) | |
# wan_pipeline.fuse_lora() | |
print("All models loaded and ready.") | |
# --- Constants --- | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1216 | |
# --- Helper Functions --- | |
def clear_gpu_memory(): | |
"""Clear GPU memory more thoroughly""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.ipc_collect() | |
gc.collect() | |
def process_long_prompt(prompt, negative_prompt=""): | |
"""Simple long prompt processing using Compel""" | |
try: | |
conditioning, pooled = compel([prompt, negative_prompt]) | |
return conditioning, pooled | |
except Exception as e: | |
print(f"Long prompt processing failed: {e}, falling back to standard processing") | |
return None, None | |
def select_best_size_for_image(image, available_sizes): | |
"""Select the size option with aspect ratio closest to the input image.""" | |
if image is None: | |
return available_sizes[0] | |
img_width, img_height = image.size | |
img_aspect_ratio = img_height / img_width | |
best_size = available_sizes[0] | |
best_diff = float('inf') | |
for size_str in available_sizes: | |
height, width = map(int, size_str.split('*')) | |
size_aspect_ratio = height / width | |
diff = abs(img_aspect_ratio - size_aspect_ratio) | |
if diff < best_diff: | |
best_diff = diff | |
best_size = size_str | |
return best_size | |
def validate_video_inputs(image, prompt, duration_seconds): | |
"""Validate user inputs for video generation""" | |
errors = [] | |
if not prompt or len(prompt.strip()) < 5: | |
errors.append("Prompt must be at least 5 characters long.") | |
if image is not None: | |
if isinstance(image, np.ndarray): | |
img = Image.fromarray(image) | |
else: | |
img = image | |
if img.size[0] * img.size[1] > 4096 * 4096: | |
errors.append("Image size is too large (maximum 4096x4096).") | |
if duration_seconds > 5.0 and image is None: | |
errors.append("Videos longer than 5 seconds require an input image.") | |
return errors | |
# --- Text-to-Image Generation Function --- | |
def get_t_duration(prompt, | |
negative_prompt, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, progress): | |
"""Calculate dynamic GPU duration for video generation""" | |
if sampling_steps > 35 and duration_seconds >= 2: | |
return 4 | |
elif sampling_steps < 35 or duration_seconds < 2: | |
return 105 | |
else: | |
return 90 | |
def generate_image( | |
prompt, | |
negative_prompt, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
"""Generate image from text prompt""" | |
progress(0, desc="Initializing image generation...") | |
use_long_prompt = len(prompt.split()) > 60 or len(prompt) > 300 | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
try: | |
progress(0.3, desc="Processing prompt...") | |
if use_long_prompt: | |
print("Using long prompt processing...") | |
conditioning, pooled = process_long_prompt(prompt, negative_prompt) | |
if conditioning is not None: | |
progress(0.5, desc="Generating image...") | |
output_image = sdxl_pipe( | |
prompt_embeds=conditioning[0:1], | |
pooled_prompt_embeds=pooled[0:1], | |
negative_prompt_embeds=conditioning[1:2], | |
negative_pooled_prompt_embeds=pooled[1:2], | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator | |
).images[0] | |
progress(1.0, desc="Complete!") | |
return output_image, seed | |
# Fall back to standard processing | |
progress(0.5, desc="Generating image...") | |
output_image = sdxl_pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator | |
).images[0] | |
progress(1.0, desc="Complete!") | |
return output_image, seed | |
except RuntimeError as e: | |
print(f"Error during generation: {e}") | |
error_img = Image.new('RGB', (width, height), color=(0, 0, 0)) | |
return error_img, seed | |
finally: | |
clear_gpu_memory() | |
# --- Image-to-Video Generation Function --- | |
def get_video_duration(image, prompt, size, duration_seconds, sampling_steps, guide_scale, shift, seed, progress): | |
"""Calculate dynamic GPU duration for video generation""" | |
if sampling_steps > 35 and duration_seconds >= 2: | |
return 120 | |
elif sampling_steps < 35 or duration_seconds < 2: | |
return 105 | |
else: | |
return 90 | |
def generate_video( | |
image, | |
prompt, | |
size, | |
duration_seconds, | |
sampling_steps, | |
guide_scale, | |
shift, | |
seed, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
"""Generate video from image and prompt""" | |
errors = validate_video_inputs(image, prompt, duration_seconds) | |
if errors: | |
raise gr.Error("\n".join(errors)) | |
progress(0, desc="Setting up video generation...") | |
if seed == -1: | |
seed = random.randint(0, sys.maxsize) | |
progress(0.1, desc="Processing image...") | |
input_image = None | |
if image is not None: | |
if isinstance(image, np.ndarray): | |
input_image = Image.fromarray(image).convert("RGB") | |
else: | |
input_image = image.convert("RGB") | |
# Resize image to match selected size | |
target_height, target_width = map(int, size.split('*')) | |
input_image = input_image.resize((target_width, target_height)) | |
# Calculate number of frames based on duration | |
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) | |
progress(0.2, desc="Generating video...") | |
try: | |
video_tensor = wan_pipeline.generate( | |
input_prompt=prompt, | |
img=input_image, | |
size=SIZE_CONFIGS[size], | |
max_area=MAX_AREA_CONFIGS[size], | |
frame_num=num_frames, | |
shift=shift, | |
sample_solver='unipc', | |
sampling_steps=int(sampling_steps), | |
guide_scale=guide_scale, | |
seed=seed, | |
offload_model=True | |
) | |
progress(0.9, desc="Saving video...") | |
video_path = cache_video( | |
tensor=video_tensor[None], | |
save_file=None, | |
fps=cfg.sample_fps, | |
normalize=True, | |
value_range=(-1, 1) | |
) | |
progress(1.0, desc="Complete!") | |
except torch.cuda.OutOfMemoryError: | |
clear_gpu_memory() | |
raise gr.Error("GPU out of memory. Please try with lower settings.") | |
except Exception as e: | |
raise gr.Error(f"Video generation failed: {str(e)}") | |
finally: | |
if 'video_tensor' in locals(): | |
del video_tensor | |
clear_gpu_memory() | |
return video_path | |
# --- Combined Generation Function --- | |
def generate_image_to_video( | |
img_prompt, | |
img_negative_prompt, | |
img_seed, | |
img_randomize_seed, | |
img_width, | |
img_height, | |
img_guidance_scale, | |
img_num_inference_steps, | |
video_prompt, | |
video_size, | |
video_duration, | |
video_sampling_steps, | |
video_guide_scale, | |
video_shift, | |
video_seed | |
): | |
"""Generate image from text, then use it to generate video""" | |
# First generate image | |
generated_image, used_seed = generate_image( | |
img_prompt, | |
img_negative_prompt, | |
img_seed, | |
img_randomize_seed, | |
img_width, | |
img_height, | |
img_guidance_scale, | |
img_num_inference_steps | |
) | |
# Update the best video size based on generated image | |
available_sizes = list(SUPPORTED_SIZES[TASK_NAME]) | |
best_size = select_best_size_for_image(generated_image, available_sizes) | |
# Then generate video using the generated image | |
video_path = generate_video( | |
generated_image, | |
video_prompt, | |
best_size, # Use auto-selected size | |
video_duration, | |
video_sampling_steps, | |
video_guide_scale, | |
video_shift, | |
video_seed | |
) | |
return generated_image, video_path, used_seed, best_size | |
# --- Gradio Interface --- | |
css = """ | |
.gradio-container {max-width: 1400px !important; margin: 0 auto} | |
#output_video {height: 500px;} | |
#input_image {height: 400px;} | |
#generated_image {height: 400px;} | |
.tab-nav button {font-size: 18px !important; padding: 10px 20px !important;} | |
""" | |
# Prompt templates | |
video_templates = { | |
"Cinematic": "cinematic shot of {subject}, professional lighting, smooth camera movement, 4k quality", | |
"Animation": "animated style {subject}, vibrant colors, fluid motion, dynamic movement", | |
"Nature": "nature documentary footage of {subject}, wildlife photography, natural movement", | |
"Slow Motion": "slow motion capture of {subject}, high speed camera, detailed motion", | |
"Action": "dynamic action shot of {subject}, fast paced movement, energetic motion" | |
} | |
def apply_template(template, current_prompt): | |
"""Apply prompt template""" | |
if "{subject}" in template: | |
subject = current_prompt.split(",")[0] if "," in current_prompt else current_prompt | |
return template.replace("{subject}", subject) | |
return template + " " + current_prompt | |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# 🎨 Integrated Text-to-Image-to-Video Generator | |
Generate images from text and convert them to high-quality videos using: | |
- Text-to-Image generation | |
- **Wan 2.2 5B** for Image-to-Video generation | |
### ✨ Features: | |
- 📝 **Text-to-Image**: Generate images from text descriptions | |
- 🎬 **Image-to-Video**: Convert images (uploaded or generated) to videos | |
- 🔄 **Text-to-Image-to-Video**: Complete pipeline from text to video | |
""") | |
# Badge section | |
with gr.Tabs() as tabs: | |
# Tab 1: Text-to-Image | |
with gr.Tab("Text to Image", id="t2i_tab"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
t2i_prompt = gr.Textbox( | |
label="Prompt", | |
placeholder="Describe the image you want to generate...", | |
lines=3 | |
) | |
t2i_negative_prompt = gr.Textbox( | |
label="Negative Prompt", | |
value=" (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn", | |
lines=2 | |
) | |
with gr.Row(): | |
t2i_width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
t2i_height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
with gr.Accordion("Advanced Settings", open=False): | |
t2i_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
t2i_randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
t2i_guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, step=0.1, value=7) | |
t2i_num_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28) | |
t2i_generate_btn = gr.Button("Generate Image", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
t2i_output = gr.Image(label="Generated Image", elem_id="generated_image") | |
t2i_seed_output = gr.Number(label="Used Seed", interactive=False) | |
# Tab 2: Image-to-Video | |
with gr.Tab("Image to Video", id="i2v_tab"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
i2v_image = gr.Image(type="numpy", label="Input Image", elem_id="input_image") | |
i2v_prompt = gr.Textbox( | |
label="Video Prompt", | |
value="Generate a video with smooth and natural movement. Objects should have visible motion while maintaining fluid transitions.", | |
lines=3 | |
) | |
with gr.Accordion("Prompt Templates", open=False): | |
gr.Markdown("Click a template to apply it to your prompt:") | |
template_buttons = {} | |
for name, template in video_templates.items(): | |
btn = gr.Button(name, size="sm") | |
template_buttons[name] = (btn, template) | |
i2v_duration = gr.Slider( | |
label="Duration (seconds)", | |
minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1), | |
maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1), | |
step=0.1, | |
value=2.0 | |
) | |
i2v_size = gr.Dropdown( | |
label="Output Resolution", | |
choices=list(SUPPORTED_SIZES[TASK_NAME]), | |
value="704*1280" | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
i2v_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=38, step=1) | |
i2v_guide_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=cfg.sample_guide_scale, step=0.1) | |
i2v_shift = gr.Slider(label="Sample Shift", minimum=1.0, maximum=20.0, value=cfg.sample_shift, step=0.1) | |
i2v_seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0) | |
i2v_generate_btn = gr.Button("Generate Video", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
i2v_output = gr.Video(label="Generated Video", elem_id="output_video") | |
# Tab 3: Text-to-Image-to-Video | |
with gr.Tab("Text to Image to Video", id="t2i2v_tab"): | |
gr.Markdown("### 🎯 Complete Pipeline: Generate an image from text, then convert it to video") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("#### Step 1: Image Generation Settings") | |
t2i2v_img_prompt = gr.Textbox( | |
label="Image Prompt", | |
placeholder="Describe the image to generate...", | |
lines=3 | |
) | |
t2i2v_img_negative = gr.Textbox( | |
label="Negative Prompt", | |
value="nsfw, (low quality, worst quality:1.2), very displeasing, 3d, watermark, signature, ugly, poorly drawn", | |
lines=2 | |
) | |
with gr.Row(): | |
t2i2v_img_width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
t2i2v_img_height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024) | |
with gr.Accordion("Image Advanced Settings", open=False): | |
t2i2v_img_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
t2i2v_img_randomize = gr.Checkbox(label="Randomize seed", value=True) | |
t2i2v_img_guidance = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, step=0.1, value=7) | |
t2i2v_img_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28) | |
gr.Markdown("#### Step 2: Video Generation Settings") | |
t2i2v_video_prompt = gr.Textbox( | |
label="Video Prompt", | |
value="Generate a video with smooth and natural movement. Objects should have visible motion while maintaining fluid transitions.", | |
lines=3 | |
) | |
t2i2v_video_duration = gr.Slider( | |
label="Duration (seconds)", | |
minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1), | |
maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1), | |
step=0.1, | |
value=2.0 | |
) | |
# Add the missing video size dropdown component | |
t2i2v_video_size = gr.Dropdown( | |
label="Video Output Resolution", | |
choices=list(SUPPORTED_SIZES[TASK_NAME]), | |
value="704*1280", | |
info="This will be auto-adjusted based on generated image aspect ratio" | |
) | |
with gr.Accordion("Video Advanced Settings", open=False): | |
t2i2v_video_steps = gr.Slider(label="Sampling Steps", minimum=10, maximum=50, value=38, step=1) | |
t2i2v_video_guide = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=cfg.sample_guide_scale, step=0.1) | |
t2i2v_video_shift = gr.Slider(label="Sample Shift", minimum=1.0, maximum=20.0, value=cfg.sample_shift, step=0.1) | |
t2i2v_video_seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0) | |
t2i2v_generate_btn = gr.Button("Generate Image → Video", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
gr.Markdown("#### Results") | |
t2i2v_image_output = gr.Image(label="Generated Image", elem_id="generated_image") | |
t2i2v_video_output = gr.Video(label="Generated Video", elem_id="output_video") | |
with gr.Row(): | |
t2i2v_seed_output = gr.Number(label="Image Seed Used", interactive=False) | |
t2i2v_size_output = gr.Textbox(label="Video Size Used", interactive=False) | |
# Event handlers | |
# Tab 1: Text-to-Image | |
t2i_generate_btn.click( | |
fn=generate_image, | |
inputs=[ | |
t2i_prompt, t2i_negative_prompt, t2i_seed, t2i_randomize_seed, | |
t2i_width, t2i_height, t2i_guidance_scale, t2i_num_steps | |
], | |
outputs=[t2i_output, t2i_seed_output] | |
) | |
# Tab 2: Image-to-Video | |
# Connect template buttons | |
for name, (btn, template) in template_buttons.items(): | |
btn.click( | |
fn=lambda t=template, p=i2v_prompt: apply_template(t, p), | |
inputs=[i2v_prompt], | |
outputs=i2v_prompt | |
) | |
# Auto-select best size when image is uploaded | |
def handle_image_upload(image): | |
if image is None: | |
return gr.update() | |
pil_image = Image.fromarray(image).convert("RGB") | |
available_sizes = list(SUPPORTED_SIZES[TASK_NAME]) | |
best_size = select_best_size_for_image(pil_image, available_sizes) | |
return gr.update(value=best_size) | |
i2v_image.upload( | |
fn=handle_image_upload, | |
inputs=[i2v_image], | |
outputs=[i2v_size] | |
) | |
i2v_generate_btn.click( | |
fn=generate_video, | |
inputs=[ | |
i2v_image, i2v_prompt, i2v_size, i2v_duration, | |
i2v_steps, i2v_guide_scale, i2v_shift, i2v_seed | |
], | |
outputs=i2v_output | |
) | |
# Tab 3: Text-to-Image-to-Video | |
t2i2v_generate_btn.click( | |
fn=generate_image_to_video, | |
inputs=[ | |
t2i2v_img_prompt, t2i2v_img_negative, t2i2v_img_seed, t2i2v_img_randomize, | |
t2i2v_img_width, t2i2v_img_height, t2i2v_img_guidance, t2i2v_img_steps, | |
t2i2v_video_prompt, t2i2v_video_size, t2i2v_video_duration, | |
t2i2v_video_steps, t2i2v_video_guide, t2i2v_video_shift, t2i2v_video_seed | |
], | |
outputs=[t2i2v_image_output, t2i2v_video_output, t2i2v_seed_output, t2i2v_size_output] | |
) | |
# Examples | |
gr.Examples( | |
examples=[ | |
["A majestic lion sitting on a rock at sunset, golden hour lighting, photorealistic", "Generate a video with the lion slowly turning its head and mane flowing in the wind"], | |
["A futuristic cyberpunk city with neon lights and flying cars", "Cinematic shot with smooth camera movement through the city streets"], | |
["A serene Japanese garden with cherry blossoms and a koi pond", "Gentle breeze causing cherry blossoms to fall, ripples in the pond"], | |
], | |
inputs=[t2i2v_img_prompt, t2i2v_video_prompt], | |
label="Example Prompts" | |
) | |
if __name__ == "__main__": | |
demo.launch() |