rahul7star's picture
Update DF.py
4381239 verified
raw
history blame
3.69 kB
import os
import time
import uuid
import torch
import gradio as gr
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
from dfloat11 import DFloat11Model
import spaces
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
@spaces.GPU(enable_queue=True)
def generate_video(prompt, negative_prompt, width, height, num_frames,
guidance_scale, guidance_scale_2, num_inference_steps, fps, cpu_offload):
start_time = time.time()
torch.cuda.empty_cache()
# Load VAE and Wan2.2 pipeline
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.2-T2V-A14B-Diffusers",
subfolder="vae",
torch_dtype=torch.float32,
)
pipe = WanPipeline.from_pretrained(
"Wan-AI/Wan2.2-T2V-A14B-Diffusers",
vae=vae,
torch_dtype=torch.bfloat16,
)
# Only apply second-stage DFloat11 model
DFloat11Model.from_pretrained(
"DFloat11/Wan2.2-T2V-A14B-2-DF11",
device="cpu",
cpu_offload=cpu_offload,
bfloat16_model=pipe.transformer_2,
)
pipe.enable_model_cpu_offload()
# Generate video frames
output_frames = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=guidance_scale,
guidance_scale_2=guidance_scale_2,
num_inference_steps=num_inference_steps,
).frames[0]
# Export to video
output_path = f"/tmp/{uuid.uuid4().hex}_t2v.mp4"
export_to_video(output_frames, output_path, fps=fps)
elapsed = time.time() - start_time
print(f"✅ Generated in {elapsed:.2f}s, saved to {output_path}")
return output_path
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🎬 Wan2.2 + DFloat11 (Stage 2 only) - Text to Video Generator")
with gr.Row():
prompt = gr.Textbox(label="Prompt", value="A serene koi pond at night, with glowing lanterns reflecting on the rippling water. Ethereal fireflies dance above as cherry blossoms gently fall.", lines=3)
negative_prompt = gr.Textbox(label="Negative Prompt", value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", lines=3)
with gr.Row():
width = gr.Slider(256, 1280, value=1280, step=64, label="Width")
height = gr.Slider(256, 720, value=720, step=64, label="Height")
fps = gr.Slider(8, 30, value=16, step=1, label="FPS")
with gr.Row():
num_frames = gr.Slider(8, 81, value=81, step=1, label="Frames")
num_inference_steps = gr.Slider(10, 60, value=40, step=1, label="Inference Steps")
with gr.Row():
guidance_scale = gr.Slider(1.0, 10.0, value=4.0, step=0.1, label="Guidance Scale (Stage 1)")
guidance_scale_2 = gr.Slider(1.0, 10.0, value=3.0, step=0.1, label="Guidance Scale (Stage 2)")
cpu_offload = gr.Checkbox(label="Enable CPU Offload", value=True)
with gr.Row():
btn = gr.Button("🚀 Generate Video")
output_video = gr.Video(label="Generated Video")
btn.click(
generate_video,
inputs=[prompt, negative_prompt, width, height, num_frames, guidance_scale, guidance_scale_2, num_inference_steps, fps, cpu_offload],
outputs=[output_video]
)
demo.launch()