rahul7star commited on
Commit
bf9f272
·
verified ·
1 Parent(s): 4fc218f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -10
app.py CHANGED
@@ -39,24 +39,37 @@ class PatchedModelManager(ModelManager):
39
 
40
 
41
  # Constants
42
- WAN_SUBFOLDER = "Wan2.1-T2V-14B"
43
- MODEL_REPO_ID = "RaphaelLiu/PusaV1"
 
 
44
  MODEL_ZOO_DIR = "./model_zoo"
45
- WAN_MODEL_PATH = os.path.join(MODEL_ZOO_DIR, WAN_SUBFOLDER)
46
- LORA_PATH = os.path.join(MODEL_ZOO_DIR, "PusaV1", "pusa_v1.pt")
 
 
47
 
48
- # Ensure model is downloaded
49
  def ensure_model_downloaded():
 
 
 
 
 
 
 
 
 
 
50
  if not os.path.exists(WAN_MODEL_PATH):
51
- print("Downloading Wan2.1-T2V-14B from HuggingFace Hub...")
52
  snapshot_download(
53
- repo_id=MODEL_REPO_ID,
54
- local_dir=MODEL_ZOO_DIR,
55
  repo_type="model",
56
- allow_patterns=[f"{WAN_SUBFOLDER}/**"],
57
  local_dir_use_symlinks=False,
58
  )
59
- print("Model downloaded.")
60
 
61
  # Subclass ModelManager to force WanModelPusa
62
 
 
39
 
40
 
41
  # Constants
42
+ import os
43
+ from huggingface_hub import snapshot_download
44
+
45
+ # Constants
46
  MODEL_ZOO_DIR = "./model_zoo"
47
+ PUSA_DIR = os.path.join(MODEL_ZOO_DIR, "PusaV1")
48
+ WAN_SUBFOLDER = "Wan2.1-T2V-14B"
49
+ WAN_MODEL_PATH = os.path.join(PUSA_DIR, WAN_SUBFOLDER)
50
+ LORA_PATH = os.path.join(PUSA_DIR, "pusa_v1.pt")
51
 
52
+ # Ensure model and weights are downloaded
53
  def ensure_model_downloaded():
54
+ if not os.path.exists(PUSA_DIR):
55
+ print("Downloading RaphaelLiu/PusaV1 to ./model_zoo/PusaV1 ...")
56
+ snapshot_download(
57
+ repo_id="RaphaelLiu/PusaV1",
58
+ local_dir=PUSA_DIR,
59
+ repo_type="model",
60
+ local_dir_use_symlinks=False,
61
+ )
62
+ print("✅ PusaV1 downloaded.")
63
+
64
  if not os.path.exists(WAN_MODEL_PATH):
65
+ print("Downloading Wan-AI/Wan2.1-T2V-14B to ./model_zoo/PusaV1 ...")
66
  snapshot_download(
67
+ repo_id="Wan-AI/Wan2.1-T2V-14B",
68
+ local_dir=os.path.join(PUSA_DIR, WAN_SUBFOLDER),
69
  repo_type="model",
 
70
  local_dir_use_symlinks=False,
71
  )
72
+ print(" Wan2.1-T2V-14B downloaded.")
73
 
74
  # Subclass ModelManager to force WanModelPusa
75