rahul7star commited on
Commit
886d416
·
verified ·
1 Parent(s): db3e3e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -24
app.py CHANGED
@@ -2,16 +2,11 @@ import gradio as gr
2
  import os
3
  import torch
4
  import tempfile
 
5
  from huggingface_hub import snapshot_download
6
- import sys, os
7
- sys.path.insert(0, os.path.abspath("./PusaV1"))
8
-
9
-
10
  import spaces
11
 
12
- import sys, os
13
-
14
- # Add PusaV1 to sys.path if not already
15
  PUSA_PATH = os.path.abspath("./PusaV1")
16
  if PUSA_PATH not in sys.path:
17
  sys.path.insert(0, PUSA_PATH)
@@ -24,15 +19,9 @@ if not os.path.exists(DIFFSYNTH_PATH):
24
  f"Ensure PusaV1 is correctly cloned and folder structure is intact."
25
  )
26
 
27
-
28
-
29
-
30
-
31
-
32
-
33
-
34
  from PusaV1.diffsynth import ModelManager, WanVideoPusaPipeline, save_video
35
-
36
  # Constants
37
  WAN_SUBFOLDER = "Wan2.1-T2V-14B"
38
  MODEL_REPO_ID = "RaphaelLiu/PusaV1"
@@ -53,19 +42,42 @@ def ensure_model_downloaded():
53
  )
54
  print("Model downloaded.")
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Video generation logic
57
- @spaces.GPU(duration = 200)
58
  def generate_video(prompt: str):
59
  ensure_model_downloaded()
60
 
61
- # Load model
62
- manager = ModelManager(
63
- file_path_list=[WAN_MODEL_PATH],
64
- torch_dtype=torch.float16,
65
- device="cuda"
66
- )
67
-
68
-
69
  model = manager.load_model()
70
 
71
  # Set up pipeline
 
2
  import os
3
  import torch
4
  import tempfile
5
+ import sys
6
  from huggingface_hub import snapshot_download
 
 
 
 
7
  import spaces
8
 
9
+ # Setup paths
 
 
10
  PUSA_PATH = os.path.abspath("./PusaV1")
11
  if PUSA_PATH not in sys.path:
12
  sys.path.insert(0, PUSA_PATH)
 
19
  f"Ensure PusaV1 is correctly cloned and folder structure is intact."
20
  )
21
 
22
+ # Import core modules from PusaV1
 
 
 
 
 
 
23
  from PusaV1.diffsynth import ModelManager, WanVideoPusaPipeline, save_video
24
+
25
  # Constants
26
  WAN_SUBFOLDER = "Wan2.1-T2V-14B"
27
  MODEL_REPO_ID = "RaphaelLiu/PusaV1"
 
42
  )
43
  print("Model downloaded.")
44
 
45
+ # Subclass ModelManager to force WanModelPusa
46
+ class PatchedModelManager(ModelManager):
47
+ def load_model(self, file_path=None, model_names=None, device=None, torch_dtype=None):
48
+ if file_path is None:
49
+ file_path = self.file_path_list[0]
50
+ print(f"[app.py] Forcing architecture: WanModelPusa for {file_path}")
51
+ for detector in self.model_detector:
52
+ if detector.match(file_path, {}):
53
+ model_names, models = detector.load(
54
+ file_path,
55
+ state_dict={},
56
+ device=device or self.device,
57
+ torch_dtype=torch_dtype or self.torch_dtype,
58
+ allowed_model_names=model_names,
59
+ model_manager=self,
60
+ forced_architecture="WanModelPusa"
61
+ )
62
+ for name, model in zip(model_names, models):
63
+ self.model.append(model)
64
+ self.model_path.append(file_path)
65
+ self.model_name.append(name)
66
+ return models[0] if models else None
67
+ print("No suitable model detector matched.")
68
+ return None
69
+
70
  # Video generation logic
71
+ @spaces.GPU(duration=200)
72
  def generate_video(prompt: str):
73
  ensure_model_downloaded()
74
 
75
+ # Load model using patched manager
76
+ manager = PatchedModelManager(
77
+ file_path_list=[WAN_MODEL_PATH],
78
+ torch_dtype=torch.float16,
79
+ device="cuda"
80
+ )
 
 
81
  model = manager.load_model()
82
 
83
  # Set up pipeline