Commit
·
6373d0a
1
Parent(s):
6896d2c
experiment
Browse files- app.py +3 -1
- app_with_streaming.py +3 -2
app.py
CHANGED
@@ -185,7 +185,9 @@ def frames_to_mp4_base64(frames, fps = 15):
|
|
185 |
|
186 |
return "data:video/mp4;base64,"
|
187 |
|
188 |
-
|
|
|
|
|
189 |
if use_trt:
|
190 |
from demo_utils.vae import VAETRTWrapper
|
191 |
print("Initializing TensorRT VAE Decoder...")
|
|
|
185 |
|
186 |
return "data:video/mp4;base64,"
|
187 |
|
188 |
+
# note: we set use_taehv to be able to use other resolutions
|
189 |
+
# this might impact performance
|
190 |
+
def initialize_vae_decoder(use_taehv=True, use_trt=False):
|
191 |
if use_trt:
|
192 |
from demo_utils.vae import VAETRTWrapper
|
193 |
print("Initializing TensorRT VAE Decoder...")
|
app_with_streaming.py
CHANGED
@@ -253,8 +253,6 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WI
|
|
253 |
seed = random.randint(0, 2**32 - 1)
|
254 |
|
255 |
|
256 |
-
print(f"🎬 video_generation_handler_streaming called, seed: {seed}, duration: {duration}s, fps: {fps}, width: {width}, height: {height}")
|
257 |
-
|
258 |
# Setup
|
259 |
conditional_dict = text_encoder(text_prompts=[prompt])
|
260 |
for key, value in conditional_dict.items():
|
@@ -267,6 +265,9 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WI
|
|
267 |
# Calculate latent dimensions based on actual width/height (assuming 8x downsampling)
|
268 |
latent_height = height // 8
|
269 |
latent_width = width // 8
|
|
|
|
|
|
|
270 |
noise = torch.randn([1, 21, 16, latent_height, latent_width], device=gpu, dtype=torch.float16, generator=rnd)
|
271 |
|
272 |
vae_cache, latents_cache = None, None
|
|
|
253 |
seed = random.randint(0, 2**32 - 1)
|
254 |
|
255 |
|
|
|
|
|
256 |
# Setup
|
257 |
conditional_dict = text_encoder(text_prompts=[prompt])
|
258 |
for key, value in conditional_dict.items():
|
|
|
265 |
# Calculate latent dimensions based on actual width/height (assuming 8x downsampling)
|
266 |
latent_height = height // 8
|
267 |
latent_width = width // 8
|
268 |
+
|
269 |
+
print(f"🎬 video_generation_handler_streaming called, seed: {seed}, duration: {duration}s, fps: {fps}, width: {width}, height: {height}")
|
270 |
+
|
271 |
noise = torch.randn([1, 21, 16, latent_height, latent_width], device=gpu, dtype=torch.float16, generator=rnd)
|
272 |
|
273 |
vae_cache, latents_cache = None, None
|