rahul7star commited on
Commit
de49789
·
verified ·
1 Parent(s): 886d416

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -6
app.py CHANGED
@@ -20,7 +20,24 @@ if not os.path.exists(DIFFSYNTH_PATH):
20
  )
21
 
22
  # Import core modules from PusaV1
23
- from PusaV1.diffsynth import ModelManager, WanVideoPusaPipeline, save_video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Constants
26
  WAN_SUBFOLDER = "Wan2.1-T2V-14B"
@@ -43,6 +60,8 @@ def ensure_model_downloaded():
43
  print("Model downloaded.")
44
 
45
  # Subclass ModelManager to force WanModelPusa
 
 
46
  class PatchedModelManager(ModelManager):
47
  def load_model(self, file_path=None, model_names=None, device=None, torch_dtype=None):
48
  if file_path is None:
@@ -73,12 +92,19 @@ def generate_video(prompt: str):
73
  ensure_model_downloaded()
74
 
75
  # Load model using patched manager
 
 
 
 
 
 
76
  manager = PatchedModelManager(
77
- file_path_list=[WAN_MODEL_PATH],
78
- torch_dtype=torch.float16,
79
- device="cuda"
80
- )
81
- model = manager.load_model()
 
82
 
83
  # Set up pipeline
84
  pipeline = WanVideoPusaPipeline(model=model)
 
20
  )
21
 
22
  # Import core modules from PusaV1
23
+ from PusaV1.diffsynth import WanVideoPusaPipeline, save_video
24
+
25
+
26
+ from PusaV1.diffsynth import ModelManager as BaseModelManager
27
+
28
+ class PatchedModelManager(BaseModelManager):
29
+ def __init__(self, *args, **kwargs):
30
+ super().__init__(*args, **kwargs)
31
+
32
+ # Your custom architecture dict entries to patch or add
33
+ custom_architecture_dict = {
34
+ "WanModel": ("diffsynth.models.wan_model", "WanModelPusa", None),
35
+ # Add more fixes if needed
36
+ }
37
+
38
+ # Update or replace the architecture_dict
39
+ self.architecture_dict.update(custom_architecture_dict)
40
+
41
 
42
  # Constants
43
  WAN_SUBFOLDER = "Wan2.1-T2V-14B"
 
60
  print("Model downloaded.")
61
 
62
  # Subclass ModelManager to force WanModelPusa
63
+
64
+
65
  class PatchedModelManager(ModelManager):
66
  def load_model(self, file_path=None, model_names=None, device=None, torch_dtype=None):
67
  if file_path is None:
 
92
  ensure_model_downloaded()
93
 
94
  # Load model using patched manager
95
+ # manager = PatchedModelManager(
96
+ # file_path_list=[WAN_MODEL_PATH],
97
+ # torch_dtype=torch.float16,
98
+ # device="cuda"
99
+ # )
100
+
101
  manager = PatchedModelManager(
102
+ file_path_list=[WAN_MODEL_PATH],
103
+ torch_dtype=torch.float16,
104
+ device="cuda"
105
+ )
106
+
107
+ model = manager.load_model(WAN_MODEL_PATH)
108
 
109
  # Set up pipeline
110
  pipeline = WanVideoPusaPipeline(model=model)