jbilcke-hf HF Staff commited on
Commit
d9acd84
·
verified ·
1 Parent(s): 1d561d4

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +543 -0
handler.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ import logging
4
+ import base64
5
+ import random
6
+ import gc
7
+ import os
8
+ import numpy as np
9
+ import torch
10
+ from typing import Dict, Any, Optional, List, Union, Tuple
11
+ import json
12
+ from omegaconf import OmegaConf
13
+ from PIL import Image
14
+ import io
15
+
16
+ from pipeline import CausalInferencePipeline
17
+ from demo_utils.constant import ZERO_VAE_CACHE
18
+ from demo_utils.vae_block3 import VAEDecoderWrapper
19
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # Get token from environment
26
+ hf_token = os.getenv("HF_API_TOKEN")
27
+
28
+ # Constraints
29
+ MAX_LARGE_SIDE = 1280
30
+ MAX_SMALL_SIDE = 768
31
+ MAX_FRAMES = 169 # Based on Wan model capabilities
32
+
33
+ @dataclass
34
+ class GenerationConfig:
35
+ """Configuration for video generation using Wan model"""
36
+
37
+ # general content settings
38
+ prompt: str = ""
39
+ negative_prompt: str = "worst quality, lowres, blurry, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles"
40
+
41
+ # video model settings
42
+ width: int = 960 # Wan model default width
43
+ height: int = 576 # Wan model default height
44
+
45
+ # number of frames (based on Wan model block structure)
46
+ num_frames: int = 105 # 7 blocks * 15 frames per block
47
+
48
+ # guidance and sampling settings
49
+ guidance_scale: float = 7.5
50
+ num_inference_steps: int = 4 # Distilled model uses fewer steps
51
+
52
+ # reproducible generation settings
53
+ seed: int = -1 # -1 means random seed
54
+
55
+ # output settings
56
+ fps: int = 15 # FPS of the final video
57
+ quality: int = 18 # Video quality (CRF)
58
+
59
+ # advanced settings
60
+ mixed_precision: bool = True
61
+ use_taehv: bool = False # Whether to use TAEHV decoder
62
+ use_trt: bool = False # Whether to use TensorRT optimized decoder
63
+
64
+ def validate_and_adjust(self) -> 'GenerationConfig':
65
+ """Validate and adjust parameters to meet constraints"""
66
+ # Ensure dimensions are multiples of 32 and within limits
67
+ self.width = max(128, min(MAX_LARGE_SIDE, round(self.width / 32) * 32))
68
+ self.height = max(128, min(MAX_LARGE_SIDE, round(self.height / 32) * 32))
69
+
70
+ # Ensure frame count is reasonable
71
+ self.num_frames = min(self.num_frames, MAX_FRAMES)
72
+
73
+ # Set random seed if not specified
74
+ if self.seed == -1:
75
+ self.seed = random.randint(0, 2**32 - 1)
76
+
77
+ return self
78
+
79
+ def load_image_to_tensor_with_resize_and_crop(
80
+ image_input: Union[str, bytes],
81
+ target_height: int = 576,
82
+ target_width: int = 960,
83
+ quality: int = 100
84
+ ) -> torch.Tensor:
85
+ """Load and process an image into a tensor for Wan model.
86
+
87
+ Args:
88
+ image_input: Either a file path (str) or image data (bytes)
89
+ target_height: Desired height of output tensor
90
+ target_width: Desired width of output tensor
91
+ quality: JPEG quality to use when re-encoding
92
+ """
93
+ # Handle base64 data URI
94
+ if isinstance(image_input, str) and image_input.startswith('data:'):
95
+ header, encoded = image_input.split(",", 1)
96
+ image_data = base64.b64decode(encoded)
97
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
98
+ # Handle raw bytes
99
+ elif isinstance(image_input, bytes):
100
+ image = Image.open(io.BytesIO(image_input)).convert("RGB")
101
+ # Handle file path
102
+ elif isinstance(image_input, str):
103
+ image = Image.open(image_input).convert("RGB")
104
+ else:
105
+ raise ValueError("image_input must be either a file path, bytes, or base64 data URI")
106
+
107
+ # Apply JPEG compression if quality < 100
108
+ if quality < 100:
109
+ buffer = io.BytesIO()
110
+ image.save(buffer, format="JPEG", quality=quality)
111
+ buffer.seek(0)
112
+ image = Image.open(buffer).convert("RGB")
113
+
114
+ # Resize and crop to target dimensions
115
+ input_width, input_height = image.size
116
+ aspect_ratio_target = target_width / target_height
117
+ aspect_ratio_frame = input_width / input_height
118
+
119
+ if aspect_ratio_frame > aspect_ratio_target:
120
+ new_width = int(input_height * aspect_ratio_target)
121
+ new_height = input_height
122
+ x_start = (input_width - new_width) // 2
123
+ y_start = 0
124
+ else:
125
+ new_width = input_width
126
+ new_height = int(input_width / aspect_ratio_target)
127
+ x_start = 0
128
+ y_start = (input_height - new_height) // 2
129
+
130
+ image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
131
+ image = image.resize((target_width, target_height))
132
+
133
+ # Convert to tensor format expected by Wan model
134
+ frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
135
+ frame_tensor = (frame_tensor / 127.5) - 1.0
136
+
137
+ return frame_tensor.unsqueeze(0)
138
+
139
+ def initialize_vae_decoder(use_taehv=False, use_trt=False, device="cuda"):
140
+ """Initialize VAE decoder based on configuration"""
141
+ if use_trt:
142
+ from demo_utils.vae import VAETRTWrapper
143
+ print("Initializing TensorRT VAE Decoder...")
144
+ vae_decoder = VAETRTWrapper()
145
+ elif use_taehv:
146
+ print("Initializing TAEHV VAE Decoder...")
147
+ from demo_utils.taehv import TAEHV
148
+ taehv_checkpoint_path = "checkpoints/taew2_1.pth"
149
+
150
+ if not os.path.exists(taehv_checkpoint_path):
151
+ print(f"Downloading TAEHV checkpoint to {taehv_checkpoint_path}...")
152
+ os.makedirs("checkpoints", exist_ok=True)
153
+ import urllib.request
154
+ download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
155
+ try:
156
+ urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
157
+ except Exception as e:
158
+ raise RuntimeError(f"Failed to download taew2_1.pth: {e}")
159
+
160
+ class DotDict(dict):
161
+ __getattr__ = dict.get
162
+
163
+ class TAEHVDiffusersWrapper(torch.nn.Module):
164
+ def __init__(self):
165
+ super().__init__()
166
+ self.dtype = torch.float16
167
+ self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
168
+ self.config = DotDict(scaling_factor=1.0)
169
+
170
+ def decode(self, latents, return_dict=None):
171
+ return self.taehv.decode_video(latents, parallel=True).mul_(2).sub_(1)
172
+
173
+ vae_decoder = TAEHVDiffusersWrapper()
174
+ else:
175
+ print("Initializing Default VAE Decoder...")
176
+ vae_decoder = VAEDecoderWrapper()
177
+ try:
178
+ vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
179
+ decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k}
180
+ vae_decoder.load_state_dict(decoder_state_dict)
181
+ except FileNotFoundError:
182
+ print("Warning: Default VAE weights not found.")
183
+
184
+ vae_decoder.eval().to(dtype=torch.float16).requires_grad_(False).to(device)
185
+ return vae_decoder
186
+
187
+ def create_wan_pipeline(
188
+ config: GenerationConfig,
189
+ device: str = "cuda"
190
+ ) -> CausalInferencePipeline:
191
+ """Create and configure the Wan video pipeline"""
192
+
193
+ # Load configuration
194
+ try:
195
+ wan_config = OmegaConf.load("configs/self_forcing_dmd.yaml")
196
+ default_config = OmegaConf.load("configs/default_config.yaml")
197
+ wan_config = OmegaConf.merge(default_config, wan_config)
198
+ except FileNotFoundError as e:
199
+ logger.error(f"Error loading config file: {e}")
200
+ raise RuntimeError(f"Config files not found: {e}")
201
+
202
+ # Initialize model components
203
+ text_encoder = WanTextEncoder()
204
+ transformer = WanDiffusionWrapper(is_causal=True)
205
+
206
+ # Load checkpoint
207
+ checkpoint_path = "./checkpoints/self_forcing_dmd.pt"
208
+ try:
209
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
210
+ transformer.load_state_dict(state_dict.get('generator_ema', state_dict.get('generator')))
211
+ except FileNotFoundError as e:
212
+ logger.error(f"Error loading checkpoint: {e}")
213
+ raise RuntimeError(f"Checkpoint not found: {checkpoint_path}")
214
+
215
+ # Move to device and set precision
216
+ text_encoder.eval().to(dtype=torch.float16).requires_grad_(False).to(device)
217
+ transformer.eval().to(dtype=torch.float16).requires_grad_(False).to(device)
218
+
219
+ # Initialize VAE decoder
220
+ vae_decoder = initialize_vae_decoder(
221
+ use_taehv=config.use_taehv,
222
+ use_trt=config.use_trt,
223
+ device=device
224
+ )
225
+
226
+ # Create pipeline
227
+ pipeline = CausalInferencePipeline(
228
+ wan_config,
229
+ device=device,
230
+ generator=transformer,
231
+ text_encoder=text_encoder,
232
+ vae=vae_decoder
233
+ )
234
+
235
+ pipeline.to(dtype=torch.float16).to(device)
236
+
237
+ return pipeline
238
+
239
+ def frames_to_video_bytes(frames: List[np.ndarray], fps: int = 15, quality: int = 18) -> bytes:
240
+ """Convert frames to MP4 video bytes"""
241
+ import tempfile
242
+ import subprocess
243
+
244
+ with tempfile.TemporaryDirectory() as temp_dir:
245
+ # Save frames as images
246
+ frame_paths = []
247
+ for i, frame in enumerate(frames):
248
+ frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png")
249
+ Image.fromarray(frame).save(frame_path)
250
+ frame_paths.append(frame_path)
251
+
252
+ # Create video using ffmpeg
253
+ output_path = os.path.join(temp_dir, "output.mp4")
254
+ cmd = [
255
+ "ffmpeg", "-y", "-framerate", str(fps),
256
+ "-i", os.path.join(temp_dir, "frame_%06d.png"),
257
+ "-c:v", "libx264", "-crf", str(quality),
258
+ "-pix_fmt", "yuv420p", "-movflags", "faststart",
259
+ output_path
260
+ ]
261
+
262
+ try:
263
+ subprocess.run(cmd, check=True, capture_output=True)
264
+ with open(output_path, "rb") as f:
265
+ return f.read()
266
+ except subprocess.CalledProcessError as e:
267
+ logger.error(f"FFmpeg error: {e}")
268
+ raise RuntimeError(f"Video encoding failed: {e}")
269
+
270
+ class EndpointHandler:
271
+ """Handler for the Wan Video endpoint"""
272
+
273
+ def __init__(self, model_path: str = "./"):
274
+ """Initialize the endpoint handler
275
+
276
+ Args:
277
+ model_path: Path to model weights
278
+ """
279
+ # Enable TF32 for potential speedup on Ampere GPUs
280
+ torch.backends.cuda.matmul.allow_tf32 = True
281
+
282
+ # The pipeline will be loaded during inference to save memory
283
+ self.pipeline = None
284
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
285
+
286
+ # Perform warm-up inference if GPU is available
287
+ if self.device == "cuda":
288
+ logger.info("Performing warm-up inference...")
289
+ self._warmup()
290
+ logger.info("Warm-up completed!")
291
+ else:
292
+ logger.info("CPU device detected, skipping warm-up")
293
+
294
+ def _warmup(self):
295
+ """Perform a warm-up inference to prepare the model for future requests"""
296
+ try:
297
+ # Create a simple test configuration
298
+ test_config = GenerationConfig(
299
+ prompt="a cat walking",
300
+ negative_prompt="worst quality, lowres",
301
+ width=480, # Smaller resolution for faster warm-up
302
+ height=320,
303
+ num_frames=33, # Fewer frames for faster warm-up
304
+ guidance_scale=7.5,
305
+ num_inference_steps=2, # Fewer steps for faster warm-up
306
+ seed=42, # Fixed seed for consistent warm-up
307
+ fps=15,
308
+ mixed_precision=True,
309
+ ).validate_and_adjust()
310
+
311
+ # Create the pipeline if it doesn't exist
312
+ if self.pipeline is None:
313
+ self.pipeline = create_wan_pipeline(test_config, self.device)
314
+
315
+ # Run a quick inference
316
+ with torch.no_grad():
317
+ # Set seeds for reproducibility
318
+ random.seed(test_config.seed)
319
+ np.random.seed(test_config.seed)
320
+ torch.manual_seed(test_config.seed)
321
+
322
+ # Generate video frames (simplified version)
323
+ conditional_dict = self.pipeline.text_encoder(text_prompts=[test_config.prompt])
324
+ for key, value in conditional_dict.items():
325
+ conditional_dict[key] = value.to(dtype=torch.float16)
326
+
327
+ rnd = torch.Generator(self.device).manual_seed(int(test_config.seed))
328
+ self.pipeline._initialize_kv_cache(1, torch.float16, device=self.device)
329
+ self.pipeline._initialize_crossattn_cache(1, torch.float16, device=self.device)
330
+
331
+ # Generate a small noise tensor for testing
332
+ noise = torch.randn([1, 3, 8, 20, 32], device=self.device, dtype=torch.float16, generator=rnd)
333
+
334
+ # Clean up
335
+ del noise, conditional_dict
336
+ torch.cuda.empty_cache()
337
+ gc.collect()
338
+
339
+ logger.info("Warm-up successful!")
340
+
341
+ except Exception as e:
342
+ # Log the error but don't fail initialization
343
+ import traceback
344
+ error_message = f"Warm-up failed (but this is non-critical): {str(e)}\n{traceback.format_exc()}"
345
+ logger.warning(error_message)
346
+
347
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
348
+ """Process inference requests
349
+
350
+ Args:
351
+ data: Request data containing inputs and parameters
352
+
353
+ Returns:
354
+ Dictionary with generated video and metadata
355
+ """
356
+ # Extract inputs and parameters
357
+ inputs = data.get("inputs", {})
358
+
359
+ # Support both formats:
360
+ # 1. {"inputs": {"prompt": "...", "image": "..."}}
361
+ # 2. {"inputs": "..."} (prompt only)
362
+ if isinstance(inputs, str):
363
+ input_prompt = inputs
364
+ input_image = None
365
+ else:
366
+ input_prompt = inputs.get("prompt", "")
367
+ input_image = inputs.get("image")
368
+
369
+ params = data.get("parameters", {})
370
+
371
+ if not input_prompt:
372
+ raise ValueError("Prompt must be provided")
373
+
374
+ # Create and validate configuration
375
+ config = GenerationConfig(
376
+ # general content settings
377
+ prompt=input_prompt,
378
+ negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),
379
+
380
+ # video model settings
381
+ width=params.get("width", GenerationConfig.width),
382
+ height=params.get("height", GenerationConfig.height),
383
+ num_frames=params.get("num_frames", GenerationConfig.num_frames),
384
+ guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
385
+ num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),
386
+
387
+ # reproducible generation settings
388
+ seed=params.get("seed", GenerationConfig.seed),
389
+
390
+ # output settings
391
+ fps=params.get("fps", GenerationConfig.fps),
392
+ quality=params.get("quality", GenerationConfig.quality),
393
+
394
+ # advanced settings
395
+ mixed_precision=params.get("mixed_precision", GenerationConfig.mixed_precision),
396
+ use_taehv=params.get("use_taehv", GenerationConfig.use_taehv),
397
+ use_trt=params.get("use_trt", GenerationConfig.use_trt),
398
+ ).validate_and_adjust()
399
+
400
+ try:
401
+ with torch.no_grad():
402
+ # Set random seeds for reproducibility
403
+ random.seed(config.seed)
404
+ np.random.seed(config.seed)
405
+ torch.manual_seed(config.seed)
406
+
407
+ # Create pipeline if not already created
408
+ if self.pipeline is None:
409
+ self.pipeline = create_wan_pipeline(config, self.device)
410
+
411
+ # Prepare text conditioning
412
+ conditional_dict = self.pipeline.text_encoder(text_prompts=[config.prompt])
413
+ for key, value in conditional_dict.items():
414
+ conditional_dict[key] = value.to(dtype=torch.float16)
415
+
416
+ # Initialize caches
417
+ rnd = torch.Generator(self.device).manual_seed(int(config.seed))
418
+ self.pipeline._initialize_kv_cache(1, torch.float16, device=self.device)
419
+ self.pipeline._initialize_crossattn_cache(1, torch.float16, device=self.device)
420
+
421
+ # Generate noise tensor
422
+ noise = torch.randn(
423
+ [1, 21, 16, config.height // 16, config.width // 16],
424
+ device=self.device,
425
+ dtype=torch.float16,
426
+ generator=rnd
427
+ )
428
+
429
+ # Initialize VAE cache
430
+ vae_cache = None
431
+ latents_cache = None
432
+ if not config.use_taehv and not config.use_trt:
433
+ vae_cache = [c.to(device=self.device, dtype=torch.float16) for c in ZERO_VAE_CACHE]
434
+
435
+ # Generation parameters
436
+ num_blocks = 7
437
+ current_start_frame = 0
438
+ all_num_frames = [self.pipeline.num_frame_per_block] * num_blocks
439
+
440
+ all_frames = []
441
+
442
+ # Generate video blocks
443
+ for idx, current_num_frames in enumerate(all_num_frames):
444
+ logger.info(f"Processing block {idx+1}/{num_blocks}")
445
+
446
+ noisy_input = noise[:, current_start_frame : current_start_frame + current_num_frames]
447
+
448
+ # Denoising steps
449
+ for step_idx, current_timestep in enumerate(self.pipeline.denoising_step_list):
450
+ timestep = torch.ones([1, current_num_frames], device=noise.device, dtype=torch.int64) * current_timestep
451
+ _, denoised_pred = self.pipeline.generator(
452
+ noisy_image_or_video=noisy_input,
453
+ conditional_dict=conditional_dict,
454
+ timestep=timestep,
455
+ kv_cache=self.pipeline.kv_cache1,
456
+ crossattn_cache=self.pipeline.crossattn_cache,
457
+ current_start=current_start_frame * self.pipeline.frame_seq_length
458
+ )
459
+
460
+ if step_idx < len(self.pipeline.denoising_step_list) - 1:
461
+ next_timestep = self.pipeline.denoising_step_list[step_idx + 1]
462
+ noisy_input = self.pipeline.scheduler.add_noise(
463
+ denoised_pred.flatten(0, 1),
464
+ torch.randn_like(denoised_pred.flatten(0, 1)),
465
+ next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
466
+ ).unflatten(0, denoised_pred.shape[:2])
467
+
468
+ # Update cache for next block
469
+ if idx < len(all_num_frames) - 1:
470
+ self.pipeline.generator(
471
+ noisy_image_or_video=denoised_pred,
472
+ conditional_dict=conditional_dict,
473
+ timestep=torch.zeros_like(timestep),
474
+ kv_cache=self.pipeline.kv_cache1,
475
+ crossattn_cache=self.pipeline.crossattn_cache,
476
+ current_start=current_start_frame * self.pipeline.frame_seq_length,
477
+ )
478
+
479
+ # Decode to pixels
480
+ if config.use_trt:
481
+ pixels, vae_cache = self.pipeline.vae.forward(denoised_pred.half(), *vae_cache)
482
+ elif config.use_taehv:
483
+ if latents_cache is None:
484
+ latents_cache = denoised_pred
485
+ else:
486
+ denoised_pred = torch.cat([latents_cache, denoised_pred], dim=1)
487
+ latents_cache = denoised_pred[:, -3:]
488
+ pixels = self.pipeline.vae.decode(denoised_pred)
489
+ else:
490
+ pixels, vae_cache = self.pipeline.vae(denoised_pred.half(), *vae_cache)
491
+
492
+ # Handle frame skipping
493
+ if idx == 0 and not config.use_trt:
494
+ pixels = pixels[:, 3:]
495
+ elif config.use_taehv and idx > 0:
496
+ pixels = pixels[:, 12:]
497
+
498
+ # Convert frames to numpy
499
+ for frame_idx in range(pixels.shape[1]):
500
+ frame_tensor = pixels[0, frame_idx]
501
+ frame_np = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
502
+ frame_np = frame_np.to(torch.uint8).cpu().numpy()
503
+ frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
504
+ all_frames.append(frame_np)
505
+
506
+ current_start_frame += current_num_frames
507
+
508
+ # Convert frames to video
509
+ video_bytes = frames_to_video_bytes(all_frames, fps=config.fps, quality=config.quality)
510
+
511
+ # Convert to base64 data URI
512
+ video_b64 = base64.b64encode(video_bytes).decode('utf-8')
513
+ video_uri = f"data:video/mp4;base64,{video_b64}"
514
+
515
+ # Prepare metadata
516
+ metadata = {
517
+ "width": config.width,
518
+ "height": config.height,
519
+ "num_frames": len(all_frames),
520
+ "fps": config.fps,
521
+ "duration": len(all_frames) / config.fps,
522
+ "seed": config.seed,
523
+ "prompt": config.prompt,
524
+ }
525
+
526
+ # Clean up to prevent CUDA OOM errors
527
+ del noise, conditional_dict, pixels
528
+ if self.device == "cuda":
529
+ torch.cuda.empty_cache()
530
+ gc.collect()
531
+
532
+ return {
533
+ "video": video_uri,
534
+ "content-type": "video/mp4",
535
+ "metadata": metadata
536
+ }
537
+
538
+ except Exception as e:
539
+ # Log the error and reraise
540
+ import traceback
541
+ error_message = f"Error generating video: {str(e)}\n{traceback.format_exc()}"
542
+ logger.error(error_message)
543
+ raise RuntimeError(error_message)