rahul7star commited on
Commit
b9a5c9f
·
verified ·
1 Parent(s): 3098f22

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -7
app.py CHANGED
@@ -98,16 +98,70 @@ class PatchedModelManager(ModelManager):
98
  return None
99
 
100
  # Video generation logic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  @spaces.GPU(duration=200)
102
  def generate_video(prompt: str):
103
  ensure_model_downloaded()
104
 
105
  # Load model using patched manager
106
- manager = ModelManager(
107
- file_path_list=[WAN_MODEL_PATH],
108
- torch_dtype=torch.float16,
109
- device="cuda"
110
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
 
113
 
@@ -117,11 +171,11 @@ def generate_video(prompt: str):
117
  # device="cuda"
118
  # )
119
 
120
- model = manager.load_model(WAN_MODEL_PATH)
121
 
122
  # Set up pipeline
123
  pipeline = WanVideoPusaPipeline(model=model)
124
- pipeline.set_lora_adapters(LORA_PATH)
125
 
126
  # Generate video
127
  result = pipeline(prompt)
 
98
  return None
99
 
100
  # Video generation logic
101
+
102
+
103
+
104
+ def generate_t2v_video(self, prompt, lora_alpha, num_inference_steps,
105
+ negative_prompt, progress=gr.Progress()):
106
+ """Generate video from text prompt"""
107
+ try:
108
+ progress(0.1, desc="Loading models...")
109
+ lora_path = "./model_zoo/PusaV1/pusa_v1.pt"
110
+ pipe = self.load_lora_and_get_pipe("t2v", lora_path, lora_alpha)
111
+
112
+ progress(0.3, desc="Generating video...")
113
+ video = pipe(
114
+ prompt=prompt,
115
+ negative_prompt=negative_prompt,
116
+ num_inference_steps=num_inference_steps,
117
+ height=720, width=1280, num_frames=81,
118
+ seed=0, tiled=True
119
+ )
120
+
121
+ progress(0.9, desc="Saving video...")
122
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
123
+ video_filename = os.path.join(self.output_dir, f"t2v_output_{timestamp}.mp4")
124
+ save_video(video, video_filename, fps=25, quality=5)
125
+
126
+ progress(1.0, desc="Complete!")
127
+ return video_filename, f"Video generated successfully! Saved to {video_filename}"
128
+
129
+ except Exception as e:
130
+ return None, f"Error: {str(e)}"
131
+
132
+
133
+
134
+
135
  @spaces.GPU(duration=200)
136
  def generate_video(prompt: str):
137
  ensure_model_downloaded()
138
 
139
  # Load model using patched manager
140
+
141
+
142
+ model_manager = ModelManager(device="cuda")
143
+ base_dir = "model_zoo/PusaV1/Wan2.1-T2V-14B"
144
+
145
+ model_files = sorted([os.path.join(self.base_dir, f) for f in os.listdir(self.base_dir) if f.endswith('.safetensors')])
146
+
147
+ model_manager.load_models(
148
+ [
149
+ model_files,
150
+ os.path.join(self.base_dir, "models_t5_umt5-xxl-enc-bf16.pth"),
151
+ os.path.join(self.base_dir, "Wan2.1_VAE.pth"),
152
+ ],
153
+ torch_dtype=torch.bfloat16,
154
+ )
155
+
156
+
157
+
158
+
159
+
160
+ # manager = ModelManager(
161
+ # file_path_list=[WAN_MODEL_PATH],
162
+ # torch_dtype=torch.float16,
163
+ # device="cuda"
164
+ # )
165
 
166
 
167
 
 
171
  # device="cuda"
172
  # )
173
 
174
+ #model = manager.load_model(WAN_MODEL_PATH)
175
 
176
  # Set up pipeline
177
  pipeline = WanVideoPusaPipeline(model=model)
178
+ #pipeline.set_lora_adapters(LORA_PATH)
179
 
180
  # Generate video
181
  result = pipeline(prompt)