jbilcke-hf HF Staff commited on
Commit
6373d0a
·
1 Parent(s): 6896d2c

experiment

Browse files
Files changed (2) hide show
  1. app.py +3 -1
  2. app_with_streaming.py +3 -2
app.py CHANGED
@@ -185,7 +185,9 @@ def frames_to_mp4_base64(frames, fps = 15):
185
 
186
  return "data:video/mp4;base64,"
187
 
188
- def initialize_vae_decoder(use_taehv=False, use_trt=False):
 
 
189
  if use_trt:
190
  from demo_utils.vae import VAETRTWrapper
191
  print("Initializing TensorRT VAE Decoder...")
 
185
 
186
  return "data:video/mp4;base64,"
187
 
188
+ # note: we set use_taehv to be able to use other resolutions
189
+ # this might impact performance
190
+ def initialize_vae_decoder(use_taehv=True, use_trt=False):
191
  if use_trt:
192
  from demo_utils.vae import VAETRTWrapper
193
  print("Initializing TensorRT VAE Decoder...")
app_with_streaming.py CHANGED
@@ -253,8 +253,6 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WI
253
  seed = random.randint(0, 2**32 - 1)
254
 
255
 
256
- print(f"🎬 video_generation_handler_streaming called, seed: {seed}, duration: {duration}s, fps: {fps}, width: {width}, height: {height}")
257
-
258
  # Setup
259
  conditional_dict = text_encoder(text_prompts=[prompt])
260
  for key, value in conditional_dict.items():
@@ -267,6 +265,9 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WI
267
  # Calculate latent dimensions based on actual width/height (assuming 8x downsampling)
268
  latent_height = height // 8
269
  latent_width = width // 8
 
 
 
270
  noise = torch.randn([1, 21, 16, latent_height, latent_width], device=gpu, dtype=torch.float16, generator=rnd)
271
 
272
  vae_cache, latents_cache = None, None
 
253
  seed = random.randint(0, 2**32 - 1)
254
 
255
 
 
 
256
  # Setup
257
  conditional_dict = text_encoder(text_prompts=[prompt])
258
  for key, value in conditional_dict.items():
 
265
  # Calculate latent dimensions based on actual width/height (assuming 8x downsampling)
266
  latent_height = height // 8
267
  latent_width = width // 8
268
+
269
+ print(f"🎬 video_generation_handler_streaming called, seed: {seed}, duration: {duration}s, fps: {fps}, width: {width}, height: {height}")
270
+
271
  noise = torch.randn([1, 21, 16, latent_height, latent_width], device=gpu, dtype=torch.float16, generator=rnd)
272
 
273
  vae_cache, latents_cache = None, None